diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..711b27da13b54e30f8b25e839ffc4f51ed80dc5c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py @@ -0,0 +1,997 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running CONVs + + The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run + CONV2D operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS CONVs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # A, B, C, and D are torch/numpy/cupy tensor objects + plan = cutlass_cppgen.op.Conv(A, B, C, D) + plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + One can also use the interface by specifying data types of operands at construction + and using different tensor objects with these data types at runtime: + + .. highlight:: python + .. code-block:: python + + # The following is shorthand for: + # cutlass_cppgen.op.Conv2d(kind="fprop", + # element_A=torch.float32, element_B=torch.float32, + # element_C=torch.float32, element_D=torch.float32, + # element_accumulator=torch.float32) + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32) + + A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda') + B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda') + C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda') + D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda') + plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + A = torch.rand((32, 128), dtype=torch.float32, device='cuda') + B = torch.rand((128, 256), dtype=torch.float32, device='cuda') + C = torch.zeros((32, 256), dtype=torch.float32, device='cuda') + D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda') + plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + The interface additionally enables one to decouple the compilation of the underlying CUTLASS + kernel from its execution: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) + + # Do other work... + + plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + # Do other work... + + plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + Elementwise activation functions are easily fused to the GEMM via the interface: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) + plan.activation = cutlass_cppgen.epilogue.relu + + Operations can also be run asynchronously: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32) + args = plan.run() + + # Do other work... + + args.sync() +""" + +from __future__ import annotations +from typing import Optional +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +from cutlass_library import ( + ConvKind, + ConvMode, + DataTypeSize, + IteratorAlgorithm, + OperationKind, + SplitKMode, + StrideSupport, +) + +import cutlass_cppgen +from cutlass_cppgen import epilogue +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation +from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments +from cutlass_cppgen.backend.library import TensorDescription, TileDescription +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord +from cutlass_cppgen.utils import check, datatypes + + +class Conv2d(OperationBase): + """ + Constructs a ``Conv2d`` object. + + The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C, + along with the data type of output D and that used for accumulation, are bound to the ``Conv`` + object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. The following + constructors are equivalent: + + .. highlight:: python + .. code-block:: python + + # Use F32 for A, B, C, D, and accumulation in fprop + + # Use the generic ``element`` parameter to concisely set all data types for operands to the same values. + Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32) + + # Explicitly specify the data types to use for A, B, C, and D. + Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, + element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32) + + # Set the data types and elements from existing tensors. Note that one can use different tensors when + # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must + # have the same data type as those passed in here). + # A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout + Conv2d(kind="fprop", A=A, B=B, C=C, D=D) + + # Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit + # those passed in via the generic ``element`` + Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, + element=cutlass_cppgen.DataType.f32) + + The order of precedence for the setting of the data type for a given operand/output is as follows: + 1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor + 2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those + 3) Otherwise, use the generic values (e.g., ``element``) + + :param kind: the convolution kind (i.e. fprop, wgrad, and dgrad) + :type kind: str + :param A: tensor representing data type of operand A + :param B: tensor representing data type of operand B + :param C: tensor representing data type of operand C + :param D: tensor representing data type of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass_cppgen.DataType + :param element_A: data type to be used for operand A + :type element_A: cutlass_cppgen.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass_cppgen.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass_cppgen.DataType + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass_cppgen.DataType + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + """ + def __init__( + self, kind="fprop", + A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0, + element=None, + element_A=None, element_B=None, element_C=None, element_D=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None + ): + super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d) + # Verify the kernel cc + if self.current_cc in [90, 100, 101, 103]: + # The Conv2d kernel on Hopper (SM90) is currently unsupported + # Revert to use SM80-tagged kernels + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + self.specified_kernel_cc = 80 + self._reset_options(80) + + # The arch is used in testing + self.arch = self.current_cc + self.name = "conv2d" + kind + + # The convolution kind. (concept: cutlass_library.library.ConvKind) + self.conv_kind = datatypes.getattr_enum(ConvKind, kind) + + # The element types (concept: cutlass library types) of A, B, C, and D + elements = [] + layouts = [] + + # Complete the data types based on user-provided arguments + for elt, tens, name in zip([element_A, element_B, element_C, element_D], + [A, B, C, D], + ["A", "B", "C", "D"]): + if elt is not None and tens is not None: + raise Exception(f'Must not specify both element_{name} and tensor {name}') + if elt is None and tens is None and element is None: + raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') + + elt_to_set = None + lay_to_set = None + + if tens is not None: + elt_to_set, _ = datatypes.get_datatype_and_layout(tens) + else: + elt_to_set = elt if elt is not None else element + + assert elt_to_set is not None + + # Currently we only support layout TensorNHWC + lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC + elements.append(datatypes.library_type(elt_to_set)) + layouts.append(lay_to_set) + + self._element_a, self._element_b, self._element_c, self._element_d = elements + self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts + + self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta + + if element_accumulator is None: + self._element_accumulator = self._element_c + else: + self._element_accumulator = datatypes.library_type(element_accumulator) + + # Default inputs if none is supplied in run() + self.A = A + self.B = B + self.C = C + self.D = D + + self.alpha = alpha + self.beta = beta + + # We only specify the stride of the swizzling functor here + # The actual swizzling functor is determined in run based on conv_kind and stride + self._swizzling_stride = 1 + + # Arguments that will be set to default value in _reset_operations + # The default tile_description and op_class are fetched from manifest of cutlass library + self._tile_description = None + self.op_class = None + # The default identity epilogue will be created + self.epilogue_functor = None + + self._reset_operations() + + # Arguments that will be determined online based on arguments of "run" + # based on stride, input/output channels, alignment, and conv_kind + self._iterator_algorithm = None + self._stride_support = None + + def _reset_operations(self, reset_epilogue: bool = True): + # Set the default op class + datatype_comb = (self._element_a, self._element_b, self._element_accumulator) + layout_comb = (self._layout_a, self._layout_b) + + self.possible_op_classes = self.options.supporting_opclasses( + self._element_a, self._element_b, self._element_accumulator, + self._layout_a, self._layout_b, self._math_operation + ) + + if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.TensorOp + elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.Simt + else: + if self._math_operation is not None: + math_op_str = f' and math operation {self._math_operation}' + else: + math_op_str = '' + + raise Exception(f'No kernel configuration found for supported data type and layout ' + f'combination {datatype_comb}x{layout_comb}{math_op_str}') + + if reset_epilogue: + self._reset_epilogue_functor_activation(epilogue.identity) + + self.alignment_pref_A = min( + 128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) + self.alignment_pref_B = min( + 128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) + self.alignment_pref_C = min( + 128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C"))) + + # + # Tile description Related + # + + @property + def tile_description(self) -> TileDescription: + """ + Returns the tile description + """ + return self._tile_description + + @tile_description.setter + def tile_description( + self, td=None): + """ + Set the tile description + + :param td: tile description + :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys + { + "threadblock_shape": [int, int, int], + "warp_count": [int, int, int], + "stages": int, + "instruction_shape": [int, int, int] (optional), + "cluster_shape": [int, int, int] (optional) + } + """ + if td is None: + return + if isinstance(td, dict): + if self._tile_description is None: + op = self.possible_operations.default_operation(self._math_operation) + self._tile_description = datatypes.td_from_profiler_op(op) + if "cluster_shape" in td.keys(): + if td["cluster_shape"] != [1, 1, 1]: + cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.") + td["cluster_shape"] = [1, 1, 1] + td = self._tile_description.clone_and_update(td) + + valid, msg = self._valid_tile_description(td) + if valid: + self._tile_description = td + else: + raise Exception(msg) + + def _valid_tile_description(self, td: TileDescription) -> tuple: + """ + Checks whether the provided tile description is valid for the given compute capability. At present, + this checks the following: + + - Does the tile description use a number of stages supported by the compute capability in question? + - Does the tile size requested fit within shared memory? + - Are cluster dimensions outside the valid range requested for a given architecture (e.g., + more non-unit cluster dimensions for pre-SM90 architectures)? + - Is the kernel schedule being used supported on the architecture in question? + + :param td: tile description to validate + :type td: cutlass_cppgen.backend.TileDescription + :return: tuple in which the first element is a bool indicating that the tile description is valid + and the second element is a string providing an optional error message. + :rtype: tuple + """ + valid, msg = check.valid_stage_count(self.cc, self.current_cc, td) + if not valid: + return (valid, msg) + + valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) + if not valid: + return (valid, msg) + + return valid, msg + + def tile_descriptions(self) -> list: + """ + Returns a list of valid tile descriptions for the operations + + :returns: list of valid tile descriptions for the operations + :rtype: list + """ + descriptions = [] + description_str = [] + for op in self.possible_operations.all_operations: + td = datatypes.td_from_profiler_op(op) + + if self._math_operation is not None: + if td.math_instruction.math_operation != self._math_operation: + continue + + if str(td) not in description_str: + description_str.append(str(td)) + descriptions.append(td) + return descriptions + + # + # Swizzling functor Related + # + + @property + def swizzling_stride(self): + """ + Returns the stride of swizzling currently being used by the Conv2d + + :return: swizzing stride + """ + return self._swizzling_stride + + @swizzling_stride.setter + def swizzling_stride(self, stride: int): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + if not isinstance(stride, int): + raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}") + self._swizzling_stride = stride + + def _propose_swizzling_functor(self, stride): + """ + Automatically propose the swizzling functor based on the stride + """ + if self.conv_kind == ConvKind.Dgrad: + if stride[0] != 1 or stride[1] != 1: + return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}") + + return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}") + + # + # Iterator Algorithm Related + # + + @property + def iterator_algorithm(self) -> IteratorAlgorithm: + """ + Returns the iterator algorithm + """ + return self._iterator_algorithm + + @iterator_algorithm.setter + def iterator_algorithm(self, alg: str): + """ + Sets the iterator algorithm + + :param alg: The iterator algorithm + :type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels" + """ + iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg) + + # Check if the iterator algorithm is valid + if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop: + raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.") + + self._iterator_algorithm = iterator_alg + + def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm: + """ + Propose a valid iterator algorithm based on problem size and alignment + """ + if self.conv_kind == ConvKind.Fprop: + # Check whether the fixed channel is applicable + if problem_size.C == alignment_a: + return IteratorAlgorithm.FixedChannels + elif (problem_size.C % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32): + return IteratorAlgorithm.Optimized + else: + return IteratorAlgorithm.Analytic + elif self.conv_kind == ConvKind.Dgrad: + if (problem_size.K % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32 and + problem_size.C % alignment_b == 0): + return IteratorAlgorithm.Optimized + else: + return IteratorAlgorithm.Analytic + elif self.conv_kind == ConvKind.Wgrad: + if (problem_size.K % alignment_a == 0 and + problem_size.C % alignment_b == 0): + return IteratorAlgorithm.Optimized + else: + return IteratorAlgorithm.Analytic + + def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool: + """ + Validate whether the user provide iterator algorithm works for the given problem size + """ + if self.conv_kind == ConvKind.Fprop: + if iterator_algorithm == IteratorAlgorithm.FixedChannels: + return problem_size.C == alignment_a + elif iterator_algorithm == IteratorAlgorithm.Optimized: + return (problem_size.C % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32) + elif iterator_algorithm == IteratorAlgorithm.FewChannels: + return problem_size.C % alignment_a == 0 + elif self.conv_kind == ConvKind.Dgrad: + if iterator_algorithm == IteratorAlgorithm.Optimized: + return (problem_size.K % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32 and + problem_size.C % alignment_b == 0) + elif self.conv_kind == ConvKind.Wgrad: + if iterator_algorithm == IteratorAlgorithm.Optimized: + return (problem_size.K % alignment_a == 0 and + problem_size.C % alignment_b == 0) + + return True + + # + # Stride Support Related + # + + def _propose_stride_support(self, stride): + if self.conv_kind == ConvKind.Dgrad: + if stride[0] == 1 and stride[1] == 1: + return StrideSupport.Unity + + return StrideSupport.Strided + + # + # Construct and Compilation + # + + def construct( + self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + iterator_algorithm: IteratorAlgorithm = None, + stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None, + epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation: + """ + Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current + kernel specification of the ``Conv2d`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param iterator_algorithm: the iterator algorithm used + :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm + :param stride_support: the stride support of dgrad + :type stride_support: cutlass_library.library.StrideSupport + :param swizzling_functor: the swizzling functor + :type swizzling_functor: cutlass_cppgen.swizzle + :param epilogue_functor: the epilogue functor + + :return: operation that was constructed + :rtype: cutlass_cppgen.backend.Conv2dOperation + """ + # Get alignment + alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A) + alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B) + alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C) + + tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A) + tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) + tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + + if tile_description is None: + if self.tile_description is not None: + tile_description = self.tile_description + else: + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] + tile_description = datatypes.td_from_profiler_op(op) + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self.tile_description = tile_description + + if iterator_algorithm is None: + # If the iterator algorithm is already set + if self.iterator_algorithm is not None: + iterator_algorithm = self.iterator_algorithm + else: + # Otherwise, we conservatively use the analytic iterator for correctness + iterator_algorithm = IteratorAlgorithm.Analytic + + if stride_support is None: + # If the stride support is already set + if self._stride_support is not None: + stride_support = self._stride_support + else: + # Otherwise, we assume strided + stride_support = StrideSupport.Strided + + if swizzling_functor is None: + # If the swizzling functor is already set + swizzling_functor = self._propose_swizzling_functor(stride=(2, 2)) + + if epilogue_functor is None: + if self.epilogue_functor is not None: + epilogue_functor = self.epilogue_functor + else: + epilogue_functor = self._create_epilogue_functor_activation(self._activation) + + # Reset the alignment of the epilogue functor + epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor) + + operation = Conv2dOperation( + conv_kind=self.conv_kind, + iterator_algorithm=iterator_algorithm, + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + stride_support=stride_support, + epilogue_functor=epilogue_functor, + swizzling_functor=swizzling_functor, + ) + + return operation + + def compile(self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + iterator_algorithm: IteratorAlgorithm = None, + stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None, + epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation: + """ + Emits and compiles the kernel currently specified. If ``tile_description`` and any + of the ``alignment`` parameters are set, the kernel will be chosen using this + tile description and alignments. Otherwise, a default tile description and alignment + will be used. + + ::param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param iterator_algorithm: the iterator algorithm used + :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm + :param stride_support: the stride support of dgrad + :type stride_support: cutlass_library.library.StrideSupport + :param swizzling_functor: the swizzling functor + :type swizzling_functor: cutlass_cppgen.swizzle + :param epilogue_functor: the epilogue functor + + :return: operation that was compiled + :rtype: cutlass_cppgen.backend.Conv2dOperation + """ + + self.operation = self.construct( + tile_description, alignment_A, alignment_B, alignment_C, + iterator_algorithm, stride_support, swizzling_functor, epilogue_functor) + + if print_module: + print(self.operation.rt_module.emit()) + + compiler.add_module([self.operation,]) + return self.operation + + # + # Run Related + # + + def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): + """ + Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception + is raised if it does not. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + """ + dtype, _ = datatypes.get_datatype_and_layout(tensor) + if dtype != ref_type: + raise Exception(f'Tensor {name} with type and layout {dtype} ' + f'does not match the expected type of {ref_type}.') + + def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation): + if self.conv_kind == ConvKind.Fprop: + input = A + weight = B + output = C + output_tensor = "C" + elif self.conv_kind == ConvKind.Dgrad: + output = A + weight = B + input = C + output_tensor = "A" + elif self.conv_kind == ConvKind.Wgrad: + output = A + input = B + weight = C + output_tensor = "A" + else: + raise Exception(f"Convolution kind {self.conv_kind} is not supported") + + N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV") + K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV") + _, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV") + + problem_size = Conv2DProblemSize( + N_, H_, W_, C_, + K_, R_, S_, C_, + padding[0], padding[1], + stride[0], stride[1], + dilation[0], dilation[1], + ConvMode.CrossCorrelation, + 1, 1 + ) + + if P_ != problem_size.P or Q_ != problem_size.Q: + raise Exception( + f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})") + + return problem_size + + def run(self, A=None, B=None, C=None, D=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), + alpha=None, beta=None, + split_k=("serial", 1), sync: bool = True, + print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + """ + Runs the kernel currently specified. If it has not already been, the kernel is emitted and + compiled. Tensors holding operands and outputs of the kernel are sourced either from the + ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` + parameters provided in the call, or from those + passed in on the construction of this object -- one of the two must be specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1) + :param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0) + :param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1) + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param split_k: a tuple (split_k_mode, split_k_slices) + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + + :return: arguments passed in to the kernel + :rtype: cutlass_cppgen.backend.Conv2dArguments + """ + if not stream: + stream = cuda.CUstream(0) + super().run_setup() + + A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") + B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") + C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") + D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + # handle the case when there is no C + if C is None: + if beta != 0: + raise Exception(f"With beta {beta} != 0, C has to be provided.") + else: + C = D + + # Construct problem size based on input + # It also verifies whether the A, B, C, D, stride, padding, and dilation are matching + problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation) + + # Propose stride support based on input + stride_support = self._propose_stride_support(stride) + + # Propose swizzling functor + swizzling_functor = self._propose_swizzling_functor(stride) + + shape_a = datatypes.get_tensor_shape(A, op="CONV") + shape_b = datatypes.get_tensor_shape(B, op="CONV") + shape_c = datatypes.get_tensor_shape(C, op="CONV") + + # Get the alignment + alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A") + alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B") + alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C") + + alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A) + alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B) + alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C) + + # Propose iterator algorithm based on input + if self._iterator_algorithm is None: + # Propose a default iterator algorithm based on the problem size + iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b) + else: + if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)): + iterator_algorithm = self._iterator_algorithm + else: + raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.") + + epilogue_args = [alpha, beta] + + if hasattr(self, "_activation_args"): + if isinstance(self._activation_args, list): + epilogue_args += self._activation_args + else: + epilogue_args.append(self._activation_args) + + if split_k[0] == "parallel" and split_k[1] > 1: + epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity) + else: + epilogue_functor = self.epilogue_functor + + # The alignment is determined by the iterator function (I believe) + self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support, + swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module) + + # Create reduction operation for parallel split-k + if split_k[0] == "parallel" and split_k[1] > 1: + epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor) + self.reduction_operation = ReductionOperation( + shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C, + element_accumulator=self._element_accumulator, + element_compute=self._element_accumulator, + epilogue_functor=epilogue_functor_reduction, + count=alignment_c + ) + if print_module: + print(self.reduction_operation.rt_module.emit()) + compiler.add_module([self.reduction_operation,]) + + arguments = Conv2dArguments( + operation=self.operation, problem_size=problem_size, + A=A, B=B, C=C, D=D, + output_op=self.operation.epilogue_type(*epilogue_args), + split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]), + split_k_slices=split_k[1], + stream=stream + ) + + self.operation.run(arguments) + + if split_k[0] == "parallel" and split_k[1] > 1: + implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind) + reduction_arguments = ReductionArguments( + self.reduction_operation, + problem_size=[implicit_gemm_size.m, implicit_gemm_size.n], + partitions=split_k[1], + workspace=arguments.ptr_D, + destination=D, + source=C, + output_op=self.reduction_operation.epilogue_type(*epilogue_args), + stream=stream + ) + self.reduction_operation.run(reduction_arguments) + + if sync: + if split_k[0] == "parallel" and split_k[1] > 1: + reduction_arguments.sync() + + # Free memory allocated by args because we are not + # calling `arguments.sync()` in this case (which will free memory) + arguments.free() + else: + arguments.sync() + + return arguments + + # + # Helper functions + # + @staticmethod + def output_size(input_size, weight_size, padding, stride, dilation): + problem_size = Conv2DProblemSize( + *input_size, + *weight_size, + padding[0], padding[1], + stride[0], stride[1], + dilation[0], dilation[1], + ConvMode.CrossCorrelation, + 1, 1 + ) + return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K) + + +# +# Easy to use interfaces for fprop, wgrad, and dgrad +# + +class Conv2dFprop(Conv2d): + def __init__( + self, + input=None, weight=None, C=None, output=None, alpha=1, beta=0, + element=None, + element_input=None, element_weight=None, element_C=None, element_output=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = input, weight, output + element_A, element_B, element_D = element_input, element_weight, element_output + super().__init__( + "fprop", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run( + self, input=None, weight=None, C=None, output=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + + if not stream: + stream = cuda.CUstream(0) + + A, B, D = input, weight, output + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) + + +class Conv2dDgrad(Conv2d): + def __init__( + self, + grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0, + element=None, + element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = grad_output, weight, grad_input + element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input + super().__init__( + "dgrad", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + # + if not stream: + stream = cuda.CUstream(0) + + A, B, D = grad_output, weight, grad_input + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) + + +class Conv2dWgrad(Conv2d): + def __init__( + self, + grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0, + element=None, + element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = grad_output, input, grad_weight + element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight + super().__init__( + "wgrad", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> Conv2dArguments: + if not stream: + stream = cuda.CUstream(0) + + A, B, D = grad_output, input, grad_weight + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f9b1ab43a1c45d0024e99e50e45813ba18866e --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py @@ -0,0 +1,725 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running GEMMs. + + The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run + GEMM operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS GEMMs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # A, B, C, and D are torch/numpy/cupy tensor objects + plan = cutlass_cppgen.op.Gemm(A, B, C, D) + plan.run() + + + One can also use the interface by specifying data types of operands at construction + and using different tensor objects with these data types at runtime: + + .. highlight:: python + .. code-block:: python + + # The following is shorthand for: + # cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32, + # element_C=torch.float32, element_D=torch.float32, + # element_accumulator=torch.float32, + # layout=cutlass_cppgen.LayoutType.RowMajor) + plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + + A0 = torch.rand((128, 256), device='cuda') + B0 = torch.rand((256, 64), device='cuda') + C0 = torch.zeros((128, 64), device='cuda') + D0 = torch.zeros((128, 64), device.'cuda') + plan.run(A0, B0, C0, D0) + + A = torch.rand((32, 128), device='cuda') + B = torch.rand((128, 256), device='cuda') + C = torch.zeros((32, 256), device='cuda') + D = torch.zeros((32, 256), device.'cuda') + plan.run(A1, B1, C1, D1) + + The interface additionally enables one to decouple the compilation of the underlying CUTLASS + kernel from its execution: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.compile() + + # Do other work... + + plan.run(A0, B0, C0, D0) + + # Do other work... + + plan.run(A1, B1, C1, D1) + + Elementwise activation functions are easily fused to the GEMM via the interface: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.activation = cutlass_cppgen.epilogue.relu + + Operations can also be run asynchronously: + + .. highlight:: python + .. code-block:: python + + plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor) + args = plan.run() + + # Do other work... + + args.sync() +""" +from __future__ import annotations +from typing import Optional +from math import prod + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +from cutlass_library import ( + DataType, + DataTypeSize, + GemmUniversalMode, + KernelScheduleSuffixes, +) + +import cutlass_cppgen +from cutlass_cppgen import epilogue, swizzle +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor +from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass_cppgen.backend.library import TensorDescription, TileDescription +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils import check, datatypes + + +class Gemm(OperationBase): + """ + Constructs a ``Gemm`` object. + + The data types and layouts of operands A, B, and C, along with the data type of output D + and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime -- + these are not to be changed after a ``Gemm`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. The following + constructors are equivalent: + + .. highlight:: python + .. code-block:: python + + # Use F32 for A, B, C, D, and accumulation. All operands are row major. + + # Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts + # for operands to the same values. + Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + + # Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``. + Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32, + element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + + # Set the data types and elements from existing tensors. Note that one can use different tensors when + # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must + # have the same data type and layout as those passed in here). + # A, B, C, and D are row-major torch.Tensor objects of type torch.float32 + Gemm(A=A, B=B, C=C, D=D) + + # Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is + # the same as that for D, at present) + Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor, + layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor) + + # Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types + # and layouts will inherit those passed in via the generic ``element`` and ``layout`` + Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor, + element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor) + + The order of precedence for the setting of the data type and layout for a given operand/output is as follows: + 1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor + 2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those + 3) Otherwise, use the generic values (e.g., ``element``, ``layout``) + + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass_cppgen.DataType + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass_cppgen.DataType + :param layout: generic layout type to be used for operands A, B, C, and D + :type layout: cutlass_cppgen.LayoutType + :param element_A: data type to be used for operand A + :type element_A: cutlass_cppgen.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass_cppgen.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass_cppgen.DataType + :param layout_A: layout of operand A + :type layout_A: cutlass_cppgen.LayoutType + :param layout_B: layout of operand B + :type layout_B: cutlass_cppgen.LayoutType + :param layout_C: layout of operand C + :type layout_C: cutlass_cppgen.LayoutType + :param layout_D: layout of operand D + :type layout_D: cutlass_cppgen.LayoutType + """ + + def __init__( + self, A=None, B=None, C=None, D=None, + alpha=1.0, beta=0.0, element_accumulator=None, + element=None, layout=None, + element_A=None, element_B=None, element_C=None, element_D=None, + layout_A=None, layout_B=None, layout_C=None, + cc: int = None, kernel_cc: int = None + ): + super().__init__(cc=cc, kernel_cc=kernel_cc) + self.name = "gemm" + self.compiled = False + + elements = [] + layouts = [] + + # Check that at least one of the following is set for each tensor (illustrated assuming tensor A): + # ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout`` + for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D], + [layout_A, layout_B, layout_C, layout_C], + [A, B, C, D], + ["A", "B", "C", "D"]): + if elt is not None and tens is not None: + raise Exception(f'Must not specify both element_{name} and tensor {name}') + if lay is not None and tens is not None: + raise Exception(f'Must not specify both layout_{name} and tensor {name}') + if elt is None and tens is None and element is None: + raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') + if lay is None and tens is None and layout is None: + raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.') + + elt_to_set = None + lay_to_set = None + if tens is not None: + elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens) + else: + elt_to_set = elt if elt is not None else element + lay_to_set = lay if lay is not None else layout + + elements.append(datatypes.library_type(elt_to_set)) + layouts.append(lay_to_set) + + self._element_a, self._element_b, self._element_c, self._element_d = elements + self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts + + if element_accumulator is None: + self._element_accumulator = self._element_c + else: + self._element_accumulator = datatypes.library_type(element_accumulator) + + self.A = A + self.B = B + self.C = C + self.D = D + + self.alpha = alpha + self.beta = beta + + self.epilogue_functor = None + self.op_class = None + self._tile_description = None + + self._reset_operations() + + self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1 + + def _reset_operations(self, reset_epilogue: bool = True): + # Set the default op class + datatype_comb = (self._element_a, self._element_b, self._element_accumulator) + layout_comb = (self._layout_a, self._layout_b) + + self.possible_op_classes = self.options.supporting_opclasses( + self._element_a, self._element_b, self._element_accumulator, + self._layout_a, self._layout_b, self._math_operation) + + if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.TensorOp + elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes: + self.opclass = cutlass_cppgen.OpcodeClass.Simt + else: + if self._math_operation is not None: + math_op_str = f' and math operation {self._math_operation}' + else: + math_op_str = '' + + raise Exception(f'No kernel configuration found for supported data type and layout ' + f'combination {datatype_comb}x{layout_comb}{math_op_str}') + + if reset_epilogue: + self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity) + + @property + def swizzling_functor(self): + """ + Returns the type of the swizzling functor currently being used by the GEMM + + :return: swizzing functor type + """ + return self._swizzling_functor + + @swizzling_functor.setter + def swizzling_functor(self, swizzling_functor): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK: + if self.op_class == cutlass_cppgen.OpcodeClass.Simt: + raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp') + + if self.current_cc in [90, 100, 101, 103]: + raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+') + self._swizzling_functor = swizzling_functor + + # + # Tile description Related + # + + @property + def tile_description(self) -> TileDescription: + """ + Returns the tile description + """ + return self._tile_description + + @tile_description.setter + def tile_description( + self, td=None): + """ + Set the tile description + + :param td: tile description + :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys + { + "threadblock_shape": [int, int, int], + "warp_count": [int, int, int], + "stages": int, + "instruction_shape": [int, int, int] (optional), + "cluster_shape": [int, int, int] (optional) + } + """ + if td is None: + return + if isinstance(td, dict): + if self._tile_description is None: + op = self.possible_operations.default_operation(self._math_operation) + self._tile_description = datatypes.td_from_profiler_op(op) + td = self._tile_description.clone_and_update(td) + + valid, msg = self._valid_tile_description(td) + if valid: + self._tile_description = td + else: + raise Exception(msg) + + def _valid_tile_description(self, td: TileDescription) -> tuple: + """ + Checks whether the provided tile description is valid for the given compute capability. At present, + this checks the following: + + - Does the tile description use a number of stages supported by the compute capability in question? + - Does the tile size requested fit within shared memory? + - Are cluster dimensions outside the valid range requested for a given architecture (e.g., + more non-unit cluster dimensions for pre-SM90 architectures)? + - Is the kernel schedule being used supported on the architecture in question? + + :param td: tile description to validate + :type td: cutlass_cppgen.backend.TileDescription + :return: tuple in which the first element is a bool indicating that the tile description is valid + and the second element is a string providing an optional error message. + :rtype: tuple + """ + valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d) + if not valid: + return (valid, msg) + + valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) + if not valid: + return (valid, msg) + + valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler) + + if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0: + valid = False + msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103" + + return valid, msg + + def tile_descriptions(self) -> list: + """ + Returns a list of valid tile descriptions for the operations + + :returns: list of valid tile descriptions for the operations + :rtype: list + """ + tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations] + if self._math_operation is not None: + tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation] + return tds + + def construct( + self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal: + """ + Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current + kernel specification of the ``Gemm`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + + :return: operation that was constructed + :rtype: cutlass_cppgen.backend.GemmOperationUniversal + """ + alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) + alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) + alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A) + alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B) + + tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A) + tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) + + if alignment_C is None: + alignment_C = max(self.possible_operations.alignments("C")) + if self._element_c != DataType.void: + alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C) + + if tile_description is None: + if self._tile_description is None: + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] + tile_description = datatypes.td_from_profiler_op(op) + + # The selected op may have lower alignment than that determined above, so we must + # reset alignment here. + alignment_C = op.C.alignment + else: + tile_description = self._tile_description + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self._tile_description = tile_description + + tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) + + operation = GemmOperationUniversal( + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + epilogue_functor=self.epilogue_functor, + swizzling_functor=self._swizzling_functor, + ) + + return operation + + def compile(self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal: + """ + Emits and compiles the kernel currently specified. If ``tile_description`` and any + of the ``alignment`` parameters are set, the kernel will be chosen using this + tile description and alignments. Otherwise, a default tile description and alignment + will be used. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param print_module: whether to print the emitted C++ code + :type print_module: bool + + :return: operation that was compiled + :rtype: cutlass_cppgen.backend.GemmOperationUniversal + """ + self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C) + + if print_module: + print(self.operation.rt_module.emit()) + + compiler.add_module([self.operation,]) + return self.operation + + def _verify_rank(self, tensor): + """ + Verifies that ``tensor`` has rank greater than 1 + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + """ + if len(tensor.shape) < 2: + raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}") + + def _get_batch_count(self, A, B, C, D) -> int: + """ + Returns the batch count specified by the tensors A, B, C, and D and verifies that these + tensors match in batch size. Presence of a batch dimension is detected by one of the + tensors being rank 3. If a batch dimension is present, it must be present in one of + operands A, B, or C (but need not be in all), and must be present in D. + + :param A: tensor A + :type A: numpy/cupy/torch array/tensor object + :param B: tensor B + :type B: numpy/cupy/torch array/tensor object + :param C: tensor C + :type C: numpy/cupy/torch array/tensor object + :param D: tensor D + :type D: numpy/cupy/torch array/tensor object + + :return: tuple of batch count dimensions + :rtype: tuple + """ + A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1 + B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1 + + if 1 not in [A_batch, B_batch]: + if A_batch != B_batch: + raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}") + return max(A_batch, B_batch) + + def _get_batch_stride(self, tensor) -> int: + """ + Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0. + + :param tensor: tensor object to process + :type tensor: numpy/cupy/torch array/tensor object + + :return: stride between each matrix in the batch + :rtype: int + """ + if tensor is not None and len(tensor.shape) > 2: + return tensor.shape[-2] * tensor.shape[-1] + else: + return 0 + + def _get_problem_args(self, A, B, C, D) -> tuple: + """ + Returns the problem size and GEMM universal mode to use for the + given operands. + + :param A: tensor A + :type A: numpy/cupy/torch array/tensor object + :param B: tensor B + :type B: numpy/cupy/torch array/tensor object + :param C: tensor C + :type C: numpy/cupy/torch array/tensor object + :param D: tensor D + :type D: numpy/cupy/torch array/tensor object + + :return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int) + :rtype: tuple + """ + M, K = A.shape[-2:] + N = B.shape[-1] + mode = GemmUniversalMode.Gemm + + batch_count = self._get_batch_count(A, B, C, D) + returned_batch_count = batch_count + + # If we are running a batched GEMM in which there is a nonzero batch stride + # only for A, then we can fold the batched dimension of A into the M dimension + # (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A + # and C are row major. A similar operation can be performed if only B has a nonzero + # batch dimension + if batch_count > 1: + A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor + B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor + C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor + + # Consider a Tensor to be batched if its rank is > 2 and + # the product of the modes beyond rank 2 equals our pre-determined batch size. + batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count) + + if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row: + M *= batch_count + returned_batch_count = 1 + elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row: + N *= batch_count + returned_batch_count = 1 + else: + mode = GemmUniversalMode.Batched + + return GemmCoord(M, N, K), mode, returned_batch_count + + def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): + """ + Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception + is raised if it does not. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param ref_layout: layout for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + """ + dtype, layout = datatypes.get_datatype_and_layout(tensor) + if dtype != ref_type or layout != ref_layout: + try: + # Attempt to transpose the tensor to fit the desired layout + tensor = tensor.transpose(-1, -2) + except: + raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) ' + f'does not match the expected type and ' + f'layout of ({ref_type}, {ref_layout}) and transpose failed.') + + def run(self, A=None, B=None, C=None, D=None, + alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None, + stream: Optional[cuda.CUstream] = None) -> GemmArguments: + """ + Runs the kernel currently specified. If it has not already been, the kernel is emitted and + compiled. Tensors holding operands and outputs of the kernel are sourced either from the + ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` + parameters provided in this call, or from those + passed in on the construction of this object -- one of the two must be specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + + :return: arguments passed in to the kernel + :rtype: cutlass_cppgen.backend.GemmArguments + """ + if not stream: + stream = cuda.CUstream(0) + super().run_setup() + A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") + B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") + C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") + D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + is_void_c = self._element_c == DataType.void + + self._verify_rank(A) + self._verify_rank(B) + if not is_void_c: + self._verify_rank(C) + self._verify_rank(D) + + alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") + alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") + + # Set C alignment based on D.shape so as to correctly get an alignment with void-C + # kernels, for which `C` is None. + alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C") + self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, print_module=print_module) + + problem_size, mode, batch_count = self._get_problem_args(A, B, C, D) + + if mode == GemmUniversalMode.Gemm or batch_count == 1: + kwargs = {'split_k_slices': 1} + else: + kwargs = { + 'batch': batch_count, + 'batch_strides': { + 'A': self._get_batch_stride(A), + 'B': self._get_batch_stride(B), + 'C': self._get_batch_stride(C), + 'D': self._get_batch_stride(D) + } + } + + kwargs['stream'] = stream + + if isinstance(self.epilogue_functor, EpilogueFunctorVisitor): + output_op = self.operation.epilogue_type(visitor_args) + else: + output_op = self.operation.epilogue_type(alpha, beta) + + arguments = GemmArguments( + operation=self.operation, problem_size=problem_size, + A=A, B=B, C=C, D=D, + output_op=output_op, + gemm_mode=mode, + **kwargs + ) + + self.operation.run(arguments) + + if sync: + arguments.sync() + + return arguments diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py new file mode 100644 index 0000000000000000000000000000000000000000..59f90535c29a816541bc1a2155fea35afd1c94fd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py @@ -0,0 +1,269 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running GEMMs. + + The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run + grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS grouped GEMMs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects + plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1]) +""" +from __future__ import annotations +from typing import Optional +from cutlass_library import DataTypeSize + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +from cutlass_cppgen.backend.gemm_operation import ( + GemmGroupedArguments, + GemmOperationGrouped, +) +from cutlass_cppgen.backend.library import ( + SchedulerMode, + TensorDescription, + TileDescription, +) +from cutlass_cppgen.op.gemm import Gemm +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils import check, datatypes + + +class GroupedGemm(Gemm): + """ + Constructs a ``GroupedGemm`` object. + + The data types and layouts of operands A, B, and C, along with the data type of output D + and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime -- + these are not to be changed after a ``GroupedGemm`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. Please see the constructor + for ``Gemm`` for examples of these. + + :param cc: compute capability of device to generate kernels for + :type cc: int + :param A: tensor representing data type and layout of operands A + :param B: tensor representing data type and layout of operands B + :param C: tensor representing data type and layout of operands C + :param D: tensor representing data type and layout of operands D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass_cppgen.DataType + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass_cppgen.DataType + :param layout: generic layout type to be used for operands A, B, C, and D + :type layout: cutlass_cppgen.LayoutType + :param element_A: data type to be used for operand A + :type element_A: cutlass_cppgen.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass_cppgen.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass_cppgen.DataType + :type layout_A: layout of operand A + :param layout_A: cutlass_cppgen.LayoutType + :type layout_B: layout of operand B + :param layout_B: cutlass_cppgen.LayoutType + :type layout_C: layout of operand C + :param layout_C: cutlass_cppgen.LayoutType + :type layout_D: layout of operand D + :param layout_D: cutlass_cppgen.LayoutType + """ + + def __init__( + self, A=None, B=None, C=None, D=None, + alpha=1.0, beta=0.0, element_accumulator=None, + element=None, layout=None, + element_A=None, element_B=None, element_C=None, element_D=None, + layout_A=None, layout_B=None, layout_C=None, + cc: int = None, + ): + super().__init__( + A=A, B=B, C=C, D=D, + alpha=alpha, beta=beta, + element_accumulator=element_accumulator, + element=element, layout=layout, + element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_D, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + cc=cc + ) + + # Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80 + if self.current_cc in [90, 100, 101, 103]: + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + + self.name = "grouped_gemm" + + @Gemm.swizzling_functor.setter + def swizzling_functor(self, swizzling_functor): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + raise Exception('Grouped GEMM does not currently support different swizzling functors') + + def construct(self, tile_description: TileDescription = None, + alignment_A: int = None, + alignment_B: int = None, + alignment_C: int = None) -> GemmOperationGrouped: + """ + Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current + kernel specification of the ``Gemm`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass_cppgen.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + + :return: operation that was constructed + :rtype: cutlass_cppgen.backend.GemmOperationGrouped + """ + alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A"))) + alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B"))) + alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C"))) + + self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) + + tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A) + tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) + tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + + if tile_description is None: + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0] + tile_description = datatypes.td_from_profiler_op(op) + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self.tile_description = tile_description + + operation = GemmOperationGrouped( + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + epilogue_functor=self.epilogue_functor, + swizzling_functor=self._swizzling_functor, + precompute_mode=SchedulerMode.Device) + + return operation + + def run(self, A, B, C, D, + alpha=None, beta=None, sync: bool = True, + print_module: bool = False, + stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments: + """ + Runs the kernel currently specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: list of tensors representing data type and layout of operand A + :type A: list + :param B: list of tensors representing data type and layout of operand B + :type B: list + :param C: list of tensors representing data type and layout of operand C + :type C: list + :param D: list of tensors representing data type and layout of operand D + :type D: list + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + :param stream: cuda stream, defaults to cuda.cuda.CUstream(0) + :type stream: :class:`cuda.cuda.CUstream` + + :return: arguments passed in to the kernel + :rtype: cutlass_cppgen.backend.GemmGroupedArguments + """ + if not stream: + stream = cuda.CUstream(0) + + super().run_setup() + + if len(A) != len(B) or len(A) != len(C) or len(A) != len(D): + raise Exception("Lengths of A, B, C, and D lists must be equal") + + problem_sizes = [] + As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4)) + for i in range(len(A)): + As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A") + Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B") + Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C") + Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D") + problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1])) + + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As)) + alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs)) + alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs)) + self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, print_module=print_module) + + arguments = GemmGroupedArguments( + operation=self.operation, + problem_sizes=problem_sizes, + A=As, B=Bs, C=Cs, D=Ds, + output_op=self.operation.epilogue_type(alpha, beta), + stream=stream + ) + + self.operation.run(arguments) + + if sync: + arguments.sync() + + return arguments diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py new file mode 100644 index 0000000000000000000000000000000000000000..bebf07a7e5b83a1cf14cfecf19e90f730e305dce --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py @@ -0,0 +1,431 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) +""" + +from bisect import bisect_left + +from cutlass_library import ( + DataType, + DataTypeSize, + MathOperation, + OperationKind, + SharedMemPerCC +) + +import cutlass_cppgen +from cutlass_cppgen import get_option_registry +from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor +from cutlass_cppgen.backend.evt.passes.util import cc_map +from cutlass_cppgen.backend.utils.device import device_cc +from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity +from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs +from cutlass_cppgen.swizzle import get_swizzling_functors +from cutlass_cppgen.utils import datatypes, check + + +class OperationBase: + """ + Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) + """ + + def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm): + """ + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + :param operation_kind: class of operation that will be performed (e.g., GEMM, Conv) + :type operation_kind: cutlass_library.OperationKind + """ + self.operation_kind = operation_kind + self.cc = cc if cc is not None else device_cc() + self.specified_kernel_cc = kernel_cc is not None + self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc) + self.tile_description = None + self._math_operation = None + + self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind) + + if self.options is None: + raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}") + + # Default activation function: identity + self._activation = identity + + def _find_closest_cc(self, cc: int) -> int: + """ + Returns the closest CC in _generator_ccs less than or equal to `cc` + + :param cc: compute capability to query + :type cc: int + + :returns: closest CC in _generator_ccs less than or equal to `cc` + :rtype: int + """ + if cc in _generator_ccs: + return cc + + # Find closest CC lower than this CC + idx = bisect_left(_generator_ccs, cc) + if idx == 0: + raise Exception(f'No valid CC to fall back to for {cc}') + return _generator_ccs[idx-1] + + def activations(self) -> list: + """ + Returns possible activation functions that can be used + + :return: list of activation functions that can be used + :rtype: list + """ + return get_activations() + + def swizzling_functors(self) -> list: + """ + Returns possible swizzling functions that can be used + + :return: list of swizzling functions that can be used + :rtype: list + """ + return get_swizzling_functors() + + def _reset_options(self, cc: int): + """ + Resets the kernel options based on cc + + :param cc: compute capability to reset to + :type cc: int + """ + if cc != self.current_cc: + if cc not in _generator_ccs: + raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.') + self.current_cc = cc + self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind) + + def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name): + """ + Verifies the following properties: + 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``) + 2) If ``scalar`` is not ``None``, its datatype must match matches the current version + set by the plan (i.e., those in ``ref_dtype``) + + If either of these properties does not hold, an exception is raised. If these properties hold and + ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned. + + :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type scalar: numpy/cupy/torch scalar + :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in + :type ref_scalar: numpy/cupy/torch scalar + :param ref_dtype: data type for the scalar that this object was initialized to + :param name: identifier of the scalar to verify. Used in raising exceptions + :type name: str + + :return: valid scalar to use + :rtype: numpy/cupy/torch scalar + """ + if scalar is None: + if ref_scalar is None: + raise Exception(f"Scalar {name} must be set.") + return ref_scalar + if hasattr(scalar, "dtype"): + dtype = datatypes.library_type(scalar.dtype) + if dtype != ref_dtype: + raise Exception( + f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}." + ) + return scalar + + def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name): + """ + Verifies the following properties: + If ref_dtype is not void: + 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``) + 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions + set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``) + If ref_dtype is void: + Neither ``tensor`` nor ``ref_tensor`` are set + + If either of these properties does not hold, an exception is raised. If these properties hold and + ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in + :type ref_tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param ref_layout: layout for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + + :return: valid tensor object to use + :rtype: numpy/cupy/torch array/tensor object + """ + if ref_dtype == DataType.void: + if tensor is not None or ref_tensor is not None: + raise Exception("Operands with element DataType.void must not be provided a tensor") + return None + + if tensor is None: + if ref_tensor is None: + raise Exception(f"Tensor {name} must be set.") + return ref_tensor + + self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name) + return tensor + + @property + def opclass(self) -> cutlass_cppgen.OpcodeClass: + """ + Returns the opcode class currently in use + + :return: opcode class currently in use + :rtype: cutlass_cppgen.OpcodeClass + """ + return self.op_class + + @opclass.setter + def opclass(self, oc: cutlass_cppgen.OpcodeClass): + if isinstance(oc, str): + oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc) + if oc in self.possible_op_classes: + self.op_class = oc + else: + raise Exception( + f'Unsupported operation class {oc} for CC {self.cc} and data type combination ' + f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and ' + f'layout combination ({self._layout_a}, {self._layout_b}).') + + # Changing the op class also changes the possible operations available. Reset these. + self.possible_operations = self.options.operations( + self.op_class, self._element_a, self._element_b, + self._element_accumulator, self._layout_a, self._layout_b, self._math_operation) + + # Changing the op class changes the elements per access in the epilogue. Reset this. + if self.epilogue_functor is not None: + self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor) + + @property + def math_operation(self) -> cutlass_cppgen.MathOperation: + """ + Returns the math operation currently in use + + :return: math operation currently in use + :rtype: cutlass_cppgen.MathOperation + """ + return self._math_operation + + @math_operation.setter + def math_operation(self, mo: cutlass_cppgen.MathOperation): + if isinstance(mo, str): + mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo) + + if not self.specified_kernel_cc: + if self.current_cc in [90, 100, 101, 103]: + # CUTLASS 3.0 kernels do not use different math operations. If one is specified, we + # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + elif self.current_cc in [90, 100, 101, 103]: + raise Exception("CUTLASS 3.0 kernels do not use different math operations. " + "To use 2.x kernels with a specific math operation, do not set the `kernel_cc`" + "parameter when constructing the plan.") + + self._math_operation = mo + self._reset_operations() + + def _elements_per_access(self): + if self.op_class == cutlass_cppgen.OpcodeClass.Simt: + return 1 + elif self._element_c != DataType.void: + return 128 // DataTypeSize[self._element_c] + else: + return 128 // max(self.possible_operations.alignments("C")) + + def _create_epilogue_functor_activation(self, activation): + """ + Returns the epilogue functor with given activation function + """ + if self.epilogue_functor is None: + elements_per_access = self._elements_per_access() + else: + elements_per_access = self.epilogue_functor.epilogue_vector_length + + if not self.specified_kernel_cc: + if self.current_cc in [90, 100, 101, 103] and activation != identity: + # CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation, + # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. + cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + if self._element_c != self._element_d: + raise Exception("CUTLASS 2.x kernels require element C to be the same as element D") + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None): + # SM80 fallback kernels are currently used. Since an identity activation is requested, + # we can switch back to using SM90 kernels. + self._reset_options(self.cc) + self._reset_operations(reset_epilogue=False) + else: + if self.current_cc in [90, 100, 101, 103] and activation != identity: + raise Exception("Epilogues with elementwise fusion are not currently supported " + "in the Python interface for 3.x kernels. To use 2.x kernels " + "with fused elementwise epilogues, do not set the `kernel_cc` " + "parameter when constructing the plan.") + + return get_activation_epilogue( + activation, + self._element_d, + elements_per_access, + self._element_accumulator, + self._element_accumulator, + ) + + def _reset_epilogue_functor_activation(self, activation): + """ + Set the epilogue functor based on the provided activation function + """ + self.epilogue_functor = self._create_epilogue_functor_activation(activation) + + def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor): + """ + Reset the alignment of the current epilogue functor based on alignment C + """ + if isinstance(epilogue_functor, EpilogueFunctorVisitor): + return epilogue_functor + + if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'): + # Identity epilogue does not have 'activation_functor' + activation = identity + else: + activation = epilogue_functor.activation_functor + + epilogue_functor = get_activation_epilogue( + activation, + self._element_d, + alignment, + self._element_accumulator, + self._element_accumulator, + ) + return epilogue_functor + + @property + def activation(self): + """ + Returns the type of the current activation function used + """ + if hasattr(self.epilogue_functor, "activation_functor"): + return self.epilogue_functor.activation_functor + else: + return identity + + @activation.setter + def activation(self, act): + """ + Sets the type of the activation function to use + Activation can come with a set of arguments + + :param act: type of activation function to use + :type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01) + + """ + if isinstance(act, tuple): + if isinstance(act[0], str): + act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0]) + else: + act_fn = act[0] + self._reset_epilogue_functor_activation(act_fn) + self._activation_args = act[1] + self._activation = act[0] + else: + if isinstance(act, str): + act = getattr(cutlass_cppgen.backend.epilogue, act) + self._reset_epilogue_functor_activation(act) + self._activation = act + + @property + def epilogue_visitor(self): + """ + Return the epilogue functor + """ + return self.epilogue_functor + + @epilogue_visitor.setter + def epilogue_visitor(self, visitor): + """ + Create the epilogue visitor + """ + self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor) + + # The epilogue_functor may consume too much shared memory + # Reset the possible operations + if self.cc not in [90, 100, 101, 103]: + # The shared memory is only a concern for sm90+ epilogue + # In sm80, the epilogue and mainloop share the shared memory + return + + datatype_comb = self.possible_operations.datatype_comb + layout_comb = self.possible_operations.layout_comb + new_possible_operations = KernelsForDataType(datatype_comb, layout_comb) + for operation in self.possible_operations.all_operations: + td = datatypes.td_from_profiler_op(operation) + # Filter invalid epilogue schedules + if cc_map[self.cc] == 90 and td.epilogue_schedule not in [ + cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized, + cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]: + continue + epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td) + + # Verify the maximum number of mainloop stages + mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm) + smem_capacity_bytes = SharedMemPerCC[self.cc] << 10 + mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage + if mainloop_stages < 2: + # Mainloop stages must >= 2 + continue + + new_possible_operations.add(operation) + if len(new_possible_operations.all_operations) == 0: + raise RuntimeError( + "The epilogue consumes too much shared memory. " + "No valid tile description is found in the generator.") + self.possible_operations = new_possible_operations + + + def run_setup(self): + """ + Steps that must be taken before caling `plan.run()` + """ + # Initialize the memory pool if, if not already done + cutlass_cppgen.get_memory_pool() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py new file mode 100644 index 0000000000000000000000000000000000000000..a718f9bb4432f1f51457661abe27e24ea818aba4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py @@ -0,0 +1,184 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for expressing shapes +""" + +from cutlass_library import ( + ConvMode, + ConvKind, + LayoutType +) +from cutlass_cppgen.backend.c_types import ( + Conv2DProblemSize_, + GemmCoord_, + GemmCoordBatched_ +) + + +class MatrixCoord: + def __init__(self, row, col): + self._row = row + self._col = col + + @property + def row(self): + return self._row + + @property + def column(self): + return self._col + + def leading_dimension(self, layout: LayoutType) -> int: + """ + Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord. + + :param layout: layout of matrix + :type layout: cutlass_library.LayoutType + + :returns: leading dimension + :rtype: int + """ + if layout == LayoutType.RowMajor: + return self._col + elif layout == LayoutType.ColumnMajor: + return self._row + else: + raise Exception(f'Unsupported layout for leading dimension calculation: {layout}') + + +class GemmCoord: + def __init__(self, m: int, n: int, k: int): + self._m = m + self._n = n + self._k = k + + @property + def m(self) -> int: + return self._m + + @property + def n(self) -> int: + return self._n + + @property + def k(self) -> int: + return self._k + + @property + def mk(self) -> MatrixCoord: + return MatrixCoord(self._m, self._k) + + @property + def mn(self) -> MatrixCoord: + return MatrixCoord(self._m, self._n) + + @property + def kn(self) -> MatrixCoord: + return MatrixCoord(self._k, self._n) + + @property + def ctype(self) -> GemmCoord_: + return GemmCoord_(self._m, self._n, self._k) + + def batched_ctype(self, batch_count: int) -> GemmCoordBatched_: + return GemmCoordBatched_(self._m, self._n, self._k, batch_count) + + +class Conv2DProblemSize: + def __init__( + self, n: int, h: int, w: int, c: int, + k: int, r: int, s: int, c_: int, + pad_h: int, pad_w: int, stride_h: int, stride_w: int, + dilation_h: int, dilation_w: int, mode: ConvMode=ConvMode.CrossCorrelation, + split_k_slices: int=1, groups: int=1): + + self.N = n + self.H = h + self.W = w + self.C = c + self.K = k + self.R = r + self.S = s + self.pad_h = pad_h + self.pad_w = pad_w + self.stride_h = stride_h + self.stride_w = stride_w + self.dilation_h = dilation_h + self.dilation_w = dilation_w + self.mode = int(mode) + self.split_k_slices = split_k_slices + self.groups = groups + self.P = ((h + pad_h * 2 - r * dilation_h) // stride_h) + 1 + self.Q = ((w + pad_w * 2 - s * dilation_w) // stride_w) + 1 + + @property + def ctype(self) -> Conv2DProblemSize_: + return Conv2DProblemSize_(self) + + def implicit_gemm_size(self, kind: ConvKind): + if kind == ConvKind.Fprop: + return GemmCoord( + self.N * self.P * self.Q, + self.K, + self.R * self.S * self.C // self.groups + ) + elif kind == ConvKind.Dgrad: + return GemmCoord( + self.N * self.H * self.W, + self.C, + self.R * self.S * self.K + ) + elif kind == ConvKind.Wgrad: + return GemmCoord( + self.K, + self.R * self.S * self.C, + self.N * self.P * self.Q + ) + + @staticmethod + def from_sizes(input_size, weight_size): + K, R, S, _ = weight_size + pad_h = R // 2 + pad_w = S // 2 + stride_h = 1 + stride_w = 1 + dilation_h = 1 + dilation_w = 1 + return Conv2DProblemSize( + *input_size, + *weight_size, + pad_h, pad_w, + stride_h, stride_w, + dilation_h, dilation_w + ) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py new file mode 100644 index 0000000000000000000000000000000000000000..ffd9483415ea36716bf4643d27b8d92f3e9878a5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py @@ -0,0 +1,65 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Registry of swizzling functions +""" + +from cutlass_library import SwizzlingFunctor + + +IdentitySwizzle1 = SwizzlingFunctor.Identity1 +IdentitySwizzle2 = SwizzlingFunctor.Identity2 +IdentitySwizzle4 = SwizzlingFunctor.Identity4 +IdentitySwizzle8 = SwizzlingFunctor.Identity8 +HorizontalSwizzle = SwizzlingFunctor.Horizontal +ThreadblockSwizzleStreamK = SwizzlingFunctor.StreamK +StridedDgradIdentitySwizzle1 = SwizzlingFunctor.StridedDgradIdentity1 +StridedDgradIdentitySwizzle4 = SwizzlingFunctor.StridedDgradIdentity4 +StridedDgradHorizontalSwizzle = SwizzlingFunctor.StridedDgradHorizontal + + +_swizzling_functors = [ + IdentitySwizzle1, + IdentitySwizzle2, + IdentitySwizzle4, + IdentitySwizzle8, + HorizontalSwizzle, + ThreadblockSwizzleStreamK, + StridedDgradIdentitySwizzle1, + StridedDgradIdentitySwizzle4, + StridedDgradHorizontalSwizzle, +] + + +def get_swizzling_functors(): + return _swizzling_functors diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..75d8416a15070ddcf2c6270248ccd9deff8e2137 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py @@ -0,0 +1,41 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_cppgen.utils.check import ( + alignment_or_default, + calculate_smem_usage, + calculate_smem_usage_per_stage, + valid_cluster_shape, + valid_schedule, + valid_stage_count, + update_alignment, +) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py new file mode 100644 index 0000000000000000000000000000000000000000..108f268b4bc54ec0839afb5c1602ba63e5b98743 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py @@ -0,0 +1,262 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for checking constraints on kernels and calculating kernel attributes +""" + +import ctypes + +from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC + +import cutlass_cppgen +from cutlass_cppgen.backend.library import TileDescription + + +def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int: + """ + Returns the amount of shared memory in bytes consumed in a single stage of a kernel. + + :param td: tile description to compute shared memory of + :type td: TileDescription + :param operation_kind: identifier for the type of operation being performed + :type operation_kind: cutlass_library.OperationKind + + :return: number of bytes of shared memory consumed by a single stage + :rtype: int + """ + m, n, k = td.blackwell_threadblock_shape + if td.is_2sm: + m //= 2 + + if operation_kind == OperationKind.Gemm: + stage_barrier_bytes = 32 + return ( + (DataTypeSize[td.math_instruction.element_a] * m * k // 8) + + (DataTypeSize[td.math_instruction.element_b] * k * n // 8) + + stage_barrier_bytes + ) + else: + raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}") + + +def calculate_smem_usage(operation) -> int: + """ + Returns the amount of shared memory in bytes consumed by a kernel. + + :return: number of bytes of shared memory consumed by the operation + :return: int + """ + _per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind) + return _per_stage * operation.tile_description.stages + + +def valid_stage_count( + cc: int, + kernel_cc: int, + td: TileDescription, + element_C: cutlass_cppgen.DataType = None, + element_D: cutlass_cppgen.DataType = None, + verbose: bool = True) -> tuple: + """ + Checks whether a device with `cc` supports the number of stages within `tile_description`, both + based on raw limits on the number of stages and based on shared memory capacity + + :param cc: compute capability of device in question + :type cc: int + :param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS) + :type kernel_cc: int + :param td: tile description to check + :type td: TileDescription + :param element_C: data type of operand C + :type element_C: cutlass_cppgen.DataType + :param element_D: data type of operand D + :type element_D: cutlass_cppgen.DataType + :param verbose: whether to log warnings + :type verbose: bool + + :return: tuple with the first element indicating whether the provided tile description is + valid for the provided device and the second element being an error message + :rtype: tuple + """ + if kernel_cc in [90, 100, 101, 103]: + if (td.stages is None or td.stages == 0): + # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically + # determines the stage count to use. Thus, all settings are valid in these scenarios. + return (True, "") + elif verbose: + cutlass_cppgen.logger.warning( + "Setting an explicit stage count for SM90 kernels currently may " + "result in compilation errors if the combination of tile shape, " + "stage count, and shared memory requirement of the epilogue exceeds " + "the available shared memory per SM.") + + if td.stages <= 0: + return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.") + + if cc < 80 and td.stages != 2: + return (False, f"Tile description has stage count of {td.stages}, " + f"but only 2 stages are supported on SM{cc}.") + + # The calculation below does not consider shared memory used by the epilogue and, thus, + # only catches cases in which the mainloop exceeds the device's shared memory capacity. + # This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the + # mainloop and epilogue is shared. + smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm) + smem_usage_mainloop = (smem_per_stage * td.stages) + smem_arch = SharedMemPerCC[cc] << 10 + if smem_usage_mainloop > smem_arch: + return ( False, + "Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n" + f"Details:\n" + f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and " + f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n" + f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.") + + return (True, "") + + +def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: + """ + Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`. + + :param cc: compute capability of device in question + :type cc: int + :param cluster_shape: dimensions of thread block cluster shape to check + :type cluster_shape: list + + :return: tuple with the first element indicating whether the provided cluster shape is + valid for the provided device and the second element being an error message + :rtype: tuple + """ + + if cc < 90 or cc in [120, 121]: + if cluster_shape != [1, 1, 1]: + return (False, + f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of " + f"{cluster_shape} for SM{cc}.") + else: + return (True, "") + + if len(cluster_shape) != 3: + return (False, + f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}") + + if cluster_shape[2] != 1: + return (False, + "CUTLASS kernels currently require the third dimension of cluster shape to be 1. " + f"Received cluster shape of {cluster_shape}.") + + return (True, "") + + +def valid_schedule( + cc: int, + kernel_schedule: cutlass_cppgen.KernelScheduleType, + epilogue_schedule: cutlass_cppgen.EpilogueScheduleType, + tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple: + """ + Checks that the kernel and epilogue schedules passed in are a valid combination for + a device of compute capability ``cc``. + + :param cc: compute capability of device in question + :type cc: int + :param kernel_schedule: kernel schedule type + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue schedule type + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param tile_scheduler: tile scheduler type + :type tile_scheduler: cutlass_cppgen.TileSchedulerType + + :return: tuple with the first element indicating whether the provided schedules are + valid for the provided device and the second element being an error message + :rtype: tuple + """ + kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto) + epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto) + tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default) + if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default): + return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)") + + if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)): + return (False, "Kernel and epilogue schedules must either both be auto or neither be auto") + + if not tile_scheduler_default: + cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative] + if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): + return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule") + return (True, "") + + +def alignment_or_default(alignment_provided: int, default_alignment: int) -> int: + """ + Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks + that `alignment_provided` does not exceed `default_alignment`. + + :param alignment_provided: alignment preference specified. Can be None. + :type alignment_provided: int + :param default_alignment: alignment to use if `alignment_provided` is None + :type default_alignment: int + + :return: alignment to use + :rtype: int + """ + if alignment_provided is not None: + if alignment_provided > default_alignment: + raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.") + return alignment_provided + + return default_alignment + + +def update_alignment(alignment_provided:int, default_alignment: int) -> int: + """ + Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks + that `alignment_provided` does not exceed `default_alignment`. + + :param alignment_provided: alignment preference specified. Can be None. + :type alignment_provided: int + :param default_alignment: alignment to use if `alignment_provided` is None + :type default_alignment: int + + :return: alignment to use + :rtype: int + """ + if alignment_provided is not None: + if alignment_provided > default_alignment: + if alignment_provided % default_alignment == 0: + return default_alignment + raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.") + return alignment_provided + + return default_alignment diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py new file mode 100644 index 0000000000000000000000000000000000000000..c03a834dc47871bebe618752e4775a0a7434ff78 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py @@ -0,0 +1,362 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for converting between frontend datatypes and CUTLASS datatypes +""" + +import cutlass_cppgen +from cutlass_library import ( + DataTypeSize, + MathOperation, + MathInstruction +) +from cutlass_cppgen.backend.library import ( + TileDescription, +) + +bfloat16_available = None +cupy_available = None +numpy_available = None +torch_available = None +_library_to_cupy_dict = None +_library_to_numpy_dict = None +_library_to_torch_dict = None +_torch_to_library_dict = None + + +def is_numpy_available(): + global numpy_available, _library_to_numpy_dict + if numpy_available is None: + try: + import numpy as np + + numpy_available = True + _library_to_numpy_dict = { + cutlass_cppgen.DataType.f16: np.float16, + cutlass_cppgen.DataType.f32: np.float32, + cutlass_cppgen.DataType.f64: np.float64, + cutlass_cppgen.DataType.s8: np.int8, + cutlass_cppgen.DataType.s32: np.int32, + } + except ImportError: + numpy_available = False + _library_to_numpy_dict = {} + return numpy_available + + +def is_numpy_tensor(inp) -> bool: + if is_numpy_available(): + import numpy as np + return isinstance(inp, np.ndarray) + return False + + +def numpy_library_type(inp) -> cutlass_cppgen.DataType: + if is_numpy_available(): + import numpy as np + if inp == np.float16: + return cutlass_cppgen.DataType.f16 + elif inp == np.float32: + return cutlass_cppgen.DataType.f32 + elif inp == np.float64: + return cutlass_cppgen.DataType.f64 + elif inp == np.int8: + return cutlass_cppgen.DataType.s8 + elif inp == np.int32: + return cutlass_cppgen.DataType.s32 + return None + + +def numpy_type(inp): + return _library_to_numpy_dict.get(inp, None) + + +def is_cupy_available(): + global cupy_available + if cupy_available is None: + try: + import cupy as cp + + cupy_available = True + _library_to_cupy_dict = { + cutlass_cppgen.DataType.f16: cp.float16, + cutlass_cppgen.DataType.f32: cp.float32, + cutlass_cppgen.DataType.f64: cp.float64, + cutlass_cppgen.DataType.s8: cp.int8, + cutlass_cppgen.DataType.s32: cp.int32, + } + except ImportError: + cupy_available = False + _library_to_cupy_dict = {} + return cupy_available + + +def is_cupy_tensor(inp) -> bool: + if is_cupy_available(): + import cupy as cp + return isinstance(inp, cp.ndarray) + return False + + +def cupy_library_type(inp) -> cutlass_cppgen.DataType: + if is_cupy_available(): + import cupy as cp + if inp == cp.float16: + return cutlass_cppgen.DataType.f16 + elif inp == cp.float32: + return cutlass_cppgen.DataType.f32 + elif inp == cp.float64: + return cutlass_cppgen.DataType.f64 + return None + + +def cupy_type(inp): + return _library_to_cupy_dict.get(inp, None) + + +def is_torch_available(): + global torch_available, _library_to_torch_dict, _torch_to_library_dict + if torch_available is None: + try: + import torch + + torch_available = True + _torch_to_library_dict = { + torch.half: cutlass_cppgen.DataType.f16, + torch.float16: cutlass_cppgen.DataType.f16, + torch.bfloat16: cutlass_cppgen.DataType.bf16, + torch.float: cutlass_cppgen.DataType.f32, + torch.float32: cutlass_cppgen.DataType.f32, + torch.double: cutlass_cppgen.DataType.f64, + torch.float64: cutlass_cppgen.DataType.f64, + torch.int8: cutlass_cppgen.DataType.s8, + torch.int32: cutlass_cppgen.DataType.s32, + torch.uint8: cutlass_cppgen.DataType.u8, + } + + _library_to_torch_dict = { + cutlass_cppgen.DataType.f16: torch.half, + cutlass_cppgen.DataType.f16: torch.float16, + cutlass_cppgen.DataType.bf16: torch.bfloat16, + cutlass_cppgen.DataType.f32: torch.float, + cutlass_cppgen.DataType.f32: torch.float32, + cutlass_cppgen.DataType.f64: torch.double, + cutlass_cppgen.DataType.f64: torch.float64, + cutlass_cppgen.DataType.s8: torch.int8, + cutlass_cppgen.DataType.s32: torch.int32, + cutlass_cppgen.DataType.u8: torch.uint8, + } + + def possibly_add_type(torch_type_name, cutlass_type): + # Only try adding the type if the version of torch being used supports it + if hasattr(torch, torch_type_name): + torch_type = getattr(torch, torch_type_name) + _torch_to_library_dict[torch_type] = cutlass_type + _library_to_torch_dict[cutlass_type] = torch_type + + possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3) + possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2) + + except ImportError: + torch_available = False + _torch_to_library_dict = {} + _library_to_torch_dict = {} + return torch_available + + +def is_torch_tensor(inp) -> bool: + if is_torch_available(): + import torch + return isinstance(inp, torch.Tensor) + return False + + +def torch_library_type(inp) -> cutlass_cppgen.DataType: + return _torch_to_library_dict.get(inp, None) + + +def torch_type(inp): + return _library_to_torch_dict.get(inp, None) + + +def is_bfloat16_available(): + global bfloat16_available + + if bfloat16_available is None: + try: + import bfloat16 + + bfloat16_available = True + except ImportError: + bfloat16_available = False + return bfloat16_available + + +def bfloat16_library_type(inp) -> cutlass_cppgen.DataType: + if is_bfloat16_available(): + import bfloat16 + if inp == bfloat16.bfloat16: + return cutlass_cppgen.DataType.bf16 + + +def bfloat16_type(inp): + if is_bfloat16_available(): + import bfloat16 + if inp == cutlass_cppgen.DataType.bf16: + return bfloat16.bfloat16 + + +def library_type(inp): + if inp in DataTypeSize: + return inp + + for cvt_fn in [ + bfloat16_library_type, + cupy_library_type, + numpy_library_type, + torch_library_type, + ]: + out = cvt_fn(inp) + if out is not None: + return out + + raise Exception(f"No available conversion from type {inp} to a library type.") + + +def _tensor_from_numpy(np_tensor): + dtype = library_type(np_tensor.dtype) + if np_tensor.flags.c_contiguous: + layout = cutlass_cppgen.LayoutType.RowMajor + elif np_tensor.flags.f_contiguous: + layout = cutlass_cppgen.LayoutType.ColumnMajor + return (dtype, layout) + + +def _tensor_from_torch(pt_tensor): + dtype = library_type(pt_tensor.dtype) + return (dtype, cutlass_cppgen.LayoutType.RowMajor) + + +def get_datatype_and_layout(tensor): + if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)): + return _tensor_from_numpy(tensor) + elif is_torch_tensor(tensor): + return _tensor_from_torch(tensor) + elif isinstance(tensor, float) or isinstance(tensor, int): + return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor) + else: + raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") + + +def get_tensor_shape(tensor, op="GEMM"): + if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)): + return tensor.shape + elif is_torch_tensor(tensor): + size = tensor.size() + if op == "CONV": + # PyTorch Tensors have shape NCHW + return (size[0], size[2], size[3], size[1]) + else: + return tuple(tensor.size()) + elif isinstance(tensor, float) or isinstance(tensor, int): + return (1,) + else: + raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") + + +_math_operation_value_map = {x.value: x for x in MathOperation} + + +def backend_math_operation(math_op: MathOperation): + if math_op.value not in _math_operation_value_map.keys(): + raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.") + return _math_operation_value_map[math_op.value] + + +def construct_backend_td(td: cutlass_cppgen.TileDescription, + kernel_schedule: cutlass_cppgen.KernelScheduleType, + epilogue_schedule: cutlass_cppgen.EpilogueScheduleType, + tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription: + mi = td.math_instruction + backend_mi = MathInstruction( + mi.instruction_shape, + mi.element_a, + mi.element_b, + mi.element_accumulator, + mi.opcode_class, + backend_math_operation(mi.math_operation) + ) + cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1] + return TileDescription(td.threadblock_shape, td.stages, td.warp_count, + backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler) + + +def td_from_profiler_op(op) -> TileDescription: + """ + Converts the profiler's TileDescription in ``op`` into the backend TileDescription + + :param op: profiler Operation + + :returns: backend TileDescription + :rtype: cutlass_cppgen.backend.TileDescription + """ + kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None + eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None + tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None + return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule) + + +def td_from_profiler_td(td: TileDescription) -> TileDescription: + """ + Converts the profiler's TileDescription into the backend TileDescription + + :param td: profiler TileDescription + :type td: cutlass_cppgen.TileDescription + + :returns: backend TileDescription + :rtype: cutlass_cppgen.backend.TileDescription + """ + return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None) + + +def to_camel_case(snake_str): + return "".join(x.capitalize() for x in snake_str.lower().split("_")) + + +def getattr_enum(obj, attr_name): + # The attr_name is under the snake_case + camel_attr = to_camel_case(attr_name) + if hasattr(obj, camel_attr): + return getattr(obj, camel_attr) + else: + raise Exception(f"Invalid option: {attr_name}") diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py new file mode 100644 index 0000000000000000000000000000000000000000..16f6a185040f4c2f6167c6191c9bee766a92b1b9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py @@ -0,0 +1,41 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# +import importlib +from typing import Any + +def lazy_import(mod_name: str) -> Any: + class Lazy: + def __getattr__(self, name:str) -> Any: + module = importlib.import_module(mod_name) + return getattr(module, name) + + return Lazy() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..f53b1567978d17f2eaec0208d896aafb296f033f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py @@ -0,0 +1,196 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Profiler based on the cuda events +""" + +import re +import subprocess + +from cutlass_cppgen.utils.lazy_import import lazy_import +cuda = lazy_import("cuda.cuda") +cudart = lazy_import("cuda.cudart") +import numpy as np + +from cutlass_cppgen import CUTLASS_PATH +from cutlass_cppgen.backend.library import DataTypeSize +from cutlass_cppgen.op.op import OperationBase +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils.datatypes import is_numpy_tensor + + +class GpuTimer: + def __init__(self) -> None: + self.events = [ + cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], + cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1], + ] + + def start(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + + (err,) = cuda.cuEventRecord(self.events[0], stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + + def stop(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + + (err,) = cuda.cuEventRecord(self.events[1], stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + pass + + def stop_and_wait(self, stream=None): + if not stream: + stream = cuda.CUstream(0) + + self.stop(stream) + if stream: + (err,) = cuda.cuStreamSynchronize(stream) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + else: + (err,) = cudart.cudaDeviceSynchronize() + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + + def duration(self, iterations=1): + err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1]) + if err != cuda.CUresult.CUDA_SUCCESS: + raise RuntimeError(f"CUDA Error {str(err)}") + return duration / float(iterations) + + +class CUDAEventProfiler: + def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None: + self.arguments = op.run(*args, **kwargs) + self.operation = op.operation + self.warmup_iterations = warmup_iterations + self.iterations = iterations + self.timer = GpuTimer() + + # + # Cutlass Python Interface Profiler + # + + def __call__(self): + for _ in range(self.warmup_iterations): + self.operation.run(self.arguments) + + self.timer.start() + for _ in range(self.iterations): + self.operation.run(self.arguments) + + self.timer.stop_and_wait() + runtime = self.timer.duration(self.iterations) + return runtime + + # + # CUTLASS Profiler + # + + def run_cutlass_profiler(self): + alpha = 1.0 + beta = 1.0 + + profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler" + kernel_name = self.operation.procedural_name() + verification_providers = "device" + provider = "cutlass" + problem_size = self.arguments.problem_size + + if "cutlass3x" in kernel_name: + # cutlass3x generator only have column-major output + layout_name = self.operation.layout_name_3x() + if layout_name[-1] == "t": + new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"]) + problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k) + kernel_name = kernel_name.replace(layout_name, new_layout_name) + + batch_count = self.arguments.batch_count + + cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \ + f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \ + f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\ + f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}" + + result = subprocess.getoutput(cmd) + + m = re.search(r"Runtime:\s+(?P\d+.\d+)", result) + runtime = float(m.group("runtime")) + + m = re.search(r"Bytes:\s+(?P\d+)", result) + bytes = int(m.group("bytes")) + + m = re.search(r"FLOPs:\s+(?P\d+)", result) + flops = int(m.group("flops")) + + # check if the problem size matches + assert bytes == self.bytes(problem_size, batch_count, beta) + assert flops == self.flops(problem_size, batch_count, beta) + + return runtime + + def bytes(self, problem_size, batch_count=1, beta=0.0): + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + + bytes = ( + (DataTypeSize[self.operation.A.element] * m // 8) * k + + (DataTypeSize[self.operation.B.element] * n // 8) * k + + (DataTypeSize[self.operation.C.element] * m // 8) * n + ) + + if beta != 0: + bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n + + bytes *= batch_count + + return bytes + + def flops(self, problem_size, batch_count=1, beta=0.0): + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + + flops_ = (m * n * k) * 2 * batch_count + + if beta != 0: + flops_ += m * n * batch_count * 2 + + return flops_ + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..534eef47d810eb9f17a9ba6dbbe2e0dff935eb3f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py @@ -0,0 +1,63 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import os +import sys + +from . import conv2d_operation +from . import conv3d_operation +from . import emit_kernel_listing +from . import gemm_operation + +if '-m' not in sys.argv: + # Do not import generator when running python -m cutlass_library.generator to + # avoid double-import warnings + from . import generator + +from . import library +from . import manifest +from . import rank_2k_operation +from . import rank_k_operation +from . import symm_operation +from . import trmm_operation +# Make enum types from library.py accessible via cutlass_library.* +from .library import * + +# Set up `source` to point to the path containing the CUTLASS source. +# Check first if the path contains a `source` subdirectory -- this will +# be the case when the package has been installed via pip. Otherwise, +# default to the root of CUTLASS. +install_source_path = os.path.join(__path__[0], 'source') +if os.path.isdir(install_source_path): + source_path = install_source_path +else: + source_path = os.path.join(__path__[0], '../..') diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..b674463a2c5795be8610883c4dc98a1e7123a01b --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py @@ -0,0 +1,621 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Conv2d kernels +""" + +import enum +import logging +import os.path +import shutil +from string import Template + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes +except ImportError: + from library import * + from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### + +# +class Conv2dOperation: + # + def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \ + stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \ + group_mode = GroupMode.NoneGroup): + + self.operation_kind = OperationKind.Conv2d + self.arch = arch + self.tile_description = tile_description + self.conv_kind = conv_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.iterator_algorithm = iterator_algorithm + self.stride_support = stride_support + self.swizzling_functor = swizzling_functor + self.group_mode = group_mode + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + intermediate_type = '' + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.accumulator_type(): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + else: + inst_shape = '' + + return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ + inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm]) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + threadblock = self.tile_description.procedural_name() + + # grouped conv + if self.group_mode != GroupMode.NoneGroup: + group_conv_name = f"{GroupModeNames[self.group_mode]}_" + else: + group_conv_name = "" + + if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}" + else: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}" + + return SubstituteTemplate( + configuration_name, + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'alignment': "%d" % self.A.alignment, + 'group_conv_name': group_conv_name + } + ) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.configuration_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +class EmitConv2dInstance: + def __init__(self): + # Emitter for CUTLASS 3 convolution operations + self.conv3x_emitter = EmitConv3xInstance() + self.template = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} + >::Kernel; +""" + self.template_group_conv = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator}, + ${group_mode}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} + >::Kernel; +""" + self.template_depthwise_direct_conv = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>, + cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >, + + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ${threadblock_output_shape_n}, + ${threadblock_output_shape_p}, + ${threadblock_output_shape_q}>, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + cutlass::MatrixShape<${stride_r}, ${stride_s}>, + cutlass::MatrixShape<${dilation_r}, ${dilation_s}> + >::Kernel; +""" + + def arch_number_to_type(self, arch: int): + return f"cutlass::arch::Sm{arch}" + + def emit(self, operation): + _LOGGER.debug("*** EmitConv2dInstance::emit") + _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) + + if hasattr(operation, 'is_3x') and operation.is_3x: + _LOGGER.debug("*** CUTLASS 3 operation") + return self.conv3x_emitter.emit(operation) + + _LOGGER.debug("*** CUTLASS 2 operation") + + warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'conv_kind': ConvKindTag[operation.conv_kind], + 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], + 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), + 'stride_support': StrideSupportTag[operation.stride_support], + 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \ + MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + } + + if operation.group_mode == GroupMode.NoneGroup: + _LOGGER.debug("*** group_mode=NoneGroup") + return SubstituteTemplate(self.template, values) + + elif operation.group_mode == GroupMode.Depthwise: + _LOGGER.debug("*** group_mode=Depthwise") + values['group_mode'] = GroupModeTag[operation.group_mode] + # Setup other template params + values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0]) + values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1]) + values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2]) + + values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3]) + + values['filter_shape_r'] = str(operation.tile_description.filter_shape[0]) + values['filter_shape_s'] = str(operation.tile_description.filter_shape[1]) + + values['stride_r'] = str(operation.tile_description.stride[0]) + values['stride_s'] = str(operation.tile_description.stride[1]) + + values['dilation_r'] = str(operation.tile_description.dilation[0]) + values['dilation_s'] = str(operation.tile_description.dilation[1]) + + return SubstituteTemplate(self.template_depthwise_direct_conv, values) + + else: + _LOGGER.debug("*** group_mode=" + GroupModeTag[operation.group_mode]) + values['group_mode'] = GroupModeTag[operation.group_mode] + return SubstituteTemplate(self.template_group_conv, values) + +################################################################################################### +# +# Generator functions for all layouts +# +################################################################################################### + +# +def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128): + _LOGGER.debug("*** GenerateConv2dTensorOp") + + for tile in tile_descriptions: + for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: + + if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]): + + # + output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ + if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ + else [tile.math_instruction.element_accumulator,] + + for output_type in output_types: + A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a])) + B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b])) + C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type]))) + + manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator)) + +class EmitConv2dIncludes: + '''Emit includes that are specific to the operation.''' + + def __init__(self): + self.includes = ['conv2d_operation.h'] + self.emitter_3x = EmitConv3xIncludes() + + def operation_is_3x(self, operation) -> bool: + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def emit(self, operation) -> str: + if self.operation_is_3x(operation): + return self.emitter_3x.emit(operation) + + return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ + "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitConv2dConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name) + + self.instance_emitter = EmitConv2dInstance() + self.includes_emitter = EmitConv2dIncludes() + + self.header_template = """ +/* + Generated by conv2d_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +""" + + self.instance_template = """ +${stub_begin} +${operation_instance} +// Derived class +struct ${operation_name} : + public ${operation_name}_base { }; +${stub_end} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.configuration_header = """ + +namespace cutlass { +namespace library { + +// Initialize all instances +void initialize_${configuration_name}(Manifest &manifest) { +""" + + self.configuration_instance = """${stub_begin} + using Operation_${operation_name} = cutlass::conv::device::${kernel_name}< + ${operation_name}>; + + manifest.append(new cutlass::library::${operation_wrapper}< + Operation_${operation_name} + >( + "${operation_name}" + )); +${stub_end} +""" + + self.configuration_epilogue = "}\n" + + self.epilogue_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def operation_is_3x(self, operation): + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def __enter__(self): + """ + Open the configuration_file, and write the "header" C++ code to it. + + The "header" consists of a comment (that this is generated code, + so it should not be edited), and includes that are common + to all kinds of kernels. + """ + _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__enter__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + self.configuration_file = open(self.configuration_path, "w") + + self.configuration_file.write(SubstituteTemplate(self.header_template, { + 'configuration_name': self.configuration_name + })) + self.operations = [] + return self + + def emit(self, operation): + """ + Write three pieces of C++ code to the configuration_file + (that was opened by the __enter__ method above): + + 1. the header includes that are specific to the operation + (CUTLASS 2 vs. CUTLASS 3); + + 2. the "operation instance" (a "using" declaration ending in "_base"); and + + 3. the "operation name" (declaration and definition of a derived class + of the above operation instance). + + The "using" declaration turns a C++ class name, possibly namespace-qualified, + possibly also with angle brackets, into a C-style, easily demangled identifier. + """ + _LOGGER.debug('*** EmitConv2dConfigurationLibrary::emit') + _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name()) + self.operations.append(operation) + + self.configuration_file.write(self.includes_emitter.emit(operation)) + + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = '#endif // 0' + + self.configuration_file.write(Template(self.instance_template).substitute({ + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'operation_instance': self.instance_emitter.emit(operation), + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + def __exit__(self, exception_type, exception_value, traceback): + """ + Write the rest of the C++ code to the configuration_file, and close the file. + + The "rest of the C++ code" has the following components. + + 1. Configuration header: Open the namespace(s), and open the definition + of the "initialize_${configuration_name}" registration function + that registers the operation with the Manifest. + ("Registration" helps turn C++ compile-time polymorphism + (via template parameters) into a run-time choice of parameters.) + + 2. Configuration instance: In the body of the registration function, + make a "using" declaration Operation_${operation_name} for the + operation type (which uses operation_name as its template argument). + Then, tell the manifest about the operation via a "manifest.append" call. + The argument of the call is a new instance of + "SomethingOperation" + (replace Something with a specific name). + + 3. Configuration epilogue: Close the definition of the registration function. + + 4. Epilogue template: Close the namespace(s). + """ + + _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__exit__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + + self.configuration_file.write(SubstituteTemplate(self.configuration_header, { + 'configuration_name': self.configuration_name + })) + + for operation in self.operations: + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = "#endif // 0" + + if operation.group_mode == GroupMode.Depthwise: + kernel_name = 'DirectConvolution' + operation_wrapper = 'DirectConv2dOperation' + else: + kernel_name = 'ImplicitGemmConvolution' + operation_wrapper = 'Conv2dOperation' + if self.operation_is_3x(operation): + kernel_name = 'ConvUniversalAdapter' + operation_wrapper = 'ConvOperation3x' + + self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'kernel_name': kernel_name, + 'operation_wrapper': operation_wrapper, + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + self.configuration_file.write(self.configuration_epilogue) + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + + +################################################################################################### +################################################################################################### diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..b96b6db74224e52bd90b6e184a62624475385352 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py @@ -0,0 +1,482 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Conv3d kernels +""" + +import enum +import logging +import os.path +import shutil +from string import Template + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes +except ImportError: + from library import * + from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### + +# +class Conv3dOperation: + # + def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \ + stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + self.operation_kind = OperationKind.Conv3d + self.arch = arch + self.tile_description = tile_description + self.conv_kind = conv_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.iterator_algorithm = iterator_algorithm + self.stride_support = stride_support + self.swizzling_functor = swizzling_functor + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + intermediate_type = '' + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + else: + inst_shape = '' + + return "%s%s%s%s3d_%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], \ + inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm]) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + threadblock = "%dx%d_%dx%d" % ( + self.tile_description.threadblock_shape[0], + self.tile_description.threadblock_shape[1], + self.tile_description.threadblock_shape[2], + self.tile_description.stages + ) + + if self.stride_support == StrideSupport.Unity: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_unity_stride" + else: + configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}" + + return SubstituteTemplate( + configuration_name, + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + } + ) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.configuration_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +class EmitConv3dInstance: + def __init__(self): + # Emitter for CUTLASS 3 convolution operations + self.conv3x_emitter = EmitConv3xInstance() + self.template = """ + // Conv3d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name}_base = + typename cutlass::conv::kernel::DefaultConv3d${conv_kind_name}< + ${element_a}, + cutlass::layout::TensorNDHWC, + ${element_b}, + cutlass::layout::TensorNDHWC, + ${element_c}, + cutlass::layout::TensorNDHWC, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + cutlass::arch::OpMultiplyAdd, + ${iterator_algorithm}, + ${stride_support} + >::Kernel; +""" + + def emit(self, operation): + _LOGGER.debug("*** EmitConv3dInstance::emit") + _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) + + if hasattr(operation, 'is_3x') and operation.is_3x: + _LOGGER.debug("*** CUTLASS 3 operation") + return self.conv3x_emitter.emit(operation) + + _LOGGER.debug("*** CUTLASS 2 operation") + + warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'conv_kind': ConvKindTag[operation.conv_kind], + 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm], + 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(), + 'stride_support': StrideSupportTag[operation.stride_support] + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### +# +# Generator functions for all layouts +# +################################################################################################### + +# +def GenerateConv3dTensorOp(manifest, tile_descriptions, min_cc, align = 128): + + for tile in tile_descriptions: + for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: + + if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]): + + # + output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \ + if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \ + else [tile.math_instruction.element_accumulator,] + + for output_type in output_types: + A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_a])) + B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_b])) + C = TensorDescription(output_type, LayoutType.TensorNDHWC, max(1, int(align / DataTypeSize[output_type]))) + + manifest.append(Conv3dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator)) + +class EmitConv3dIncludes: + '''Emit includes that are specific to the operation.''' + + def __init__(self): + self.includes = ['conv3d_operation.h'] + self.emitter_3x = EmitConv3xIncludes() + + def operation_is_3x(self, operation) -> bool: + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def emit(self, operation) -> str: + if self.operation_is_3x(operation): + return self.emitter_3x.emit(operation) + + return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ + "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitConv3dConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name) + + self.instance_emitter = EmitConv3dInstance() + self.includes_emitter = EmitConv3dIncludes() + + self.header_template = """ +/* + Generated by conv3d_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +""" + + self.instance_template = """ +${stub_begin} +${operation_instance} +// Derived class +struct ${operation_name} : + public ${operation_name}_base { }; +${stub_end} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.configuration_header = """ + +namespace cutlass { +namespace library { + +// Initialize all instances +void initialize_${configuration_name}(Manifest &manifest) { +""" + + self.configuration_instance = """${stub_begin} + using Operation_${operation_name} = cutlass::conv::device::${kernel_name}< + ${operation_name}>; + + manifest.append(new cutlass::library::${operation_wrapper}< + Operation_${operation_name} + >( + "${operation_name}" + )); +${stub_end} +""" + + self.configuration_epilogue = "}\n" + + self.epilogue_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def operation_is_3x(self, operation): + """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)""" + return hasattr(operation, 'is_3x') and operation.is_3x + + def __enter__(self): + """ + Open the configuration_file, and write the "header" C++ code to it. + + The "header" consists of a comment (that this is generated code, + so it should not be edited), and includes that are common + to both the CUTLASS 2 and the CUTLASS 3 cases. + """ + _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__enter__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + self.configuration_file = open(self.configuration_path, "w") + + self.configuration_file.write(SubstituteTemplate(self.header_template, { + 'configuration_name': self.configuration_name + })) + self.operations = [] + return self + + def emit(self, operation): + """ + Write three pieces of C++ code to the configuration_file + (that was opened by the __enter__ method above): + + 1. the header includes that are specific to the operation + (CUTLASS 2 vs. CUTLASS 3); + + 2. the "operation instance" (a "using" declaration ending in "_base"); and + + 3. the "operation name" (declaration and definition of a derived class + of the above operation instance). + + The "using" declaration turns a C++ class name, possibly namespace-qualified, + possibly also with angle brackets, into a C-style, easily demangled identifier. + """ + _LOGGER.debug('*** EmitConv3dConfigurationLibrary::emit') + _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name()) + self.operations.append(operation) + + self.configuration_file.write(self.includes_emitter.emit(operation)) + + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = '#endif // 0' + + self.configuration_file.write(Template(self.instance_template).substitute({ + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'operation_instance': self.instance_emitter.emit(operation), + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + def __exit__(self, exception_type, exception_value, traceback): + """ + Write the rest of the C++ code to the configuration_file, and close the file. + + The "rest of the C++ code" has the following components. + + 1. Configuration header: Open the namespace(s), and open the definition + of the "initialize_${configuration_name}" registration function + that registers the operation with the Manifest. + ("Registration" helps turn C++ compile-time polymorphism + (via template parameters) into a run-time choice of parameters.) + + 2. Configuration instance: In the body of the registration function, + make a "using" declaration Operation_${operation_name} for the + operation type (which uses operation_name as its template argument). + Then, tell the manifest about the operation via a "manifest.append" call. + The argument of the call is a new instance of + "SomethingOperation" + (replace Something with a specific name). + + 3. Configuration epilogue: Close the definition of the registration function. + + 4. Epilogue template: Close the namespace(s). + """ + + _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__exit__') + _LOGGER.debug('*** configuration_path (file to write): ' + + str(self.configuration_path)) + _LOGGER.debug('*** configuration_name: ' + self.configuration_name) + + self.configuration_file.write(SubstituteTemplate(self.configuration_header, { + 'configuration_name': self.configuration_name + })) + + for operation in self.operations: + stub_begin = '' + stub_end = '' + # It can be useful to stub (comment) out instantiations for testing. + # In this case, one need only set is_stub to True. + is_stub = False + if is_stub: + stub_begin = "// STUB for now\n#if 0" + stub_end = "#endif // 0" + + kernel_name = 'ImplicitGemmConvolution' + operation_wrapper = 'Conv3dOperation' + if self.operation_is_3x(operation): + kernel_name = 'ConvUniversalAdapter' + operation_wrapper = 'ConvOperation3x' + + self.configuration_file.write(SubstituteTemplate(self.configuration_instance, { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'kernel_name': kernel_name, + 'operation_wrapper': operation_wrapper, + 'stub_begin': stub_begin, + 'stub_end': stub_end + })) + + self.configuration_file.write(self.configuration_epilogue) + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + + +################################################################################################### +################################################################################################### diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..33d6da1a4675c0bbd07315717a7f5ba0ba0dc10c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py @@ -0,0 +1,250 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting CUTLASS >= 3 convolution kernels +""" + +import enum +import os.path +import shutil +import logging +from string import Template + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +class EmitConv3xInstance: + def __init__(self): + _LOGGER.debug("*** EmitConv3xInstance::__init__") + + # Define epilogue type first, so that the mainloop type + # can use it with StageCountAutoCarveout. + self.template = """ + +// CUTLASS >= 3 convolution ${conv_kind_name} kernel instance "${operation_name}" +using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, + ${opcode_class_epi}, + ${mma_tile_shape}, // mma tile shape + ${cluster_shape}, // cluster shape + ${epi_tile_mn}, + ${element_accumulator}, + ${element_compute}, + ${element_c}, ${layout_c}, 128 / cute::sizeof_bits_v<${element_c}>, + ${element_d}, ${layout_d}, 128 / cute::sizeof_bits_v<${element_d}>, + ${epilogue_schedule} + // , class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination + >::CollectiveOp; + +using ${operation_name}_mainloop = + typename cutlass::conv::collective::CollectiveBuilder< + ${arch}, + ${opcode_class_main}, + ${conv_kind}, // kFprop, kDgrad, or kWgrad + ${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>, + ${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>, + ${element_accumulator}, + ${mma_tile_shape}, // mma tile shape + ${cluster_shape}, // cluster shape + ${stages}, + ${kernel_schedule} + >::CollectiveOp; + +using ${operation_name}_problem_shape = cutlass::conv::ConvProblemShape<${conv_kind}, ${operation_name}_mainloop::NumSpatialDimensions>; + +// Unit tests call this "ConvKernel". +// Conv operator ${operation_name} +using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal< + ${operation_name}_problem_shape, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler} + >; +""" + + def arch_number_to_type(self, arch: int) -> str: + return f"cutlass::arch::Sm{arch}" + + def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str: + mma_m = cta_m + mma_n = cta_n + mma_k = cta_k + + if operation.arch >= 100: + # MmaTileShape (mma_m, mma_n, mma_k) is passed to kernel mainloop where + # mma_m = cta_m for 1sm version and mma_m = cta_m * 2 for 2sm version. + # If schedule is auto and cluster size is static and cta_m % 64 == 0 and cluster_m % 2 == 0, 2sm kernel version is allocated, + # otherwise 1sm kernel is allocated. + cta_m_per_mma_instruction = 1 + if "2sm" in operation.procedural_name() : + cta_m_per_mma_instruction = 2 + elif "1sm" in operation.procedural_name() : + cta_m_per_mma_instruction = 1 + elif operation.tile_description.cluster_shape[0] > 0 and operation.tile_description.cluster_shape[0] % 2 == 0 and cta_m % 64 == 0 : + cta_m_per_mma_instruction = 2 + mma_m = cta_m * cta_m_per_mma_instruction + + # For all three kinds of convolutions, the tile shape's K mode + # differs from GEMM in that needs to be wrapped in a Shape. + # For Wgrad convolutions specifically, + # the N tile shape also needs to be wrapped in a Shape. + m_template = 'cute::_${mma_m}' + if operation.conv_kind == ConvKind.Wgrad: + n_template = 'cute::Shape' + else: + n_template = 'cute::_${mma_n}' + k_template = 'cute::Shape' + + mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' + values = { + 'mma_m': mma_m, + 'mma_n': mma_n, + 'mma_k': mma_k + } + return Template(mma_tile_shape_template).substitute(values) + + def cluster_shape(self, operation) -> str: + m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)' + n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)' + k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)' + cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>' + values = { + 'cluster_shape_m': operation.tile_description.cluster_shape[0], + 'cluster_shape_n': operation.tile_description.cluster_shape[1], + 'cluster_shape_k': operation.tile_description.cluster_shape[2], + } + return Template(cluster_shape_template).substitute(values) + + def stage_count(self, operation) -> str: + # stages == 0 tells builder to pick the number of stages automatically + namespace_prefix = 'cutlass::conv::collective::' + if operation.tile_description.stages > 0: + return f"{namespace_prefix}StageCount<{str(operation.tile_description.stages)}>" + else: + return f"{namespace_prefix}StageCountAutoCarveout" + + def emit(self, operation) -> str: + _LOGGER.debug("*** EmitConv3xInstance::emit") + _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name()) + + # Identify the operation as CUTLASS 3 by its is_3x field + if (not hasattr(operation, 'is_3x')) or (not operation.is_3x): + raise RuntimeError("operation must be a CUTLASS 3 operation") + + epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" + opcode_class_main = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class] + opcode_class_epi = opcode_class_main + + tile_shape = operation.tile_description.tile_shape + cluster_m = operation.tile_description.cluster_shape[0] + cluster_n = operation.tile_description.cluster_shape[1] + + cta_m, cta_n, cta_k = tile_shape + # account for static/dynamic cluster shapes + if operation.arch >= 100: + cta_m = cta_m // cluster_m if cluster_m > 0 else cta_m + cta_n = cta_n // cluster_n if cluster_n > 0 else cta_n + + warp_count = operation.tile_description.warp_count + epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule] + + # KernelScheduleTag and TileSchedulerTag both hard-code the + # namespace qualification of KernelScheduleAuto as + # "cutlass::gemm::collective::" (unless the tag is 'void'). + # + # For TileSchedulerTag, this namespace is fine, since CUTLASS 3 + # convolutions use the same tile schedulers (from the same + # cutlass::gemm::collective namespace) as GEMMs. + kernel_schedule = KernelScheduleTag[operation.kernel_schedule].replace('gemm::', 'conv::') + tile_scheduler = TileSchedulerTag[operation.tile_scheduler] + opcode_class = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class] + + values = { + 'operation_name': operation.procedural_name(), + 'conv_kind': ConvKindTag[operation.conv_kind], + 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'align_a': int(operation.A.alignment), + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'align_b': int(operation.B.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'align_c': int(operation.C.alignment), + 'element_d': DataTypeTag[operation.D.element], + 'layout_d': LayoutTag[operation.D.layout], + 'align_d': int(operation.D.alignment), + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': opcode_class, + 'arch': self.arch_number_to_type(operation.arch), + 'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k), + 'cluster_shape': self.cluster_shape(operation), + 'opcode_class_epi': opcode_class_epi, + 'opcode_class_main': opcode_class_main, + 'epi_tile_mn': epi_tile_mn, + 'stages': self.stage_count(operation), + 'kernel_schedule': kernel_schedule, + 'epilogue_schedule': epilogue_schedule, + 'tile_scheduler': tile_scheduler, + 'element_compute': DataTypeTag[operation.element_compute] + } + return Template(self.template).substitute(values) + +class EmitConv3xIncludes: + def __init__(self): + _LOGGER.debug("*** EmitConv3xIncludes::__init__") + self.includes = ['conv_operation_3x.hpp', + 'cutlass/conv/device/conv_universal_adapter.hpp', + 'cutlass/conv/kernel/conv_universal.hpp', + 'cutlass/conv/collective/collective_builder.hpp', + 'cutlass/epilogue/collective/collective_builder.hpp'] + + def emit(self, operation) -> str: + _LOGGER.debug("*** EmitConv3xIncludes::emit") + return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \ + "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////" diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py new file mode 100644 index 0000000000000000000000000000000000000000..fbe52eb587ab1b5e4595739be5790151b00e0a70 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py @@ -0,0 +1,868 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +# +# +# \brief Generates the CUTLASS kernel listing with kernel filtering +# + +# + +############################################################################### +# Example usage: +# generator.py --operations all --generator-target kernel_listing \ +# --architectures "70;75;80" --kernels "*" --disable-cutlass-package-imports +############################################################################### + +import collections +import csv +import json +import math +import os + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +audit_csv_fields = [ + "KernelType", "KernelName", "Type_A", "Type_B", "Type_C", "Type_Acc", "Type_EpilogueScale", "Type_D", "Type_SFA", "Type_SFD", + "Layout_A", "Layout_B", "Layout_C", "Layout_D", + "Alignment_A", "Alignment_B", "Alignment_C", "Alignment_D", + "1SM/2SM", + "StreamK Enabled", "Support Runtime_Cluster_Shape", "Support Runtime_Input_Types", + "Test Counts" +] + +audit_csv_runtime_fields = [ + "KerneIndex", "KernelName", + "Inst_M", "Inst_N", "Inst_K", "Tile_M", "Tile_N", "Tile_K", + "Cluster_M", "Cluster_N", "Cluster_K", "Preferred_Cluster_M", "Preferred_Cluster_N", "Preferred_Cluster_K", "Fallback_Cluster_M", "Fallback_Cluster_N", "Fallback_Cluster_K", + "M", "N", "K", "L", "Alpha_val", "Beta_val", + "Runtime_Input_Types Enabled", "Runtime_Cluster_Shape Enabled" +] + +def hash_cutlass_string(input_string): + mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1') + + # Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1') + output = re.sub(mma_cluster_shape_pattern, "", input_string) + + return output + +def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b): + # Define a dictionary mapping the detected types to runtime values + datatype_map = { + 'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b, + 'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + 'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b, + } + + # Regular expression to detect all the keys in datatype_map + pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')') + + # Replace detected patterns using the dictionary + updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name) + + return updated_kernel_name + +# This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k. +def get_kernel_features(operation, kernel_name, + dynamic_datatype, runtime_input_datatype): + numcta_inst = "2sm" if "2sm" in kernel_name else "1sm" + math_inst = operation.tile_description.math_instruction + + if dynamic_datatype: + dtype_name_A = runtime_input_datatype[0] + dtype_name_B = runtime_input_datatype[1] + else: + dtype_name_A = DataTypeNames[operation.A.element] + dtype_name_B = DataTypeNames[operation.B.element] + + layout_name_A = ShortLayoutTypeNames[operation.A.layout] + layout_name_B = ShortLayoutTypeNames[operation.B.layout] + layout_name_C = ShortLayoutTypeNames[operation.C.layout] + layout_name_D = ShortLayoutTypeNames[operation.D.layout] + + scale_factor_D_type = operation.ScaleFactorD.element if hasattr(operation, "ScaleFactorD") else DataType.void + scale_factor_A_type = getattr(operation, "ScaleFactorA", DataType.void) + audit_vals = [ + "BlockScaledGEMM" if math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp else "GEMM", + kernel_name, + dtype_name_A, + dtype_name_B, + DataTypeNames[operation.C.element], + DataTypeNames[operation.tile_description.math_instruction.element_accumulator], + DataTypeNames[operation.element_epilogue], + DataTypeNames[operation.D.element], + DataTypeNames[scale_factor_D_type], + DataTypeNames[scale_factor_A_type], + layout_name_A, + layout_name_B, + layout_name_C, + layout_name_D, + str(operation.A.alignment), + str(operation.B.alignment), + str(operation.C.alignment), + str(operation.D.alignment), + numcta_inst, + "Y" if 'stream_k' in kernel_name else "N", + ] + return audit_vals + +# This helper function reports other performance-related kernel parameters and those can be specified at runtime: cluster_shape, instruction shap, m/n/k and alpha/beta. +def get_kernel_params(operation, kernel_name, cluster_shape, fallback_cluster_shape, problem_shape, alpha, beta, dynamic_datatype, dynamic_cluster): + math_inst = operation.tile_description.math_instruction + audit_vals = [ + str(math_inst.instruction_shape[0]), + str(math_inst.instruction_shape[1]), + str(math_inst.instruction_shape[2]), + str(operation.tile_description.threadblock_shape[0]), + str(operation.tile_description.threadblock_shape[1]), + str(operation.tile_description.threadblock_shape[2]), + str(operation.tile_description.cluster_shape[0]), + str(operation.tile_description.cluster_shape[1]), + str(operation.tile_description.cluster_shape[2]), + str(cluster_shape[0]), + str(cluster_shape[1]), + str(cluster_shape[2]), + str(fallback_cluster_shape[0]), + str(fallback_cluster_shape[1]), + str(fallback_cluster_shape[2]), + str(problem_shape[0]), + str(problem_shape[1]), + str(problem_shape[2]), + str(problem_shape[3]), + str(alpha), + str(beta), + "Y" if dynamic_datatype else "N", + "Y" if dynamic_cluster else "N", + ] + return audit_vals + + +def _getSubOperationType(kernel): + + if kernel.operation_kind == OperationKind.Gemm: + return GemmKindNames[kernel.gemm_kind] + elif kernel.operation_kind == OperationKind.Conv2d: + return "conv_" + ConvKindNames[kernel.conv_kind] + elif kernel.operation_kind == OperationKind.Syrk: + return "syrk_" + SyrkKindNames[kernel.syrk_kind] + elif kernel.operation_kind == OperationKind.Trmm: + return "trmm_" + TrmmKindNames[kernel.trmm_kind] + elif kernel.operation_kind == OperationKind.Symm: + return "symm_" + SymmKindNames[kernel.symm_kind] + else: + raise Exception("Unsupported kernel type") + +def _get_inst_shape(math_instruction): + return "".join(str(x) for x in math_instruction.instruction_shape) + +def _is_simt_inst(math_instruction): + return _get_inst_shape(math_instruction) in ["111","114"] + +def _getInstType(input_precision, accumulate_precision, math_instruction): + + # inst_shape + inst_shape = _get_inst_shape(math_instruction) + + # input precision + if input_precision == "fp32" and inst_shape != "111": + inp = "tf32" + else: + inp = input_precision + + # Handle SIMT op types first + if _is_simt_inst(math_instruction): + + simt_input_precision_to_inst = { + "fp32": "FFMA", + "fp64": "DFMA", + "fp16": "HFMA", + "int8": "IDP4A", + } + inst = simt_input_precision_to_inst[input_precision] + + else: # Tensor op instructions + + if accumulate_precision == "cf64": + fp64_acc_map = { + MathOperation.multiply_add_complex_gaussian : "gz", + MathOperation.multiply_add_complex : "z", + } + acc = fp64_acc_map[math_instruction.math_operation] + else: + tensor_op_acc_map = { + "fp32" : "s", + "cf32" : "s", + "fp16" : "h", + "int32": "i", + "fp64" : "d", + } + acc = tensor_op_acc_map[accumulate_precision] + + inst = "{}{}{}".format(acc, inst_shape, inp) + + return inst +# TODO: Computes FLOps/Bytes for GEMM - revisit for conv +def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1): + assert not (batch_count > 1 and num_groups > 1) + + # TODO: adjust for sparsity + gmem_bytes = ( + (DataTypeSize[operation.A.element] * m // 8) * k + + (DataTypeSize[operation.B.element] * n // 8) * k + + (DataTypeSize[operation.C.element] * m // 8) * n + ) + + # TODO: complex-valued support + flops = 2 * (m * n * k) + + if bool(beta): + gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n + flops += 2 * m * n + + multiplier = max(batch_count, num_groups) + gmem_bytes *= multiplier + flops *= multiplier + + return flops / gmem_bytes + +def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode + ): + # For functional testing, we prefer to run reference computing on device if any + reference_device_archs = ["100a", "103a"] + run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False + profiler_flags_for_verification = "device" if run_reference_on_device else "host" + + # beta values for L0 and L1 + # TODO: randomize beta values for wider coverage + beta_values = [0.5] + + is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"]) + + is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch + + if (mode == "functional_L0") and is_supported_arch: + problem_waves = [0.5, 1.25, 2.5] + + # + # Dense Gemm + # + + sm100_mma_data_type_general = [ + 'gemm_f16_f16_f16_f16_f16', + 'gemm_f16_f16_f16_void_f16', + #'gemm_f16_f16_f32_f16_f16', + 'tf32gemm_f32_f32_f32_f32_f32', + 'bf16gemm_f32_f32_f32_f32_f32', + ] + + exclude_archs = arch not in ("103a") + if exclude_archs: + sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8') + + sm100_mma_data_type_runtime_dtype = [ + 'gemm.*f4_f4_f32_f32_f32', + 'gemm.*f6_f6_f32_f32_f32', + 'gemm.*f8_f8_f32_f32_f32', + ] + + sm100_mma_cluster_size = [ + '8x1x1', + '4x4x1', '2x1x1', + '0x0x1' # dynamic cluster + ] + + # Restrict to two layouts to reduce L0 build and test time. + sm100_mma_layouts = [ + 'tnt', + 'ntn' + ] + + # regex list must be in kernel procedural name order + sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + + sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + + # + # Block Scale Gemm + # + + block_scaled_data_type = [ + # runtime datatypes + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', + 'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2', + 'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2', + #'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', + 'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2', + ] + + block_scaled_tile_k = ['x128_', 'x256_'] + + sm103_block_scaled_data_type = [ + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2', + 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1', + ] + + sm103_block_scaled_tile_k = ['x768_'] + + block_scaled_cluster_size = [ + '4x4x1', '2x1x1', + '0x0x1' # dynamic cluster + ] + + block_scaled_layouts = ['tnt'] + # regex list must be in kernel procedural name order + block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + sm103_block_scaled_prefetch_policy = ['tmapf'] + sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" + sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*" + + if arch in ["100a", "100f"]: + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})" + elif arch in ["101a", "101f", "110a", "110f"]: + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})" + elif arch in ["103a"]: + kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({sm100_mma_filter_regex_1sm_runtime})|" \ + f"({sm100_mma_filter_regex_2sm_runtime})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})|" \ + f"({sm103_block_scaled_filter_regex_1sm})|" \ + f"({sm103_block_scaled_filter_regex_2sm})" + elif arch in ["120a", "120f", "121a", "121f"]: + + # blockscaled sm120_mma kernels + blockscaled_sm120_mma_kernel_cta_tiles = [ + [ '128x128' ] + ] + + # Restrict to two layouts to reduce L0 build and test time. + blockscaled_sm120_mma_layouts = [ 'tn' ] + filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*" + + problem_waves = [0.5, 1.25, 2.5] + + kernel_filter = f"({filter_regex_blockscaled_sm120_mma})" + else: + error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f" + raise Exception(error_message) + + elif mode == "functional_L1": + sm100_mma_cluster_size = [ + '0x0x1' # dynamic cluster + ] + # Restrict to two layouts to reduce L1 build and test time. + sm100_mma_layouts = ['tnt', 'ntn'] + sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*" + sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*" + block_scaled_data_type = [ + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2', + 'ue8m0xmx8s26_ue8m0xmx8s26_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', + 'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2', + ] + + sm103_block_scaled_data_type = [ + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2', + 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1', + ] + + block_scaled_cluster_size = ['0x0x1'] + block_scaled_layouts = ['tnt'] + + # regex list must be in kernel procedural name order + block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*" + sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*" + + filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \ + f"({sm100_mma_filter_regex_2sm})|" \ + f"({block_scaled_filter_regex_1sm})|" \ + f"({block_scaled_filter_regex_2sm})" \ + f"({sm103_block_scaled_filter_regex_1sm})|" \ + f"({sm103_block_scaled_filter_regex_2sm})" + # CTA tiles for sm120 MMA - only run one tile size to reduce build/test times + sm120_mma_kernel_cta_tiles = [ + # h1688, s1688, i16832, i8816 + [ '256x128' ], + # d884, c1688, + [ '128x128' ], + # c1688, z884 + [ '128x64' ], + # gz884 + [ '64x64' ] + ] + + # sm120 MMA instruction shapes, planar complex type excluded as they are not required + sm120_mma_instruction_shapes = [ + [ 'h1688gemm_(?!planar_complex)', + 's1688gemm_f16', + 's1688gemm_bf16', + 's1688gemm_tf32', + 'i16832gemm', + 'i8816gemm' ], + [ 'd884gemm', 'c1688tf32gemm' ] , + [ 'c1688gemm', + 'z884gemm' ], + [ 'gz884gemm'] + ] + + # It's not pretty, but not sure why different instructions support different tile sizes. + filter_regex_sm120_mma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[0], sm120_mma_kernel_cta_tiles[0]]]) + ").*" + filter_regex_sm120_mma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[1], sm120_mma_kernel_cta_tiles[1]]]) + ").*" + filter_regex_sm120_mma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[2], sm120_mma_kernel_cta_tiles[2]]]) + ").*" + filter_regex_sm120_mma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[3], sm120_mma_kernel_cta_tiles[3]]]) + ").*" + + filter_regex_sm120_mma = f"({filter_regex_sm120_mma_0})|({filter_regex_sm120_mma_1})|({filter_regex_sm120_mma_2})|({filter_regex_sm120_mma_3})" + + problem_waves = [0.5, 1.25, 2.5] + + if arch in ["120a", "120f", "121a", "121f"]: + kernel_filter = f"({filter_regex_sm120_mma})" + else: + kernel_filter = f"({filter_regex_sm100_mma})" + else: + raise ValueError() + + outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv") + + audit_file_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_SM{arch}_cutlass3x_gemm.csv") + + audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv") + + kernel_filter_re = re.compile(kernel_filter) + testcase_counter = 0 + kernels_emitted = 0 + kernels_total = 0 + + perf_json_list = [] + kernel_name_set = set() + + testlist_csv_fields = ["testcase", "metadata"] + testlist_csv_rows = [] + auditlist_csv_map = {} + auditlist_csv_params_map = {} + + kernel_features = {} + + for cc in manifest.operations[OperationKind.Gemm].keys(): + for kernel_name, operation_l in manifest.operations[OperationKind.Gemm][cc].items(): + assert(len(operation_l) == 1) + kernels_total += 1 + if len(kernel_filter_re.findall(kernel_name)) == 0: + continue + # Only test f16 I/O void C kernels in void C kernel set + # Exception: Use void C kernels for more accurate perf testing + if '_void_' in kernel_name and 'perf_' not in mode: + if 'f16_f16_f16_void_f16' not in kernel_name : + continue + + kernels_emitted += 1 + kernel_name_set.add(kernel_name) + hashed_kernel_name = hash_cutlass_string(kernel_name) + operation = operation_l[0] + + dynamic_cluster = (operation.tile_description.cluster_shape[0] == 0 + or operation.tile_description.cluster_shape[1] == 0) + + dynamic_datatype = "f8" in kernel_name or "f6" in kernel_name or "f4" in kernel_name + + runtime_input_datatypes = [None] + + if dynamic_datatype: + if "f4_f4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "f4_f6" in kernel_name: + runtime_input_datatypes = [['e2m1','e3m2']] + elif "f4_f8" in kernel_name: + runtime_input_datatypes = [['e2m1','e4m3']] + + elif "f6_f4" in kernel_name: + runtime_input_datatypes = [['e3m2','e2m1']] + elif "f6_f6" in kernel_name: + runtime_input_datatypes = [['e3m2','e3m2']] + elif "f6_f8" in kernel_name: + runtime_input_datatypes = [['e3m2','e4m3']] + + elif "f8_f4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + elif "f8_f6" in kernel_name: + runtime_input_datatypes = [['e4m3','e3m2']] + elif "f8_f8" in kernel_name: + runtime_input_datatypes = [ + # mask out those not covered in statically encoded test cases + # ['e5m2','e4m3'], + # ['e4m3','e5m2'], + ['e4m3','e4m3'] + ] + + # block scaled kernels + elif "ue8m0xf4_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "ue4m3xf4_ue4m3xf4" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m1']] + elif "ue8m0xf4_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e2m1','e2m3']] + elif "ue8m0xf4_ue8m0xf8" in kernel_name: + runtime_input_datatypes = [['e2m1','e4m3']] + + elif "ue8m0xf6_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e2m3','e2m1']] + elif "ue8m0xf6_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e2m3','e2m3']] + elif "ue8m0xf8_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + + elif "ue8m0xf8_ue8m0xf4" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m1']] + elif "ue8m0xf8_ue8m0xf6" in kernel_name: + runtime_input_datatypes = [['e4m3','e2m3']] + elif "ue8m0xf8_ue8m0xf8" in kernel_name: + runtime_input_datatypes = [['e4m3','e4m3']] + + if "bstensorop" in kernel_name or is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind): + profiler_flags_for_verification = "host" + + # reduce L1 test runtime if reference kernel is not running on device. + if mode == "functional_L1" and profiler_flags_for_verification == "host" : + problem_waves = [0.5, 2.5] + + + if dynamic_cluster: + if mode == "functional_L0": + runtime_cluster_shapes = [[1,1,1], [2,2,1]] + else: + runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]] + # reduce L1 test runtime if reference kernel is not running on device. + if profiler_flags_for_verification == "host": + runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1]] + cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape + else: + runtime_cluster_shapes = [operation.tile_description.cluster_shape] + cta_tile_shape_m = int(operation.tile_description.threadblock_shape[0] / operation.tile_description.cluster_shape[0]) + cta_tile_shape_n = int(operation.tile_description.threadblock_shape[1] / operation.tile_description.cluster_shape[1]) + cta_tile_shape_k = int(operation.tile_description.threadblock_shape[2] / operation.tile_description.cluster_shape[2]) + + alignment_a = operation.A.alignment + alignment_b = operation.B.alignment + alignment_c = operation.C.alignment + alignment_ab_max = max(alignment_a, alignment_b) + + layout3x = operation.layout_name_3x() + data_types = operation.datatype_name_3x() + + ctas_per_mma_instruction = 1 + if '_2sm' in kernel_name: + ctas_per_mma_instruction = 2 + valid_cluster_shapes = [] + + # Remove any cluster shapes that have cluster_m that is not divisible by 2 + for cs in runtime_cluster_shapes: + if cs[0] % 2 == 0: + valid_cluster_shapes.append(cs) + runtime_cluster_shapes = valid_cluster_shapes + + kernel_problem_waves = problem_waves + if mode == "functional_L0" or mode == "functional_L1": + # for functional testing, we want to perturb just a little from even shapes + # large K = 8 is chosen such that some kernels will warp around their smem buffers, and some will not + # -16 ensures that we are TMA aligned even for FP8/Int8 + min_k = alignment_ab_max if cta_tile_shape_k == alignment_ab_max else cta_tile_shape_k - alignment_ab_max + max_k = (cta_tile_shape_k*8) - alignment_ab_max + problem_shapes_k = [min_k, max_k] + sm_count = 16 + swizzle_sizes = [0] + # Larger k and less than half wave trigger streamk +separate reduction case to be generated + if 'stream_k' in kernel_name: + problem_shapes_k = [max_k, cta_tile_shape_k*32] + kernel_problem_waves = [0.125, 1.25, 2.5] + else: + raise ValueError + + if "void" in kernel_name: + beta_values = [0] + + alignment_shift_m = max(alignment_c, alignment_a) + alignment_shift_n = max(alignment_c, alignment_b) + + is_first_line = True + for index_waves, waves in enumerate(kernel_problem_waves): + for index_k, k in enumerate(problem_shapes_k): + for beta in beta_values: + for cluster_shape in runtime_cluster_shapes: + for runtime_input_datatype in runtime_input_datatypes: + for swizzle_size in swizzle_sizes: + grid_size = waves * sm_count + cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape) + if cluster_shape_m >= cluster_shape_n: + grid_m = cluster_shape_m + grid_n = grid_size / grid_m + grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1) + else: + grid_n = cluster_shape_n + grid_m = grid_size / grid_n + grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1) + + verification_required = False + if mode == "functional_L0" or mode == "functional_L1": + if '_void_' not in kernel_name: + verification_required = True + + m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max) + n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max) + k = int(k) + + # For functional testing, we want to perturb just a little from even shapes. + # Only do this if the perturbation does not cause one of the dimensions of the + # problem size to go to zero. This can occur for blockscaling kernels for which + # the alignment requirements for A and B can be quite large (e.g., 256). + if m > alignment_shift_m: + m -= alignment_shift_m + if n > alignment_shift_n: + n -= alignment_shift_n + + if '_n32t32_' in kernel_name: + continue + batch_count = 1 + if mode == "functional_L0" or mode == "functional_L1" : + if index_waves == 0 and index_k == 0 : + batch_count = 3 if mode == "functional_L0" else 5 + gemm_op = "gemm" + + grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind) + num_groups = 1 + if grouped: + gemm_op = "grouped_gemm" + num_groups = 3 # small to limit test time in host block-scaled reference kernels + batch_count = 1 + elif "bstensorop" in kernel_name: + gemm_op = "block_scaled_gemm" + elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind): + gemm_op = "blockwise_gemm" + + problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)] + + assert m > 0 and n > 0 and k > 0 + + # Emit per-testcase metadata for perf testing usage, eventually in perf database + metadata_dict = { + "input_params": { + 'problem_size_category' : problem_size_category, + 'operation' : _getSubOperationType(operation), + 'datatype' : data_types, + 'layout' : layout3x, + 'm' : m, + 'n' : n, + 'k' : k, + 'beta' : beta, + 'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups) + }, + "runtime_params": { + 'ctas_per_mma_instruction' : ctas_per_mma_instruction, + 'tilesize_m' : cta_tile_shape_m, + 'tilesize_n' : cta_tile_shape_n, + 'tilesize_k' : cta_tile_shape_k, + 'cluster_shape_m' : cluster_shape_m, + 'cluster_shape_n' : cluster_shape_n, + } + } + + cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m + cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n + cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k + + + if dynamic_datatype: + runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype) + metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a + metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b + + testcase_metadata = [ + f"cutlass_profiler --operation={gemm_op}" + + (f" --verification-providers=device --providers=cutlass" if profiler_flags_for_verification == "device" else " --mode=trace") + + f" --error-on-no-match --error-if-nothing-is-profiled" + + f" --kernels={kernel_name}" + + f" --m={str(m)}" + + f" --n={str(n)}" + + f" --k={str(k)}" + + (f" --num_groups={str(num_groups)}" if grouped else "") + + f" --cluster_m={str(cluster_shape_m)}" + + f" --cluster_n={str(cluster_shape_n)}" + + f" --cluster_k={str(cluster_shape_k)}" + + f" --cluster_m_fallback={str(cluster_m_fallback)}" + + f" --cluster_n_fallback={str(cluster_n_fallback)}" + + f" --cluster_k_fallback={str(cluster_k_fallback)}" + + f" --beta={str(beta)}" + + ("" if grouped else f" --batch_count={str(batch_count)}") + + f" --swizzle_size={str(swizzle_size)}" + + f" --verification-required={str(verification_required).lower()}" + ] \ + + output_dynamic_datatype = dynamic_datatype + if output_dynamic_datatype: + testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" + + f" --runtime_input_datatype_b={runtime_datatype_b}") + + testcase_metadata.append(json.dumps(metadata_dict)) + testlist_csv_rows.append(testcase_metadata) + testcase_counter += 1 + + alpha = 1.0 + + if dynamic_datatype: + hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b) + + # If kernel_name is new, initialize its feature set with defaults + if hashed_kernel_name not in kernel_features: + kernel_features[hashed_kernel_name] = { + "is_support_dynamic_cluster": False, + "is_support_dynamic_datatype": False, + } + + # Update features for the hashed kernel name + kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster + kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype + + if hashed_kernel_name not in auditlist_csv_params_map: + auditlist_csv_params_map[hashed_kernel_name] = [] + + audit_row_params = get_kernel_params( + operation, + hashed_kernel_name, + (cluster_shape_m, cluster_shape_n, cluster_shape_k), + (cluster_m_fallback, cluster_n_fallback, cluster_k_fallback), + (m, n, k, batch_count), + alpha, beta, + dynamic_datatype, dynamic_cluster + ) + + auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params) + + if hashed_kernel_name not in auditlist_csv_map: + audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype) + auditlist_csv_map[hashed_kernel_name] = audit_row + + with open(outfile_name, 'w') as testlist_csv: + csv_writer = csv.writer(testlist_csv, delimiter=',') + csv_writer.writerow(testlist_csv_fields) + csv_writer.writerows(testlist_csv_rows) + + with open(audit_file_name, 'w') as auditlist_csv: + csv_writer = csv.writer(auditlist_csv, delimiter=',') + csv_writer.writerow(audit_csv_fields) + for hashed_kernel_name, row in auditlist_csv_map.items(): + # Append the dynamic features as "Y" or "N" + dynamic_cluster_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] else "N" + dynamic_datatype_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] else "N" + test_count = len(auditlist_csv_params_map[hashed_kernel_name]) + csv_writer.writerow(row + [dynamic_cluster_flag, dynamic_datatype_flag, test_count]) + + with open(audit_file_params_name, 'w') as auditlist_csv: + csv_writer = csv.writer(auditlist_csv, delimiter=',') + csv_writer.writerow(audit_csv_runtime_fields) + for kernel_index, (hashed_kernel_name, rows) in enumerate(auditlist_csv_params_map.items(), start=1): + for i, row in enumerate(rows): + if i == 0: + csv_writer.writerow([kernel_index, hashed_kernel_name] + row) + else: + csv_writer.writerow(["", ""] + row) + + print(f"Generated a total of {testcase_counter} test cases for {kernels_emitted} kernels out of {kernels_total} total.") + + # Generate a newline separated list of kernel filters + assert(len(kernel_name_set) == kernels_emitted) + output_filter_enabled = True + if output_filter_enabled: + kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list") + with open(kernel_filter_outfile_name, "w") as file: + kernel_name_set = set(map(lambda x: x.replace("_epi_tma", ""), kernel_name_set)) + for kernel_name in kernel_name_set: + file.write(kernel_name + "\n") + + # Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and sm120_mma kernels in cutlass2.x generated together. + if mode == "functional_L0" or mode == "functional_L1": + # Sort the .csv file + outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv") + with open(outfile_name) as file: + data = file.readlines() + data.sort() + with open(outfile_name, 'w') as file: + for i in range(len(data)): + file.write(data[i]) + # Sort the kernel list + kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list") + with open(kernel_filter_outfile_name) as file: + data = file.readlines() + data.sort() + with open(kernel_filter_outfile_name, 'w') as file: + for i in range(len(data)): + file.write(data[i]) + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..0d2449e769303b738212cdcd896c9f2793ca2632 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py @@ -0,0 +1,1613 @@ + +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting GEMM kernels +""" + +import collections +import enum +import functools +import logging +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +_LOGGER = logging.getLogger(__name__) + +################################################################################################### +# +# Data structure modeling a GEMM operation +# +################################################################################################### + +# +class GemmOperation: + # + def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, + kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, + tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False, + ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None, + ScaleFactorMVecSize = None, ScaleFactorNVecSize = None, ScaleFactorKVecSize = None): + + kinds_3x = { + GemmKind.Universal3x, + GemmKind.SparseUniversal3x, + GemmKind.BlockScaledUniversal3x, + GemmKind.GroupedUniversal3x, + GemmKind.GroupedBlockScaledUniversal3x, + GemmKind.BlockwiseUniversal3x, + GemmKind.GroupedBlockwiseUniversal3x, + } + self.is_3x = gemm_kind in kinds_3x + self.prefix = "3x" if self.is_3x else "" + self.operation_kind = OperationKind.Gemm + self.arch = arch + self.tile_description = tile_description + self.gemm_kind = gemm_kind + self.A = A + self.B = B + self.C = C + self.D = D + + if is_block_scaled(gemm_kind): + self.ScaleFactorA = ScaleFactorA + self.ScaleFactorB = ScaleFactorB + self.ScaleFactorD = ScaleFactorD["tensor"] + self.ScaleFactorVectorSize = ScaleFactorD["vector_size"] + + if is_blockwise(gemm_kind): + self.ScaleFactorMVecSize = ScaleFactorMVecSize + self.ScaleFactorNVecSize = ScaleFactorNVecSize + self.ScaleFactorKVecSize = ScaleFactorKVecSize + + if self.D == None: + self.D = self.C + + if not self.is_3x: + assert(kernel_schedule == KernelScheduleType.ScheduleAuto) + assert(epilogue_schedule == EpilogueScheduleType.ScheduleAuto) + self.kernel_schedule = kernel_schedule + self.epilogue_schedule = epilogue_schedule + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + + if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination: + self.epilogue_functor = EpilogueFunctor3x.LinearCombination + + self.swizzling_functor = swizzling_functor + self.tile_scheduler = tile_scheduler + + # Only enable mixed input mode and mixed input shuffle for Hopper + self.mixed_input_mode = None + if self.is_mixed_input() and self.arch >= 90 and self.arch < 100: + self.mixed_input_mode = mixed_input_mode + self.mixed_input_shuffle = (self.mixed_input_mode is not None) and mixed_input_shuffle + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def is_planar_complex(self): + return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and', + MathOperation.multiply_add_fast_accum: 'fastaccum', + } + + tensor_ops = [ + OpcodeClass.TensorOp, + OpcodeClass.WmmaTensorOp, + OpcodeClass.SparseTensorOp, + OpcodeClass.BlockScaledTensorOp, + ] + + is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops + + if is_tensor_op: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else "" + + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + short_math_name = self.short_math_name() if not self.is_3x else "" + + return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind]) + + # Generates a string representing the MMA instruction. + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + element_sfa = "" + element_sfb = "" + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.is_mixed_input(): + extended_name = "${core_name}_${element_a}_${element_b}" + if self.C.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_" + extended_name + elif is_blockwise(self.gemm_kind): + extended_name = "${core_name}_${element_sfa}x${element_a}_${element_sfb}x${element_b}" + element_sfa = DataTypeNames[self.accumulator_type()] + element_sfb = DataTypeNames[self.accumulator_type()] + else: + extended_name = "${core_name}" + if self.C.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_" + extended_name + if self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name += "_${element_a}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_sfa' : element_sfa, + 'element_b': DataTypeNames[self.B.element], + 'element_sfb' : element_sfb, + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def mixed_input_mode_name(self): + mode_name_mapping = { + MixedInputMode.ConvertOnly: "_cvt", + MixedInputMode.ScaleOnly: "_scl", + MixedInputMode.ScaleWithZeroPoint: "_sclzr" + } + mode_name = mode_name_mapping.get(self.mixed_input_mode, "") + if self.mixed_input_shuffle: + mode_name = mode_name + "_shfl" + return mode_name + + def extended_name_3x(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_a = DataTypeNames[self.A.element], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = DataTypeNames[self.D.element], + core_name = self.core_name()) + + if is_block_scaled(self.gemm_kind): + d_type_names = DataTypeNames[self.D.element] + + if self.ScaleFactorD.element != DataType.void: + d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names + + extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_sfa = DataTypeNames[self.ScaleFactorA], + element_a = DataTypeNames[self.A.element], + element_sfb = DataTypeNames[self.ScaleFactorB], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = d_type_names, + core_name = self.core_name()) + + if is_blockwise(self.gemm_kind): + d_type_names = DataTypeNames[self.D.element] + + extended_name = "{core_name}_{sfvec_m_size}x{sfvec_k_size}{element_sfa}x{element_a}_{sfvec_n_size}x{sfvec_k_size}{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_sfa = DataTypeNames[self.accumulator_type()], + element_a = DataTypeNames[self.A.element], + element_sfb = DataTypeNames[self.accumulator_type()], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = d_type_names, + sfvec_m_size = self.ScaleFactorMVecSize, + sfvec_n_size = self.ScaleFactorNVecSize, + sfvec_k_size = self.ScaleFactorKVecSize, + core_name = self.core_name()) + + if self.mixed_input_mode != None: + extended_name = extended_name + self.mixed_input_mode_name() + return extended_name + + def datatype_name_3x(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_a = DataTypeNames[self.A.element], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + element_d = DataTypeNames[self.D.element]) + return datatype_name + + # Generates a short string representing the AB layout tags (e.g. nt or tn) + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + # Generates a short string representing the ABC layout tags (e.g. ntn or tnn) + def layout_name_3x(self): + if self.is_complex() or self.is_planar_complex(): + return "{}{}{}".format( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)], + ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)]) + else: + return "{}{}{}".format( + ShortLayoutTypeNames[self.A.layout], + ShortLayoutTypeNames[self.B.layout], + ShortLayoutTypeNames[self.C.layout]) + + # Generates a short string representing underlying kernel schedule type + def kernel_schedule_name_3x(self): + return KernelScheduleSuffixes[self.kernel_schedule] + + # Generates a short string representing underlying epilogue schedule type + def epilogue_schedule_name_3x(self): + + if is_block_scaled(self.gemm_kind): + if self.ScaleFactorD.element != DataType.void: + return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout] + + return EpilogueScheduleSuffixes[self.epilogue_schedule] + + # Generate a short string representing the operation class + def opcode_class_name(self): + return OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + def get_collective_tile_shape(self): + """ + Get the tile shape passed to the collective builder. + On Blackwell, this is different than the operation.tile_description.tile_shape. + """ + is_sm100_kernel = (self.arch == 100 or self.arch == 103) + if not is_sm100_kernel: + return self.tile_description.tile_shape + + opcode_class_main = self.tile_description.math_instruction.opcode_class + instruction_shape = self.tile_description.math_instruction.instruction_shape + tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape + if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]: + tile_shape_m = instruction_shape[0] + tile_shape_n = instruction_shape[1] + return (tile_shape_m, tile_shape_n, tile_shape_k) + + # Generates the full kernel function name + def procedural_name(self): + return self._procedural_name + + @functools.cached_property + def _procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + if self.arch >= 90: + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}" + tile_shape = self.get_collective_tile_shape() + return kernel_name_template.format( + p = self.prefix, + ar = self.arch, + op = opcode_class_name, + ex = self.extended_name_3x(), + ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "", + cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]), + l = self.tile_description.stages, + s = self.layout_name_3x(), + al = str(max(self.A.alignment, self.B.alignment)), + t = TileSchedulerSuffixes[self.tile_scheduler], + k = self.kernel_schedule_name_3x(), + e = self.epilogue_schedule_name_3x()) + else: + threadblock = self.tile_description.procedural_name() + return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( + p = self.prefix, + op = opcode_class_name, + ex = self.extended_name(), + tb = threadblock, + l = self.layout_name(), + a = str(max(self.A.alignment, self.B.alignment))) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + + def __hash__(self): + return hash(self.configuration_name()) + + def __eq__(self, other): + return self.configuration_name() == other.configuration_name() + +################################################################################################### +# +# Data structure modeling a grouped GEMM operation +# +################################################################################################### + +# +class GroupedGemmOperation(GemmOperation): + # + def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + scheduler_mode = GroupScheduleMode.Device): + super().__init__(gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor, swizzling_functor) + + self.scheduler_mode = scheduler_mode + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + base = super().procedural_name() + return SubstituteTemplate( + base + "_schedule${schedule}", + { + 'schedule': ShortGroupScheduleModeNames[self.scheduler_mode] + }) + + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitGemmInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::Gemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + self.gemm_complex_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::GemmComplex< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${transform_a}, + ${transform_b}, + ${math_operation} + ${residual} + >; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + residual = '' + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual + } + + template = self.gemm_complex_template if operation.is_complex() else self.gemm_template + + return SubstituteTemplate(template, values) + +################################################################################################### + +class EmitSparseGemmInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::SparseGemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + residual = '' + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual + } + + template = self.gemm_template + + return SubstituteTemplate(template, values) + +################################################################################################### + + +# +class EmitGemmUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/device/gemm_universal_adapter.h", + "cutlass/gemm/kernel/default_gemm_universal.h", + ] + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > +""" + self.gemm_template = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + self.gemm_template_interleaved = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + transpose_layouts = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor + } + + if operation.A.layout in transpose_layouts.keys() and \ + operation.B.layout in transpose_layouts.keys() and \ + operation.C.layout in transpose_layouts.keys(): + + instance_layout_A = transpose_layouts[operation.A.layout] + instance_layout_B = transpose_layouts[operation.B.layout] + instance_layout_C = transpose_layouts[operation.C.layout] + + gemm_template = self.gemm_template + else: + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + + gemm_template = self.gemm_template_interleaved + # + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + + epilogue_vector_length = \ + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] + + values = { + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + # + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_functor': epilogue_functor, + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] + } + + return SubstituteTemplate(gemm_template, values) + + +################################################################################################### + +class EmitGemmUniversal3xInstance: + ''' Responsible for emitting a CUTLASS 3.x template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/gemm/gemm.h", + "cutlass/numeric_types.h", + "cutlass/gemm/kernel/gemm_universal.hpp", + "cutlass/gemm/collective/collective_builder.hpp", + "cutlass/epilogue/collective/collective_builder.hpp", + "cutlass/detail/blockwise_scale_layout.hpp", + ] + self.builtin_epilogue_functor_template = \ +"""${epilogue_functor}< + ${element_d}, + ${element_epilogue}, + ${element_c}, + ${element_epilogue} + >""" + + self.gemm_template = """ + +using ${operation_name}_epilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ${arch}, ${opcode_class_epi}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${epi_tile_mn}, + ${element_accumulator}, ${element_epilogue}, + ${element_c}, ${layout_c}, ${align_c}, + ${element_d}, ${layout_d}, ${align_d}, + ${epilogue_schedule}, + ${epilogue_functor} + >::CollectiveOp; + +${mixed_dtype_prepare_code} +${blockwise_prepare_code} + +using ${operation_name}_mainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ${arch}, ${opcode_class_main}, + ${element_a}, ${layout_a}, ${align_a}, + ${element_b}, ${layout_b}, ${align_b}, + ${element_accumulator}, + cute::Shape, + cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, + ${stages}, + ${kernel_schedule} + >::CollectiveOp; + +// Gemm operator ${operation_name} +using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< + ${problem_shape}, + ${operation_name}_mainloop, + ${operation_name}_epilogue, + ${tile_scheduler}>; + +// Define named type +struct ${operation_name} : + public ${operation_name}_base { }; + +""" + # + def instance_template(self): + return """ +${compile_guard_start} + { + using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>; + manifest.append( + new ${gemm_kind}("${operation_name}")); + } +${compile_guard_end} +""" + + + def emit_block_scale_epilogue_functor(self, operation): + block_scaled_template = """ + ${epilogue_functor}< + ${epi_vs}, + ${element_d}, + ${element_accumulator}, + ${element_sfd}, + ${layout_sfd}, + ${element_c}, + ${element_scalar} + > + """ + block_scaled_values = { + 'epi_vs' : str(operation.ScaleFactorVectorSize), + 'element_d': str(DataTypeTag[operation.D.element]), + 'element_sfd': str(DataTypeTag[operation.ScaleFactorD.element]), + 'layout_sfd': LayoutTag[operation.ScaleFactorD.layout], + 'epilogue_functor': EpilogueFunctor3xTag[EpilogueFunctor3x.LinearCombinationBlockScaleFactor], + 'element_accumulator': str(DataTypeTag[operation.accumulator_type()]), + 'element_scalar': str(DataTypeTag[operation.accumulator_type()]), + 'element_c': str(DataTypeTag[operation.C.element]), + } + return SubstituteTemplate(block_scaled_template, block_scaled_values) + + + @staticmethod + def pointerize_if_grouped(operation, layout): + return layout if not is_grouped(operation.gemm_kind) else layout + "* " + + @staticmethod + def transform_layout_A_if_blockwise(operation, layout): + layout_sfa = f"{operation.procedural_name()}_LayoutSFA" + layout_sfa = layout_sfa if not is_grouped(operation.gemm_kind) else layout_sfa + "* " + return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfa}>" + + @staticmethod + def transform_layout_B_if_blockwise(operation, layout): + layout_sfb = f"{operation.procedural_name()}_LayoutSFB" + layout_sfb = layout_sfb if not is_grouped(operation.gemm_kind) else layout_sfb + "* " + return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfb}>" + + @staticmethod + def problem_shape(operation): + gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cute::Shape" + grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">" + + return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type + + def emit(self, operation): + _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") + _LOGGER.debug("*** operation.procedural_name(): " + operation.procedural_name()) + _LOGGER.debug("*** tile_shape: " + str(operation.tile_description.tile_shape)) + _LOGGER.debug("*** warp_count: " + str(operation.tile_description.warp_count)) + + opcode_class_main = operation.tile_description.math_instruction.opcode_class + opcode_class_epi = opcode_class_main + + tile_shape = operation.tile_description.tile_shape + instruction_shape = operation.tile_description.math_instruction.instruction_shape + cluster_m = operation.tile_description.cluster_shape[0] + cluster_n = operation.tile_description.cluster_shape[1] + cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1] + tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape() + + # stage count set to zero indicates builder automatic stage selection + if operation.tile_description.stages > 0: + stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>" + elif opcode_class_main == OpcodeClass.SparseTensorOp and operation.arch == 100: + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveoutEpi<{str(operation.procedural_name())}_epilogue>" + else: + stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>" + + epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" + + instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \ + (operation.A.layout, operation.B.layout, operation.C.layout, operation.D.layout) + + # 3.0 profiler integration only supports trivial epilogues for now + epilogue_vector_length = 1 + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + values = { + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + + if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void: + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + + + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + + if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void: + epilogue_functor = self.emit_block_scale_epilogue_functor(operation) + + # + # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple, Transform : cute::identity / cute::conjugate. + element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>" + element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" + epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] + + if opcode_class_main == OpcodeClass.BlockScaledTensorOp: + grouped = is_grouped(operation.gemm_kind) + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped): + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] + if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped): + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] + # SM103 FP4 Ultra + is_sm103_fp4_ultra_1sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped) + ] + is_sm103_fp4_ultra_2sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), + to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped) + ] + if cta_n == 256 and is_sm103_fp4_ultra_1sm_kernel_schedule: + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)] + if cta_n == 256 and is_sm103_fp4_ultra_2sm_kernel_schedule: + epi_tile_mn = "cute::Shape" + if is_tma_epilogue(operation.epilogue_schedule): + epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)] + + element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>' + element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>' + + alignment_c = get_tma_alignment(operation.C.element) \ + if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \ + else operation.C.alignment + alignment_d = get_tma_alignment(operation.D.element) \ + if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \ + else operation.D.alignment + + operation_name_str = operation.procedural_name() + layout_a_str = LayoutTag[instance_layout_A] + layout_b_str = LayoutTag[instance_layout_B] + mixed_dtype_prepare_code = "" + if operation.mixed_input_mode != None: + A_dtype = operation.A.element + B_dtype = operation.B.element + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) + + narrow_tag = DataTypeTag[narrow_dtype] + wide_tag = DataTypeTag[wide_dtype] + scale_tag = DataTypeTag[wide_dtype] + zero_tag = DataTypeTag[wide_dtype] + + do_shuffle = False + value_shuffle_str = "" + if narrow_dtype_bits == 4 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, cute::Stride>" + do_shuffle = True + if narrow_dtype_bits == 8 and wide_dtype_bits == 16: + value_shuffle_str = "cute::Layout, cute::Stride>" + do_shuffle = True + do_shuffle = operation.mixed_input_shuffle and do_shuffle + + if do_shuffle: + if is_A_dtype_narrow: + stride_narrow_str = f"cutlass::detail::TagToStrideA_t<{layout_a_str}>" + layout_a_str = f"{operation_name_str}_LayoutNarrowReordered" + else: + stride_narrow_str = f"cutlass::detail::TagToStrideB_t<{layout_b_str}>" + layout_b_str = f"{operation_name_str}_LayoutNarrowReordered" + # The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and + # layout_{a, b}_str are to prevent errors in Windows platform unity build + mixed_dtype_prepare_code = f""" +using {operation_name_str}_StrideNarrow = {stride_narrow_str}; +using {operation_name_str}_ValueShuffle = {value_shuffle_str}; +static constexpr int {operation_name_str}_NumShuffleAtoms = 1; +using {operation_name_str}_MmaAtomShape = cute::Layout>>; +using {operation_name_str}_LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, {operation_name_str}_ValueShuffle>()); +using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, cute::Layout, {operation_name_str}_StrideNarrow>{{}})); + """ + + mixed_input_modes_to_element = { + MixedInputMode.ConvertOnly: narrow_tag, + MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>", + MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>" + } + narrow_element = mixed_input_modes_to_element.get(operation.mixed_input_mode, narrow_tag) + + if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2): + narrow_element = f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>" + + if is_A_dtype_narrow: + element_a = narrow_element + else: + element_b = narrow_element + + blockwise_prepare_code = "" + if is_blockwise(operation.gemm_kind): + sfm_vec_size = operation.ScaleFactorMVecSize + sfn_vec_size = operation.ScaleFactorNVecSize + sfk_vec_size = operation.ScaleFactorKVecSize + blockwise_prepare_code = f""" +using {operation_name_str}_ScaleConfig = cutlass::detail::Sm{operation.arch}BlockwiseScaleConfig<{sfm_vec_size}, {sfn_vec_size}, {sfk_vec_size}>; +using {operation_name_str}_LayoutSFA = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFA()); +using {operation_name_str}_LayoutSFB = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFB()); + """ + + values = { + 'operation_name': operation_name_str, + 'operation_suffix': self.operation_suffix, + 'problem_shape': self.problem_shape(operation), + 'element_a': element_a, + 'layout_a': self.transform_layout_A_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_a_str)), + 'element_b': element_b, + 'layout_b': self.transform_layout_B_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_b_str)), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]), + 'element_d': DataTypeTag[operation.D.element], + 'layout_d': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_D]), + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class_main': OpcodeClassTag[opcode_class_main], + 'opcode_class_epi': OpcodeClassTag[opcode_class_epi], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'tile_shape_m': str(tile_shape_m), + 'tile_shape_n': str(tile_shape_n), + 'tile_shape_k': str(tile_shape_k), + 'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int", + 'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int", + 'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int", + 'instruction_shape_m': str(instruction_shape[0]), + 'instruction_shape_n': str(instruction_shape[1]), + 'instruction_shape_k': str(instruction_shape[2]), + 'kernel_schedule' : str(KernelScheduleTag[operation.kernel_schedule]), + 'epilogue_schedule' : str(epilogue_schedule_type), + 'epi_tile_mn' : epi_tile_mn, + 'epilogue_functor': epilogue_functor, + 'stages': stage_count_string, + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'align_c': str(alignment_c), + 'align_d': str(alignment_d), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]), + 'mixed_dtype_prepare_code': mixed_dtype_prepare_code, + 'blockwise_prepare_code' : blockwise_prepare_code + } + + return SubstituteTemplate(self.gemm_template, values) + +################################################################################################### + +# +class EmitGemmPlanarComplexInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator} + >::GemmKernel; + + struct ${operation_name} : + public Operation_${operation_name} { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### + +# +class EmitGemmPlanarComplexArrayInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [] + self.template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal< + ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b}, + ${element_c}, cutlass::layout::RowMajor, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + cutlass::epilogue::thread::LinearCombinationPlanarComplex< + ${element_c}, + ${alignment_c}, + ${element_accumulator}, + ${element_epilogue} + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator} + >::GemmArrayKernel; + + struct ${operation_name} : public Operation_${operation_name} { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmUniversalAdapter<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major + transposed_layout_A = TransposedLayout[operation.A.layout] + transposed_layout_B = TransposedLayout[operation.B.layout] + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.B.element], + 'layout_a': LayoutTag[transposed_layout_B], + 'transform_a': ComplexTransformTag[operation.B.complex_transform], + 'alignment_a': str(operation.B.alignment), + 'element_b': DataTypeTag[operation.A.element], + 'layout_b': LayoutTag[transposed_layout_A], + 'transform_b': ComplexTransformTag[operation.A.complex_transform], + 'alignment_b': str(operation.A.alignment), + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'alignment_c': str(operation.C.alignment), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'stages': str(operation.tile_description.stages), + 'math_operator': 'cutlass::arch::OpMultiplyAdd' + } + + return SubstituteTemplate(self.template, values) + +################################################################################################### + +# +class EmitGemmGroupedInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self, operation_suffix = ''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/kernel/gemm_grouped.h", + "cutlass/gemm/kernel/default_gemm_grouped.h", + "cutlass/gemm/device/gemm_grouped.h" + ] + self.builtin_epilogue_functor_template = \ +"""${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >""" + + self.gemm_template = """ +// Gemm operator ${operation_name} +using ${operation_name}_base = + typename cutlass::gemm::kernel::DefaultGemmGrouped< + ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, + ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${scheduler_mode}, + ${math_operation} +>::GemmKernel; + +// Define named type +struct ${operation_name}${operation_suffix} : + public ${operation_name}_base { }; +""" + + # + def instance_template(self): + return """ +${compile_guard_start} + manifest.append(new ${gemm_kind}< + cutlass::gemm::device::GemmGrouped<${operation_name}> + >("${operation_name}")); +${compile_guard_end} +""" + + # + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + transpose_layouts = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor + } + + instance_layout_A, instance_layout_B, instance_layout_C = \ + (operation.A.layout, operation.B.layout, operation.C.layout) + # + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + + epilogue_vector_length = \ + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] + + values = { + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + } + epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + # + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_functor': epilogue_functor, + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'scheduler_mode': GroupScheduleModeTag[operation.scheduler_mode], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation] + } + + return SubstituteTemplate(self.gemm_template, values) + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitGemmConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + GemmKind.Gemm: EmitGemmInstance, + GemmKind.Sparse: EmitSparseGemmInstance, + GemmKind.Universal: EmitGemmUniversalInstance, + GemmKind.Universal3x: EmitGemmUniversal3xInstance, + GemmKind.SparseUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, + GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance, + GemmKind.Grouped: EmitGemmGroupedInstance, + GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.BlockwiseUniversal3x: EmitGemmUniversal3xInstance, + GemmKind.GroupedBlockwiseUniversal3x: EmitGemmUniversal3xInstance, + } + + self.gemm_kind_wrappers = { + GemmKind.Gemm: 'GemmOperation', + GemmKind.Sparse: 'GemmSparseOperation', + GemmKind.Universal: 'GemmUniversalOperation', + GemmKind.Universal3x: 'GemmUniversal3xOperation', + GemmKind.SparseUniversal3x: 'SparseGemmUniversal3xOperation', + GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation', + GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', + GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation', + GemmKind.Grouped: 'GemmGroupedOperation', + GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation', + GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation', + GemmKind.BlockwiseUniversal3x: 'BlockwiseGemmUniversal3xOperation', + GemmKind.GroupedBlockwiseUniversal3x: 'GroupedBlockwiseGemmUniversal3xOperation', + } + + self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" + + self.separator = """ +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.header_template = """ +/* + Generated by gemm_operation.py - Do not edit. +*/ +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + _LOGGER.debug("*** EmitGemmConfigurationLibrary::__enter__") + _LOGGER.debug("*** configuration_path (file to write): " + + str(self.configuration_path)) + + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + self.configuration_file.write(self.separator) + + self.includes = collections.OrderedDict([ + ("cutlass/cutlass.h", None), + ("cutlass/library/library.h", None), + ("cutlass/library/manifest.h", None), + ("library_internal.h", None), + ("gemm_operation.h", None), + ("gemm_operation_3x.hpp", None), + ("grouped_gemm_operation_3x.hpp", None), + ("sparse_gemm_operation_3x.hpp", None), + ("block_scaled_gemm_operation_3x.hpp", None), + ("blockwise_gemm_operation_3x.hpp", None), + ("cutlass/arch/wmma.h", None), + ("cutlass/numeric_types.h", None) + ]) + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)") + _LOGGER.debug("*** operation.gemm_kind: " + str(operation.gemm_kind)) + + emitter = self.instance_emitter[operation.gemm_kind]() + + for incl in emitter.includes: + self.includes[incl] = None + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(emitter.instance_template(), { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write includes + for incl, _ in self.includes.items(): + include_statement = "#include \"%s\"\n" % incl + self.configuration_file.write(include_statement) + + self.configuration_file.write(self.separator) + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### +################################################################################################### diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..063e8fb1caa6626e8ba099133fee4dd3dc115e40 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py @@ -0,0 +1,10962 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library kernels +""" + +import argparse +import enum +from itertools import chain, product +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Dict, Optional, Sequence, Tuple + +_LOGGER = logging.getLogger(__name__) + +def logging_prefix(indent_level: int = 0) -> str: + """String prefix for start of each debug log entry""" + prefix = '*** ' + indent = ' ' + return f"{prefix}{indent_level * indent}" + +def log_debug_line(line: str, indent_level: int = 0) -> None: + """Log one line of debug output""" + prefix = logging_prefix(indent_level) + _LOGGER.debug(prefix + line) + +# Certain usecases of cutlass_library nearly always prefer to run as scripts with +# relative imports, rather than via an installed Python package. An example of this +# is using CUTLASS's CMake system to generate a library of kernels to be profiled. +# To make it easy to use these use cases when an existing installation of cutlass_library +# exists, this global flag can be set to true (via command-line arguments) to ensure +# that package-based installations are not used. + +# Create a temporary argument parser to check only for the availability of the +# --disable-cutlass-package-imports argument, which controls whether package-based +# imports are disabled. +def _add_package_disablement_flag(argparser): + argparser.add_argument("--disable-cutlass-package-imports", action='store_true', required=False, + help="Disable use of cutlass_library from Python package") + +_parser = argparse.ArgumentParser() +_add_package_disablement_flag(_parser) +_args, _ = _parser.parse_known_args() + +# Add `CUTLASS_IGNORE_PACKAGE` to `builtins` so that it is visible for gating future +# imports without requiring importing another module. Ideally, we would just place this +# as a global variable in a module to that could be imported and checked (e.g., +# utils.CUTLASS_IGNORE_PACKAGE). However, this raises the issue of determining +# where this module should be sourced (from the cutlass_library package or from +# a relative import), which is the problem this variable is being used to solve in the +# first place. +import builtins +builtins.CUTLASS_IGNORE_PACKAGE = _args.disable_cutlass_package_imports + +try: + if CUTLASS_IGNORE_PACKAGE: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.manifest import * + from cutlass_library.heuristics import * + from cutlass_library.emit_kernel_listing import emit_gemm_kernel_testlist +except ImportError: + from library import * + from manifest import * + from heuristics import * + from emit_kernel_listing import emit_gemm_kernel_testlist +################################################################################################### + +# +def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): + + # by default, use the latest CUDA Toolkit version + cuda_version = [11, 0, 132] + + # Update cuda_version based on parsed string + if semantic_ver_string != '': + for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]): + if i < len(cuda_version): + cuda_version[i] = x + else: + cuda_version.append(x) + return cuda_version >= [major, minor, patch] + +# From cuda 13.0, Thor SM is renumbered from 101 to 110 +def ThorSMRenumbering(cuda_version): + return 110 if CudaToolkitVersionSatisfies(cuda_version, 13, 0) else 101 + +################################################################################################### +################################################################################################### + +# +def EpilogueAlignment(max_alignment, tile, epilogue_steps = 8): + ''' Helper to compute the maximum alignment of the epilogue ''' + + def product(X, identity = 1): + result = identity + for item in X: + result *= item + return result + + elements_per_thread = product(tile.threadblock_shape[:-1]) // product(tile.warp_count) // 32 // epilogue_steps + return min(max_alignment, elements_per_thread) + +def DefaultSwizzlingFunctor(): + return SwizzlingFunctor.Identity8 + # To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK` + +# +def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = DefaultSwizzlingFunctor()): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + # If alignment is a tuple or a list, then we have different alignments for A and B + alignment_a = alignment if isinstance(alignment, int) else alignment[0] + alignment_b = alignment if isinstance(alignment, int) else alignment[1] + alignment_c = min(8, alignment_a) if isinstance(alignment, int) else alignment[2] + + A = TensorDescription(element_a, layout[0], alignment_a, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment_b, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GemmOperation(GemmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts +def CreateGemmUniversal3xOperator( + manifest, layouts, tile_descriptions, data_types, + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]], + complex_transforms=None, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity1, + tile_schedulers=[TileSchedulerType.Default], + gemm_kind=GemmKind.Universal3x): + + if type(data_types) is dict: + data_types = [data_types] + + for s in schedules: + assert(len(s) == 2) + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none), ] + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + if len(tile_descriptions) == 0: + return operations + tile_descriptions = [tile_descriptions[0]] + + combinations = product(layouts, tile_descriptions, data_types, complex_transforms, schedules, tile_schedulers) + for layout, tile_description, data_type, complex_transform, schedules, tile_scheduler in combinations: + kernel_schedule, epilogue_schedule = schedules + A = TensorDescription( + data_type["a_type"], layout[0][0], layout[0][1], complex_transform[0]) + B = TensorDescription( + data_type["b_type"], layout[1][0], layout[1][1], complex_transform[1]) + + C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) + D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) + + gemm_op_extra_args = {} + element_compute = data_type.get("epi_type", data_type["acc_type"]) + + if "sf_type" in data_type: + gemm_op_extra_args["ScaleFactorA"] = data_type["sf_type"] + gemm_op_extra_args["ScaleFactorB"] = data_type["sf_type"] + gemm_op_extra_args["ScaleFactorD"] = { "tensor": TensorDescription(data_type["sfd_type"]["type"], data_type["sfd_type"]["layout"]), + "vector_size" : data_type["sfd_type"]["vector_size"]} + assert is_block_scaled(gemm_kind) + + if tile_description.explicit_vector_sizes != None: + assert len(tile_description.explicit_vector_sizes) == 3 + gemm_op_extra_args["ScaleFactorMVecSize"] = tile_description.explicit_vector_sizes[0] + gemm_op_extra_args["ScaleFactorNVecSize"] = tile_description.explicit_vector_sizes[1] + gemm_op_extra_args["ScaleFactorKVecSize"] = tile_description.explicit_vector_sizes[2] + assert is_blockwise(gemm_kind) + else: + assert not is_blockwise(gemm_kind) + + A_dtype = data_type["a_type"] + B_dtype = data_type["b_type"] + A_dtype_bits = DataTypeSize[A_dtype] + B_dtype_bits = DataTypeSize[B_dtype] + is_A_dtype_narrow = A_dtype_bits < B_dtype_bits + if is_A_dtype_narrow: + narrow_dtype, wide_dtype = (A_dtype, B_dtype) + narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits) + else: + narrow_dtype, wide_dtype = (B_dtype, A_dtype) + narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits) + + mixed_input_modes = [None] + if narrow_dtype_bits != wide_dtype_bits: + if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2): + mixed_input_modes = [MixedInputMode.ScaleOnly] + else: + mixed_input_modes = [MixedInputMode.ConvertOnly, MixedInputMode.ScaleOnly, MixedInputMode.ScaleWithZeroPoint] + + mixed_input_shuffle_options = [False] + if (mixed_input_modes[0] is not None) and (wide_dtype_bits == 16) and (narrow_dtype_bits == 4 or narrow_dtype_bits == 8): + mixed_input_shuffle_options = [False, True] + + for mixed_input_mode, mixed_input_shuffle in product(mixed_input_modes, mixed_input_shuffle_options): + operation = GemmOperation( + gemm_kind, tile_description.minimum_compute_capability, + tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, + kernel_schedule, epilogue_schedule, tile_scheduler, + mixed_input_mode=mixed_input_mode, mixed_input_shuffle=mixed_input_shuffle, **gemm_op_extra_args) + manifest.append(operation) + operations.append(operation) + + return operations + +# Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts +def CreateSparseGemmUniversal3xOperator( + manifest, layouts, tile_descriptions, data_types, + schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]], + complex_transforms=None, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity1, + tile_schedulers=[TileSchedulerType.Default]): + + if type(data_types) is dict: + data_types = [data_types] + + for s in schedules: + assert(len(s) == 2) + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none), ] + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0]] + + combinations = product(layouts, tile_descriptions, data_types, complex_transforms, schedules, tile_schedulers) + for layout, tile_description, data_type, complex_transform, schedules, tile_scheduler in combinations: + kernel_schedule, epilogue_schedule = schedules + A = TensorDescription( + data_type["a_type"], layout[0][0], layout[0][1], complex_transform[0]) + B = TensorDescription( + data_type["b_type"], layout[1][0], layout[1][1], complex_transform[1]) + + # Currently assume tensor C/D have same layout requirement. + C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) + D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) + + element_compute = data_type.get("epi_type", data_type["acc_type"]) + + operation = GemmOperation( + GemmKind.SparseUniversal3x, tile_description.minimum_compute_capability, + tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, + kernel_schedule, epilogue_schedule, tile_scheduler) + + manifest.append(operation) + operations.append(operation) + + return operations + +# +def CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + gemm_kinds = [GemmKind.Sparse] + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GemmOperation(GemmKind.Sparse, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + gemm_kinds = [GemmKind.PlanarComplex, GemmKind.PlanarComplexArray] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for gemm_kind in gemm_kinds: + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + manifest.append(GemmOperation(gemm_kind, \ + tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue)) + return + +# +def CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GroupedGemmOperation(GemmKind.Grouped, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, data_type, \ + alignment_constraints, blas_mode, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + element_a, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for fill_mode in fill_modes: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + + # SERK supported layouts (RowMajor, ColumnMajor) with no conjugation + complex_transform = ComplexTransform.none + + # HERK supported layouts (RowMajor + conj, ColumnMajor) + if blas_mode == BlasMode.hermitian and layout[0] == LayoutType.RowMajor: + complex_transform = ComplexTransform.conj + + alignment_c = 1 # Alignment only applies to A in SYRK + + A = TensorDescription(element_a, layout[0], alignment, complex_transform) + C = SymmetricTensorDescription(element_c, layout[1], fill_mode, alignment_c) + + # Rank-K update + new_operation = RankKOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + # Rank-2K update + new_operation = Rank2KOperation(RankKKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for side_mode in side_modes: + for fill_mode in fill_modes: + for diag_type in diag_types: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TriangularTensorDescription(element_a, layout[0], side_mode, fill_mode, diag_type, + alignment, complex_transform) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = TrmmOperation(TrmmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# +def CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, data_type, \ + alignment_constraints, blas_mode, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for side_mode in side_modes: + for fill_mode in fill_modes: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + + # SYMM supported layouts (RowMajor, ColumnMajor) with no conjugation + complex_transform = ComplexTransform.none + + alignment_a = 1 # No vectorized access for the triangular matrix + alignment_c = min(8, alignment) + + A = SymmetricTensorDescription(element_a, layout[0], fill_mode, alignment_a, complex_transform, side_mode) + # tensor A and B have same data type and layout + B = TensorDescription(element_b, layout[0], alignment) + C = TensorDescription(element_c, layout[1], alignment_c) + + # SYMM/HEMM update + new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + # SYMM/HEMM update + new_operation = SymmOperation(SymmKind.Universal, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor, blas_mode) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +########################################################################################################### +# ConvolutionOperator support variations +# ____________________________________________________________________ +# ConvolutionalOperator | Analytic | Optimized +# ____________________________________________________________________ +# | Fprop | (strided) | (strided) +# | Dgrad | (strided, unity*) | (strided, unity) +# | Wgrad | (strided) | (strided) +# ____________________________________________________________________ +# +# Note : Operator marked (*) are supported but not generated to keep the instantiated kernel count low +########################################################################################################### +# Convolution for 2D operations +def CreateConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + iterator_algorithms = [IteratorAlgorithm.Optimized] + + operations = [] + + for tile in tile_descriptions: + for alignment in alignment_constraints: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operations = [ + # None grouped kernel + Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_), + ] + + # Instance group conv kernel + if tile.math_instruction.opcode_class == OpcodeClass.TensorOp and A.layout == LayoutType.TensorNHWC and \ + tile.minimum_compute_capability >= 80: + # SingleGroup kernel + new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.SingleGroup)) + + # Analytic iterator supports MultipleGroup mode + if iterator_algorithm == IteratorAlgorithm.Analytic: + new_operations.append(Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_, group_mode=GroupMode.MultipleGroup)) + + for new_operation in new_operations: + manifest.append(new_operation) + operations.append(new_operation) + + # + # Conv2d Dgrad + # + if ConvKind.Dgrad in conv_kinds: + + # Unity stride for Analytic and Optimized Dgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Analytic Dgrad + # strided dgrad uses a special threadblock swizzle + # note that SwizzlingFunctor.StridedDgradHorizontal might be + # better for problem sizes with large activation channel count + swizzling_functor_strided_dgrad_ = SwizzlingFunctor.StridedDgradIdentity1 + + if IteratorAlgorithm.Analytic in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) + + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Optimized Dgrad + if IteratorAlgorithm.Optimized in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_strided_dgrad_) + + manifest.append(new_operation) + operations.append(new_operation) + + # + # Conv2d Wgrad + # + if ConvKind.Wgrad in conv_kinds: + + # Strided support for Analytic and Optimized Wgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for 2D operations specialized for few channels +def CreateConv2dFixedChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.FixedChannels,] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + channel_counts = [channel_counts[0],] + + operations = [] + + + + for tile in tile_descriptions: + for channel_count in channel_counts: + + alignment_c = EpilogueAlignment(channel_count, tile) + + A = TensorDescription(element_a, layout[0], channel_count) + B = TensorDescription(element_b, layout[1], channel_count) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for 2D operations specialized for few channels +def CreateConv2dFewChannelsOperator(manifest, layout, tile_descriptions, data_type, channel_counts, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.FewChannels,] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + channel_counts = [channel_counts[0],] + + operations = [] + + for tile in tile_descriptions: + for channel_count in channel_counts: + + alignment_c = EpilogueAlignment(channel_count, tile) + + A = TensorDescription(element_a, layout[0], channel_count) + B = TensorDescription(element_b, layout[1], channel_count) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + # + # Conv2d Fprop + # + if ConvKind.Fprop in conv_kinds: + + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operation = Conv2dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor, swizzling_functor_) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for 3D operations +def CreateConv3dOperator(manifest, layout, tile_descriptions, data_type, alignment, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], epilogue_functor = EpilogueFunctor.LinearCombination): + + element_a, element_b, element_c, element_epilogue = data_type + + # one exceptional case + alignment_c = min(8, alignment) + + # iterator algorithm (analytic and optimized) + iterator_algorithms = [IteratorAlgorithm.Analytic, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size and optimized iterators + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + iterator_algorithms = [IteratorAlgorithm.Optimized] + + operations = [] + + # All tile sizes for Conv3dFprop and Conv3dWgrad + for tile in tile_descriptions: + A = TensorDescription(element_a, layout, alignment) + B = TensorDescription(element_b, layout, alignment) + C = TensorDescription(element_c, layout, alignment_c) + + # + # Conv3d Fprop + # + if ConvKind.Fprop in conv_kinds: + # Strided support for Analytic and Optimized Fprop + for iterator_algorithm in iterator_algorithms: + new_operation = Conv3dOperation(ConvKind.Fprop, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided) + manifest.append(new_operation) + operations.append(new_operation) + # + # Conv3d Wgrad + # + if ConvKind.Wgrad in conv_kinds: + + # Strided support for Analytic and Optimized Wgrad + for iterator_algorithm in iterator_algorithms: + new_operation = Conv3dOperation(ConvKind.Wgrad, iterator_algorithm, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) + manifest.append(new_operation) + operations.append(new_operation) + + # All tile sizes for Conv3dDgrad + for tile in tile_descriptions: + + A = TensorDescription(element_a, layout, alignment) + B = TensorDescription(element_b, layout, alignment) + C = TensorDescription(element_c, layout, alignment_c) + + # + # Conv3d Dgrad + # + if ConvKind.Dgrad in conv_kinds: + # Unity stride for Optimized Dgrad + new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Optimized, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Unity, epilogue_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + # Strided support for Analytic Dgrad + # Conv3dDgrad has a naive strided support which does not cut down redundant MMAs + new_operation = Conv3dOperation(ConvKind.Dgrad, IteratorAlgorithm.Analytic, tile.minimum_compute_capability, tile,\ + A, B, C, element_epilogue, StrideSupport.Strided, epilogue_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +# Convolution for Depthwise 2d conv +def CreateDepthwiseConv2dOperator(manifest, layout, tile_descriptions, data_type, alignment_constraints, \ + conv_kinds = [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad], \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4): + + element_a, element_b, element_c, element_epilogue = data_type + + # iterator algorithm (FixedStrideDilation, Optimized) + iterator_algorithms = [IteratorAlgorithm.FixedStrideDilation, IteratorAlgorithm.Optimized] + + # by default, only generate the largest tile size, largest alignment, and optimized iterator + if manifest.kernel_filter == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + operations = [] + + for tile in tile_descriptions: + for alignment in alignment_constraints: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + if ConvKind.Fprop in conv_kinds: + + # Strided support for Optimized and FixedStridedDilation Depthwise Conv + for iterator_algorithm in iterator_algorithms: + stride_support = StrideSupport.Strided + if iterator_algorithm == IteratorAlgorithm.FixedStrideDilation: + if tile.stride == [-1, -1] or tile.dilation == [-1,-1]: + continue + stride_support = StrideSupport.Fixed + + if iterator_algorithm == IteratorAlgorithm.Optimized: + if tile.stride != [-1, -1] or tile.dilation != [-1,-1]: + continue + new_operation = Conv2dOperation(ConvKind.Fprop, + iterator_algorithm, + tile.minimum_compute_capability, + tile, + A, B, C, + element_epilogue, + stride_support, + epilogue_functor, + swizzling_functor_, + group_mode=GroupMode.Depthwise) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + +class ConvOperation3x: + """All parameters of a CUTLASS 3 convolution operation. + + Unlike CUTLASS 2 convolutions, CUTLASS 3 convolutions do not + distinguish between 2-D and 3-D convolutions by kernel class name. + Instead, for CUTLASS 3 convolutions, the tensor layouts encode + whether the convolution is 2-D or 3-D. Thus, this class deduces + the OperationKind (either Conv2d or Conv3d) from the layouts, + rather than taking it as a constructor parameter. + """ + def __init__(self, + conv_kind: ConvKind, + tile_description: TileDescription, + A: TensorDescription, + B: TensorDescription, + C: TensorDescription, + element_compute: Optional[DataType] = None, + D: Optional[TensorDescription] = None, + kernel_schedule: KernelScheduleType = KernelScheduleType.ScheduleAuto, + epilogue_schedule: EpilogueScheduleType = EpilogueScheduleType.ScheduleAuto, + tile_scheduler: TileSchedulerType = TileSchedulerType.Default, + log_indent_level: int = 1): + log_debug_line(f'ConvOperation3x::init: conv_kind: {conv_kind}', log_indent_level) + log_indent_level = log_indent_level + 1 + + self.conv_kind = conv_kind + self.tile_description = tile_description + self.A = A + self.B = B + self.C = C + self.element_compute = C.element if element_compute is None else element_compute + self.kernel_schedule = kernel_schedule + self.epilogue_schedule = epilogue_schedule + + self.arch = tile_description.minimum_compute_capability + self.tile_scheduler = tile_scheduler + if D == None: + self.D = C + else: + self.D = D + + self.is_3x = True + self.group_mode = GroupMode.NoneGroup # CUTLASS 3 convolutions currently aren't grouped + + operation_kind = None + for layout in (A.layout, B.layout, C.layout): + assert(isinstance(layout, LayoutType)) + new_operation_kind = convolution_tensor_layout_type_to_operation_kind(layout) + if operation_kind is None: + operation_kind = new_operation_kind + else: # CUTLASS 3 convolutions don't permit mixing 2-D and 3-D layouts. + assert(operation_kind == new_operation_kind) + assert(operation_kind is not None) + self.operation_kind = operation_kind + + def __str__(self): + return f"ConvOperation3x: operation_kind={self.operation_kind}, conv_kind={self.conv_kind}, tile_description={self.tile_description}" + + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + + def is_mixed_input(self): + return self.A.element != self.B.element + + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + if self.is_complex(): + return get_complex_from_real(accum) + return accum + + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and', + } + + tensor_ops = [ + OpcodeClass.TensorOp, + OpcodeClass.WmmaTensorOp, + OpcodeClass.SparseTensorOp, + OpcodeClass.BlockScaledTensorOp, + ] + + is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops + + if is_tensor_op: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s" % (math_op_string, intermediate_type, ConvKindNames[self.conv_kind]) + + def extended_name(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + extended_name = "{core_name}_{element_a}{layout_a}_{element_b}{layout_b}_{element_acc}_{element_c}_{element_d}{layout_c}".format( + element_a = DataTypeNames[self.A.element], + layout_a = ShortLayoutTypeNames[self.A.layout], + element_b = DataTypeNames[self.B.element], + layout_b = ShortLayoutTypeNames[self.B.layout], + element_acc = DataTypeNames[self.accumulator_type()], + element_c = DataTypeNames[self.C.element], + layout_c = ShortLayoutTypeNames[self.C.layout], + element_d = DataTypeNames[self.D.element], + core_name = self.core_name()) + + return extended_name + + # Generates a short string representing underlying kernel schedule type + def kernel_schedule_name(self): + return KernelScheduleSuffixes[self.kernel_schedule] + + # Generates a short string representing underlying epilogue schedule type + def epilogue_schedule_name(self): + return EpilogueScheduleSuffixes[self.epilogue_schedule] + + # Generate a short string representing the operation class + def opcode_class_name(self): + return OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + # Generates the full kernel function name + def configuration_name(self): + ''' The full function name indicates architecture, extended name, tile size, and layout. ''' + kernel_name_template = "cutlass3x_sm{ar}_{op}_{ex}{ct}{cs}_{l}_align{al}{t}{k}{e}" + return kernel_name_template.format( + ar = self.arch, + op = self.opcode_class_name(), + ex = self.extended_name(), + ct = '_' + 'x'.join([str(i) for i in self.tile_description.tile_shape]) if self.tile_description.tile_shape[0] > 0 else "", + cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]), + l = self.tile_description.stages, + al = str(max(self.A.alignment, self.B.alignment)), + t = TileSchedulerSuffixes[self.tile_scheduler], + k = self.kernel_schedule_name(), + e = self.epilogue_schedule_name()) + + def procedural_name(self): + return self.configuration_name() + +def convolution_tensor_layout_type_to_operation_kind(layout: LayoutType) -> OperationKind: + if layout == LayoutType.TensorNHWC or layout == LayoutType.TensorKCSR: + return OperationKind.Conv2d + elif layout == LayoutType.TensorNDHWC or layout == LayoutType.TensorKCSRT: + return OperationKind.Conv3d + else: + raise RuntimeError(f'LayoutType {layout} does not have a corresponding OperationKind') + +def CreateConvOperator3x(manifest: Manifest, + dims_and_alignments: Sequence[Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]], + tile_descriptions: Sequence[Sequence[TileDescription]], + data_types, + schedule_pairs: Sequence[Tuple[KernelScheduleType, KernelScheduleType]] = \ + [(KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto)], + complex_transforms: Optional[Sequence[ComplexTransform]] = None, + tile_schedulers: Sequence[TileSchedulerType] = [TileSchedulerType.Default], + conv_kind: ConvKind = ConvKind.Fprop, + log_indent_level: int = 1): + """ + Create zero or more CUTLASS 3 two-dimensional convolution operators. + + Create a CUTLASS 3 two-dimensional convolution operator + for all feasible combinations of the input parameters. + Add the operators to the manifest. + + dims_and_alignments: 3-level list. Each outer list term is a list [A, B, C]. + Each inner list (A, B, or C) has the form [num_spatial_dimensions, alignment]. + Both are integers; the first is the number of spatial dimensions + (currently, only 2 or 3 are supported), and the second is the byte alignment. + We deduce the operation_kind (either OperationKind.Conv2d or OperationKind.Conv3d) + from num_spatial_dimensions. + + This function doesn't take layouts, unlike the GEMM functions. + CUTLASS 3 convolutions currently support three input layouts: + + * TensorNWC for 1-D convolutions, + * TensorNHWC for 2-D convolutions, and + * TensorNDHWC for 3-D convolutions. + + Output (C and D) layouts are the same as input layouts, + except for Wgrad convolutions, where the layouts are + + * TensorKCS for 1-D convolutions, + * TensorKCSR for 2-D convolutions, and + * TensorKCSRT for 3-D convolutions. + + The output layouts are completely constrained by the input layouts + and the convolution kind. + + tile_descriptions: 2-level list. + Outer level has one list per math instruction. + Inner level has one TileDescription for each cluster shape. + + data_types: Either a single data_type dictionary, or a list of them. + Keys: 'a_type', 'b_type', 'c_type', 'd_type', 'acc_type', 'epi_type' + + complex_transforms: Optional list of pairs. + First element of each pair is the complex transform for A, and + second element of each pair is the complex transform for B. + + schedule_pairs: [(kernel_schedule, epilogue_schedule), ...] + + conv_kind: Convolution kind (Fprop, Dgrad, or Wgrad). + """ + log_debug_line('CreateConvOperator3x', log_indent_level) + log_indent_level = log_indent_level + 1 + log_debug_line(f'conv_kind: {conv_kind}', log_indent_level) + + for triple in dims_and_alignments: + assert(isinstance(triple, tuple) or isinstance(triple, list)) + assert(len(triple) == 3) + + spatial_dimensionality = None # to be determined by loop below + + for entry in triple: # [A, B, C] + assert(len(entry) == 2) + [dim, alignment] = entry + assert(type(dim) is int) + assert(dim == 2 or dim == 3) + assert(type(alignment) is int) + assert(alignment > 0) + if spatial_dimensionality is None: + spatial_dimensionality = dim + else: + # A, B, and C need to have the same spatial dimensionality + assert(spatial_dimensionality == dim) + + def input_and_output_layouts(spatial_dim: int, kind: ConvKind) -> Tuple[LayoutType, LayoutType]: + if spatial_dim == 1: + input_layout = LayoutType.TensorNWC + if kind == ConvKind.Wgrad: + output_layout = LayoutType.TensorKCS + else: + output_layout = input_layout + elif spatial_dim == 2: + input_layout = LayoutType.TensorNHWC + if kind == ConvKind.Wgrad: + output_layout = LayoutType.TensorKCSR + else: + output_layout = input_layout + elif spatial_dim == 3: + input_layout = LayoutType.TensorNDHWC + if kind == ConvKind.Wgrad: + output_layout = LayoutType.TensorKCSRT + else: + output_layout = input_layout + else: + assert(False) + return (input_layout, output_layout) + + def dims_to_layouts(A_B_C: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]]) -> \ + Tuple[Tuple[LayoutType, int], Tuple[LayoutType, int], Tuple[LayoutType, int]]: + [A, B, C] = A_B_C + [spatial_dim, alignment] = A + [input_layout, output_layout] = input_and_output_layouts(spatial_dim, conv_kind) + return ((input_layout, A[1]), + (input_layout, B[1]), + (output_layout, C[1])) + + # layouts: list of triples (A, B, C). + # Each of A, B, and C has the form [layout, alignment]. + layouts = [dims_to_layouts(A_B_C) for A_B_C in dims_and_alignments] + + if type(data_types) is dict: + data_types = [data_types] + + for s in schedule_pairs: + assert(len(s) == 2) + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none)] + + # product produces a one-pass generator, so the loop must call it anew each time. + def make_combinations(): + return product( + layouts, + tile_descriptions, + data_types, + complex_transforms, + schedule_pairs, + tile_schedulers + ) + + operations = [] + for layout_triple, tile_description, data_type, complex_transform_pair, schedule_pair, tile_scheduler in make_combinations(): + A_layout, A_alignment = layout_triple[0] + A_xform = complex_transform_pair[0] + B_layout, B_alignment = layout_triple[1] + B_xform = complex_transform_pair[1] + C_layout, C_alignment = layout_triple[2] + D_layout = C_layout + D_alignment = C_alignment + + A = TensorDescription(data_type["a_type"], A_layout, A_alignment, A_xform) + B = TensorDescription(data_type["b_type"], B_layout, B_alignment, B_xform) + C = TensorDescription(data_type["c_type"], C_layout, C_alignment) + D = TensorDescription(data_type["d_type"], D_layout, D_alignment) + element_compute = data_type.get("epi_type", data_type["acc_type"]) + kernel_schedule, epilogue_schedule = schedule_pair + + operation = ConvOperation3x(conv_kind=conv_kind, + tile_description=tile_description, + A=A, + B=B, + C=C, + element_compute=element_compute, + D=D, + kernel_schedule=kernel_schedule, + epilogue_schedule=epilogue_schedule, + tile_scheduler=tile_scheduler, + log_indent_level=log_indent_level) + log_debug_line(f'Created ConvOperation3x: {str(operation)}', log_indent_level) + manifest.append(operation) + operations.append(operation) + + return operations + +################################################################################################### +################################################################################################### + +# +def GenerateSM50_Simt(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + MathInstruction( \ + [1, 1, 1], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 50 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + if math_inst.element_a == DataType.f32: + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM50_Simt_complex(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 50 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32, + ] + + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM50(manifest, cuda_version): + GenerateSM50_Simt(manifest, cuda_version) + GenerateSM50_Simt_complex(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM60_Simt(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 60 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# +def GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version): + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 60 + max_cc = 1024 + + alignment_constraints = [8,] + + filter_3x3 = [3, 3] + filter_5x5 = [5, 5] + + # [stride_h, stride_w] + # [-1, -1] means all stride size. + strides = [[-1,-1], [1, 1], [2, 2]] + # [dilation_h, dilation_w] + # [-1, -1] means all dilation size. + dilations = [[-1,-1], [1, 1], [2, 2]] + + #groups per thread block + g16 = 16 + g32 = 32 + g64 = 64 + + #output shape per thread block + npq_1x4x4 = [1, 4, 4] + npq_1x8x8 = [1, 8, 8] + npq_1x10x10 = [1, 10, 10] + + tile_descriptions = [] + for math_inst in math_instructions: + for stride, dilation in product(strides, dilations): + tile_descriptions.extend([ + # filter3x3 ThreadBlock_output, filter, stage, warp + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_3x3, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_3x3, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_3x3, 4, stride, dilation,[4, 1, 1], math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_3x3, 4, stride, dilation, [4, 1, 1], math_inst, min_cc, max_cc), + + # filter5x5 ThreadBlock_output, filter, stage, warp + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g32], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g64], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x8x8+[g16], filter_5x5, 3, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x10x10+[g64], filter_5x5, 2, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g32], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g64], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc), + Direct2dConvFixedStrideDilationTileDescription(npq_1x4x4+[g16], filter_5x5, 4, stride, dilation,[4, 1, 1],math_inst, min_cc, max_cc) + ]) + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateDepthwiseConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM60(manifest, cuda_version): + GenerateSM60_Simt(manifest, cuda_version) + GenerateSM60_Simt_DepthwiseConv2d(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM61_Simt(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 4], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 61 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) +# + +# +def GenerateSM61(manifest, cuda_version): + GenerateSM61_Simt(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM70_TensorOp_884(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 75 + + alignment_constraints = [8, 4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) + +# +def GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [8, 8, 4], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 75 + + alignment_constraints = [8, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + + +# +def GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version): + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 16, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 16, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 70 + max_cc = 1024 + + alignment_constraints = [8,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# +################################################################################################## +# + +def GenerateSM70(manifest, cuda_version): + GenerateSM70_TensorOp_884(manifest, cuda_version) + GenerateSM70_PlanarComplexTensorOp_884(manifest, cuda_version) + + # To limit build size, WMMA GEMMs are disabled for now. + # + #GenerateSM70_WmmaTensorOp_161616(manifest, cuda_version) + +################################################################################################### +################################################################################################### + +# +def GenerateSM75_TensorOp_1688_FewChannels(manifest, cuda_version, math_inst): + + min_cc = 75 + max_cc = 1024 + + tile_descriptions = [ + TileDescription([128, 64, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 2], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [4, 8]) + CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [1, 2, 4]) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [4, 8]) + CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [1, 2, 4]) + +# +def GenerateSM75_TensorOp_1688(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [8, 4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) + + # Separate generator for 'few channels' specializations + GenerateSM75_TensorOp_1688_FewChannels(manifest, cuda_version, math_inst) + +# + +# +def GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [8, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + +# +def GenerateSM75_TensorOp_8816_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 16], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 90 + + alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), + + TileDescription([256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.s32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# + +# +def GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 16], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 90 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 8 +# + +# +def GenerateSM75_TensorOp_8832_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 32], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 32], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 89 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.s32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + elif op.tile_description.threadblock_shape[1] == 64: + op.C.alignment = 8 + else: + op.C.alignment = 8 + +# + +# +def GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 32], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [8, 8, 32], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 75 + max_cc = 89 + + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 2, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 16 +# + +# +def GenerateSM75_TensorOp_88128(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 128], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.xor_popc), + ] + + min_cc = 75 + max_cc = { + MathOperation.xor_popc: 89, + MathOperation.and_popc: 90 + } + + alignment_constraints = [128,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 512], 2, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 256, 512], 2, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 256, 512], 2, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 64, 512], 2, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + ] + + data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + +# + +# +def GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 16, 16], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.WmmaTensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + DataType.f32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) +# + +# +def GenerateSM75_Simt_complex(manifest, cuda_version): + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc) + ] + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32 + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +def GenerateSM75(manifest, cuda_version): + GenerateSM75_TensorOp_1688(manifest, cuda_version) + GenerateSM75_PlanarComplexTensorOp_1688(manifest, cuda_version) + GenerateSM75_TensorOp_8816_TN(manifest, cuda_version) + GenerateSM75_TensorOp_8816_Interleaved(manifest, cuda_version) + GenerateSM75_TensorOp_8832_TN(manifest, cuda_version) + GenerateSM75_TensorOp_8832_Interleaved(manifest, cuda_version) + GenerateSM75_TensorOp_88128(manifest, cuda_version) + #GenerateSM75_WmmaTensorOp_161616(manifest, cuda_version) + GenerateSM75_Simt_complex(manifest, cuda_version) + + +################################################################################################### +################################################################################################### + +# +def GenerateSM80_TensorOp_16816(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8, 4, 2] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmGroupedOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type, [4, 8]) + CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type, 8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) + CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, [4, 8]) + CreateConv3dOperator(manifest, LayoutType.TensorNDHWC, tile_descriptions, data_type_mixed, 8) +# + +# +def GenerateSM80_SparseTensorOp_16832(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 32], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 32], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# + +# +def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8, ] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([ 64, 128, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, complex_transforms) + +# +def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand A + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.s8, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.u8, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[16, 8, 8],] + + for math_inst in math_instructions: + tile_descriptions = [ + # 128x128 + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x64 + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x32 + TileDescription([128, 32, 64], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x16 + TileDescription([128, 16, 64], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 64], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_b != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_b, + math_inst.element_accumulator, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + for op in operations: + if (DataTypeSize[op.C.element] == 16) and \ + (op.tile_description.threadblock_shape[1] <= 32): + op.C.alignment = 4 + +# +def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.s8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.u8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.s8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.bf16, DataType.u8, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.s8, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + MathInstruction( \ + [16, 8, 16], \ + DataType.f16, DataType.u8, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[8, 16, 8],] + + for math_inst in math_instructions: + tile_descriptions = [ + # 128x128 + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x64 + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x32 + TileDescription([128, 32, 64], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x16 + TileDescription([128, 16, 64], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 64], 3, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 9, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 3, [2, 1, 1], math_inst, min_cc, max_cc), + # 256x16 + TileDescription([256, 16, 32], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 16, 32], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] <= 32: + op.C.alignment = 4 + +# +def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 32], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + smem_usage = 164 + + alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + operations += CreateConv2dFewChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +# + +def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand A + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s4, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[32, 16, 4],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + alignment_constraints = [[32, 16, 16],] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_b, + DataType.f32 + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + # Upcast on Operand B + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_mixed_input_upcast), + ] + + min_cc = 80 + max_cc = 1024 + + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[16, 32, 4],] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. S8 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + alignment_constraints = [[16, 32, 16],] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + DataType.f32, + ] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp, SwizzlingFunctor.Identity8) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 64], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [16,] + + tile_descriptions = [ + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.s8, DataType.s8, DataType.s32, DataType.s32] + data_type_mixed = [DataType.s8, DataType.s8, DataType.s8, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved32, LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 32], \ + DataType.u8, DataType.u8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC32HW32, LayoutType.TensorC32RSK32, LayoutType.TensorNC32HW32) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16864_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 64], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 64], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + elif op.tile_description.threadblock_shape[1] == 64: + op.C.alignment = 8 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 128], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate) + + min_cc = 80 + max_cc = 1024 + alignment_constraints = [32,] + + tile_descriptions = [ + TileDescription([ 64, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.s4, DataType.s4, DataType.s32, DataType.s32] + data_type_mixed = [DataType.s4, DataType.s4, DataType.s4, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination) + + operations = [] + + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] > 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + +# +def GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajorInterleaved64, LayoutType.RowMajorInterleaved64, LayoutType.ColumnMajorInterleaved64), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 64], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + MathInstruction( \ + [16, 8, 64], \ + DataType.u4, DataType.u4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate), + ] + + min_cc = 80 + max_cc = 1024 + alignment_constraints = [32,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] + + operations = [] + + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + conv_layout = (LayoutType.TensorNC64HW64, LayoutType.TensorC64RSK64, LayoutType.TensorNC64HW64) + + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type_mixed, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + op.C.alignment = 16 +# + +# +def GenerateSM80_TensorOp_168256(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 256], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.xor_popc), + MathInstruction( \ + [16, 8, 256], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.and_popc), + ] + + min_cc = 80 + max_cc = { + MathOperation.xor_popc: 89, + MathOperation.and_popc: 90 + } + + alignment_constraints = [128,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 512], 3, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 256, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 64, 512], 4, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 256, 512], 4, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 128, 512], 5, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 64, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 128, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 64, 512], 10, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 128, 1024], 3, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 256, 1024], 3, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 64, 1024], 4, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 256, 1024], 4, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 128, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 64, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 128, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 64, 1024], 5, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + ] + + data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + +# + +# +def GenerateSM80_TensorOp_1688(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type_mixed, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f16), + MathInstruction( \ + [16, 8, 8], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_bf16), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +def GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32) + + min_cc = 80 + max_cc = 1024 + + tile_descriptions = [ + TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + + +# +def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + tile_descriptions = [ + TileDescription([128, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1, 2, 4] # Alignment only applies to A in SYRK + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32] + + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + # SYRK + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1, 2, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_1688_symm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + # A and B have same layouts + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [ + 1, 2, 4 + ] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + #TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 8], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex), + MathInstruction( \ + [16, 8, 8], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_fast_f32), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 16], 4, [2, 1, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + DataType.cf32, DataType.cf32, DataType.cf32, DataType.cf32 + ] + + alignment_constraints = [1,] + + # SYMM + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_884(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_884_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) + +# +def GenerateSM80_TensorOp_884_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64] + + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_884_rank_k_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) + +# + +# +def GenerateSM80_TensorOp_884_rank_k_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_884_trmm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM80_TensorOp_884_trmm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + + +# +def GenerateSM80_TensorOp_884_trmm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM80_TensorOp_884_symm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [8, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +################################################################################################### + +# +def GenerateSM80_Simt_f32(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 5, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 8], 4, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + + +# +def GenerateSM80_Simt_f64(manifest, cuda_version): + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 8], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 8], 5, [1, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + + +################################################################################################## +# +def GenerateSM80_Simt_complex(manifest, cuda_version): + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32 + ] + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + for math_inst in math_instructions: + + tile_descriptions = [ + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, alignment_constraints, complex_transforms) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + CreateConv2dOperator(manifest, conv_layout, tile_descriptions, data_type, alignment_constraints) +# + +################################################################################################### + +# +def GenerateSM80(manifest, cuda_version): + GenerateSM80_TensorOp_16816(manifest, cuda_version) + GenerateSM80_SparseTensorOp_16832(manifest, cuda_version) + GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version) + GenerateSM80_TensorOp_1688(manifest, cuda_version) + GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version) + GenerateSM80_SparseTensorOp_16816_fast_math(manifest, cuda_version) + GenerateSM80_TensorOp_1688_complex(manifest, cuda_version) + # 3xTF32 + GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version) + GenerateSM80_TensorOp_1688_fast_fp32_math_complex(manifest, cuda_version) + GenerateSM80_TensorOp_1688_rank_k(manifest, cuda_version) + GenerateSM80_TensorOp_1688_rank_k_complex(manifest, cuda_version) + GenerateSM80_TensorOp_1688_trmm(manifest, cuda_version) + GenerateSM80_TensorOp_1688_trmm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_1688_symm(manifest, cuda_version) + GenerateSM80_TensorOp_1688_symm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884(manifest, cuda_version) + GenerateSM80_TensorOp_884_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_884_rank_k(manifest, cuda_version) + GenerateSM80_TensorOp_884_rank_k_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_rank_k_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_884_trmm(manifest, cuda_version) + GenerateSM80_TensorOp_884_trmm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_trmm_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_884_symm(manifest, cuda_version) + GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version) + GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version) + GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version) + GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_a(manifest, cuda_version) + GenerateSM80_TensorOp_16832_TN_mixed_input_upcast_b(manifest, cuda_version) + GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version) + GenerateSM80_TensorOp_16864_TN(manifest, cuda_version) + GenerateSM80_SparseTensorOp_168128_TN(manifest, cuda_version) + GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version) + GenerateSM80_TensorOp_168256(manifest, cuda_version) + GenerateSM80_Simt_f32(manifest, cuda_version) + GenerateSM80_Simt_f64(manifest, cuda_version) + GenerateSM80_Simt_complex(manifest, cuda_version) + +################################################################################################### + +def GenerateSM89_TensorOp_16832_fp8(manifest, element_acc): + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) + ] + + math_instructions = [ + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 32], + DataType.e4m3, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e4m3, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 32], + DataType.e5m2, DataType.e5m2, element_acc, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + ] + + min_cc = 89 + max_cc = 100 + alignment_constraints = [16,] + alignment_constraints_small_channels = [16, 8, 4] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 6, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 6, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_types = [ + [ + math_inst.element_a, + math_inst.element_b, + DataType.f32, + math_inst.element_accumulator + ], + [ + math_inst.element_a, + math_inst.element_b, + DataType.bf16, + math_inst.element_accumulator + ], + ] + + operations = [] + for data_type in data_types: + operations += CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, + alignment_constraints, None, EpilogueFunctor.LinearCombination) + + conv_layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + operations += CreateConv2dOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + operations += CreateConv2dFixedChannelsOperator(manifest, conv_layout, tile_descriptions, + data_type, alignment_constraints_small_channels, [ConvKind.Fprop], EpilogueFunctor.LinearCombination) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +def GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 4): + return + + GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f32) + +def GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + GenerateSM89_TensorOp_16832_fp8(manifest, DataType.f16) + +# +def GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version): + + if ( + not CudaToolkitVersionSatisfies(cuda_version, 12, 4) + ): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor) + ] + + math_instructions = [ + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 64], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + MathInstruction( + [16, 8, 64], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_accum), + ] + + min_cc = 89 + max_cc = 89 + + alignment_constraints = [16,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_types = [ + [ + math_inst.element_a, + math_inst.element_b, + DataType.f32, + math_inst.element_accumulator + ], + ] + + operations = [] + for data_type in data_types: + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, + alignment_constraints, None, EpilogueFunctor.LinearCombination) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 + +################################################################################################### + +# +def GenerateSM89(manifest, cuda_version): + GenerateSM89_TensorOp_16832_fp8_fp32acc(manifest, cuda_version) + GenerateSM89_TensorOp_16832_fp8_fp16acc(manifest, cuda_version) + GenerateSM89_SparseTensorOp_16864_fp8(manifest, cuda_version) + +################################################################################################### + + +try: + from .sm90_utils import ( + generate_fp16_bf16_math_instructions_sm90, + generate_tf32_math_instructions_sm90, + generate_int8_math_instructions_sm90, + generate_fp8_math_instructions_sm90, + generate_mixed_dtype_math_instructions_sm90, + make_sparse_math_instructions, + generate_tile_descriptions_sm90, + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) +except ImportError: + from sm90_utils import ( + generate_fp16_bf16_math_instructions_sm90, + generate_tf32_math_instructions_sm90, + generate_int8_math_instructions_sm90, + generate_fp8_math_instructions_sm90, + generate_mixed_dtype_math_instructions_sm90, + make_sparse_math_instructions, + generate_tile_descriptions_sm90, + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) + +def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_fp16_bf16_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_types = [data_type_w_source, data_type_wo_source] + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed_w_source = generate_data_types_from_math_instruction( + math_inst, + element_source=math_inst.element_a, + element_dest=math_inst.element_a + ) + data_type_mixed_wo_source = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=math_inst.element_a + ) + data_types.append(data_type_mixed_w_source) + data_types.append(data_type_mixed_wo_source) + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + gemm_kind=gemm_kind, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_fp16_bf16_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_types = [data_type_w_source] + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed_w_source = generate_data_types_from_math_instruction( + math_inst, + element_source=math_inst.element_a, + element_dest=math_inst.element_a + ) + data_types.append(data_type_mixed_w_source) + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + +def GenerateSM90_SparseTensorOp_16b_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=131, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = make_sparse_math_instructions(generate_fp16_bf16_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_types = [data_type_w_source, data_type_wo_source] + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed_w_source = generate_data_types_from_math_instruction( + math_inst, + element_source=math_inst.element_a, + element_dest=math_inst.element_a + ) + data_type_mixed_wo_source = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=math_inst.element_a + ) + data_types.append(data_type_mixed_w_source) + data_types.append(data_type_mixed_wo_source) + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + ] + + math_instructions = generate_tf32_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + + for layout in layouts: + data_type_tf32 = generate_data_types_from_math_instruction(math_inst) + data_type_tf32_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_f32 = copy.deepcopy(data_type_tf32) + data_type_f32_wo_source = copy.deepcopy(data_type_tf32_wo_source) + data_type_f32["a_type"] = DataType.f32 + data_type_f32["b_type"] = DataType.f32 + data_type_f32["epi_type"] = DataType.f32 + data_type_f32_wo_source["a_type"] = DataType.f32 + data_type_f32_wo_source["b_type"] = DataType.f32 + data_type_f32_wo_source["epi_type"] = DataType.f32 + data_types = [data_type_tf32, data_type_f32, data_type_tf32_wo_source, data_type_f32_wo_source] + + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=101, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_tf32_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + + for layout in layouts: + # Inconsistency: TF32 does not stamp out void-C + data_type_tf32 = generate_data_types_from_math_instruction(math_inst) + data_type_f32 = copy.deepcopy(data_type_tf32) + data_type_f32["a_type"] = DataType.f32 + data_type_f32["b_type"] = DataType.f32 + data_type_f32["epi_type"] = DataType.f32 + for data_type in [data_type_tf32, data_type_f32]: + # Inconsistency: alignments aren't fixed in TF32 / alignx + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=120, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + ] + + math_instructions = make_sparse_math_instructions(generate_tf32_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + + for layout in layouts: + data_type_tf32 = generate_data_types_from_math_instruction(math_inst) + data_type_tf32_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_f32 = copy.deepcopy(data_type_tf32) + data_type_f32_wo_source = copy.deepcopy(data_type_tf32_wo_source) + data_type_f32["a_type"] = DataType.f32 + data_type_f32["b_type"] = DataType.f32 + data_type_f32["epi_type"] = DataType.f32 + data_type_f32_wo_source["a_type"] = DataType.f32 + data_type_f32_wo_source["b_type"] = DataType.f32 + data_type_f32_wo_source["epi_type"] = DataType.f32 + data_types = [data_type_tf32, data_type_f32, data_type_tf32_wo_source, data_type_f32_wo_source] + + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], + ] + + math_instructions = generate_int8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_int8_output = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.s8, + element_dest=math_inst.element_a, + element_epilogue=DataType.f32 + ) + data_types = [data_type_w_source, data_type_wo_source, data_type_int8_output] + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = generate_int8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_int8_output = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.s8, + element_dest=math_inst.element_a, + element_epilogue=DataType.f32 + ) + data_types = [data_type_w_source, data_type_int8_output] + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=100, default_level=111, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], + ] + + math_instructions = make_sparse_math_instructions(generate_int8_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + # s8.u8 and u8.s8 wgmma variants require PTX 8.4 + if math_inst.element_a != math_inst.element_b and not CudaToolkitVersionSatisfies(cuda_version, 12, 4): + continue + data_type_w_source = generate_data_types_from_math_instruction(math_inst) + data_type_wo_source = generate_data_types_from_math_instruction(math_inst, element_source=DataType.void) + data_type_int8_output = generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.s8, + element_dest=math_inst.element_a, + element_epilogue=DataType.f32 + ) + data_types = [data_type_w_source, data_type_wo_source, data_type_int8_output] + + for layout in layouts: + for data_type in data_types: + layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + else: + for d_type in valid_types_for_d: + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + gemm_kind=gemm_kind, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + +def GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 3 if is_grouped(gemm_kind) else 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) + tile_descriptions_ = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + tile_descriptions = list() + + for desc in tile_descriptions_: + desc.explicit_vector_sizes = [1, desc.tile_shape[1], desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + desc.explicit_vector_sizes = [desc.tile_shape[0], desc.tile_shape[1], desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + desc.explicit_vector_sizes = [1, 1, desc.tile_shape[2]] + tile_descriptions.append(copy.deepcopy(desc)) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + else: + for d_type in valid_types_for_d: + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + gemm_kind=gemm_kind, + enable_fp8_fast_acc=False, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules, gemm_kind=gemm_kind) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK], + gemm_kind=gemm_kind) + + + +def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=0, default_level=101, exhaustive_level=9992) + is_aligned = False + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], # TN Layout + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = generate_fp8_math_instructions_sm90(instantiation_level) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [generate_data_types_from_math_instruction(math_inst)] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + +def GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 1): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9999) + is_aligned = True + + # layouts for ABC, their alignments will be fixed later based on the data type + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16]], + ] + + valid_types_for_a_b_acc = [ + (DataType.e4m3, DataType.f16, DataType.f32), + (DataType.e4m3, DataType.bf16, DataType.f32), + (DataType.e5m2, DataType.f16, DataType.f32), + (DataType.e5m2, DataType.bf16, DataType.f32), + (DataType.s8, DataType.f16, DataType.f32), + (DataType.s8, DataType.bf16, DataType.f32), + (DataType.u8, DataType.f16, DataType.f32), + (DataType.u8, DataType.bf16, DataType.f32), + (DataType.s4, DataType.f16, DataType.f32), + (DataType.s4, DataType.bf16, DataType.f32), + (DataType.s4, DataType.e4m3, DataType.f32), + (DataType.s4, DataType.e5m2, DataType.f32), + (DataType.u4, DataType.f16, DataType.f32), + (DataType.u4, DataType.bf16, DataType.f32), + (DataType.u2, DataType.f16, DataType.f32), + (DataType.u2, DataType.bf16, DataType.f32), + (DataType.s2, DataType.f16, DataType.f32), + (DataType.s2, DataType.bf16, DataType.f32), + ] + # Note: For sizeof(a_type) > sizeof(b_type), some generated kernels might crash due to a compiler bug. Disable it for now. + #swapped_valid_types_for_a_b_acc = [(b_type, a_type, acc_type) for a_type, b_type, acc_type in valid_types_for_a_b_acc] + #valid_types_for_a_b_acc = valid_types_for_a_b_acc + swapped_valid_types_for_a_b_acc + + math_instructions = generate_mixed_dtype_math_instructions_sm90(instantiation_level, valid_types_for_a_b_acc) + + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + + # Limit C/D types to avoid a giant number of instantiations. + # A typical use case for mixed dtype in DL is weight quantization (tensor A), + # therefore we can limit the output type to that of activation (tensor B). + valid_types_for_c = [math_inst.element_b] + valid_types_for_d = [math_inst.element_b] + + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Fix alignments, DataTypeSize are in the unit of bits + alignment_bits = 128 + layout[0][1] = alignment_bits // DataTypeSize[data_type['a_type']] + layout[1][1] = alignment_bits // DataTypeSize[data_type['b_type']] + layout[2][1] = alignment_bits // DataTypeSize[data_type['c_type']] + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 2): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=20, default_level=121, exhaustive_level=9992) + is_aligned = True + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = make_sparse_math_instructions(generate_fp8_math_instructions_sm90(instantiation_level)) + tile_descriptions = generate_tile_descriptions_sm90( + math_instructions=math_instructions, + is_aligned=is_aligned, + level=instantiation_level) + + for tile_desc in tile_descriptions: + math_inst = tile_desc.math_instruction + data_types = [] + fp8_types = [DataType.e4m3, DataType.e5m2] + valid_types_for_d = [DataType.f32, DataType.bf16, DataType.f16, DataType.e4m3, DataType.e5m2] + valid_types_for_c = copy.deepcopy(valid_types_for_d) + valid_types_for_c.append(DataType.void) + for c_type, d_type in product(valid_types_for_c, valid_types_for_d): + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=c_type, + element_dest=d_type, + ) + ) + else: + for d_type in valid_types_for_d: + data_types.append( + generate_data_types_from_math_instruction( + math_inst, + element_source=DataType.void, + element_dest=d_type, + ) + ) + + for layout in layouts: + for data_type in data_types: + # Inconsistency: alignments aren't fixed in FP8 + # layout = fix_alignments(data_type, layout, alignment_bits=128) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_desc, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_type, + instantiation_level=instantiation_level, + layout=layout, + ) + + if len(schedules): + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, schedules) + if len(stream_k_schedules): + assert CudaToolkitVersionSatisfies(cuda_version, 12, 1) + CreateSparseGemmUniversal3xOperator(manifest, [layout], [tile_desc], data_type, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]) + + +def GenerateSM90_TensorOp_1684(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = MathInstruction( + [16, 8, 4], + DataType.f64, DataType.f64, DataType.f64, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 16], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([32, 256, 16], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateGemmOperator(manifest, layouts, tile_descriptions, + data_type, alignment_constraints) + +# + +# +def GenerateSM90_TensorOp_1684_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8 ], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8 ], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8 ], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8 ], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8 ], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8 ], 4, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 3, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64] + + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) + +# + +# +def GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor), + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYRK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HERK computation + CreateRankKOperator(manifest, layouts, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints) +# + +# +def GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + + +# +def GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + diag_types = [ + DiagType.NonUnit, DiagType.Unit, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ + ComplexTransform.none, ComplexTransform.conj, + ] + + CreateTrmmOperator(manifest, layouts, side_modes, fill_modes, diag_types, tile_descriptions, \ + data_type, alignment_constraints, complex_transforms) +# + +# +def GenerateSM90_TensorOp_1684_symm(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([16, 32, 16], 5, [1, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 16, 16], 5, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f64, DataType.f64, DataType.f64, DataType.f64] + + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) +# + +# +def GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([128, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 8], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + +# +def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 8): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + side_modes = [ + SideMode.Left, SideMode.Right, + ] + + fill_modes = [ + FillMode.Lower, FillMode.Upper, + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 4], \ + DataType.f64, DataType.f64, DataType.f64, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_complex_gaussian) + + min_cc = 90 + max_cc = 90 + + alignment_constraints = [1,] + + tile_descriptions = [ + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 32, 8], 4, [2, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([16, 32, 8], 4, [1, 2, 1], math_inst, min_cc, max_cc), + #TileDescription([32, 16, 8], 4, [2, 1, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.cf64, DataType.cf64, DataType.cf64, DataType.cf64] + + complex_transforms = [ComplexTransform.none,] + + # SYMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.symmetric) + + # HEMM computation + CreateSymmOperator(manifest, layouts, side_modes, fill_modes, tile_descriptions, \ + data_type, alignment_constraints, BlasMode.hermitian) +# + + + +# Blackwell SM 100 generators + +try: + import cutlass_library.sm100_utils + from cutlass_library.sm100_utils import ( + generate_tf32_math_instructions_sm100, + generate_16b_math_instructions_sm100, + generate_f8f6f4_math_instructions_sm100, + generate_mxf8f6f4_math_instructions_sm100, + generate_mxf4nvf4_math_instructions_sm100, + generate_fp8_math_instructions_sm100, + generate_cluster_shapes_sm100, + get_pruning_level_from_global_level + ) +except ImportError: + import sm100_utils + from sm100_utils import ( + generate_tf32_math_instructions_sm100, + generate_16b_math_instructions_sm100, + generate_f8f6f4_math_instructions_sm100, + generate_mxf8f6f4_math_instructions_sm100, + generate_mxf4nvf4_math_instructions_sm100, + generate_fp8_math_instructions_sm100, + generate_cluster_shapes_sm100, + get_pruning_level_from_global_level + ) + +################################################################################################### + +def get_tma_alignment_elt(data_type : DataType, is_f8f6f4 : bool = True ) -> int: + if DataTypeSize[data_type] < 8 and is_f8f6f4: + return int(128) + return int(16 * 8 / DataTypeSize[data_type]) + +sm100_cluster_shape_1sm = [ + [4,4,1] + , DynamicClusterShape +] + +sm100_cluster_shape_2sm = [ + # cluster_m % 2 == 0 for 2sm + [4,4,1] + , DynamicClusterShape +] + +def GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4]], + ] + + data_types = [ + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + math_instructions_1sm, math_instructions_2sm = generate_tf32_math_instructions_sm100(instantiation_level) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + +def GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=490, default_level=490, exhaustive_level=9999) + + # layouts for ABC and their alignments. C alignment will be set later based on output type + layouts = [ + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 8], [LayoutType.RowMajor, 8], [LayoutType.RowMajor, 0]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + math_instructions_1sm, math_instructions_2sm = generate_16b_math_instructions_sm100(instantiation_level) + + min_cc = 100 + max_cc = thor_sm + grouped = is_grouped(gemm_kind) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + kernel_schedule = KernelScheduleType.TmaWarpSpecialized1SmSm100 if not grouped else KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100 + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized1Sm if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + if grouped: + epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + elif math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + +def GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=591 , default_level=591 , exhaustive_level=9999) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + grouped = is_grouped(gemm_kind) + + math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized1SmSm100, grouped) + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, epi_schedule]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + + # 2xSM MMA kernels + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + + if grouped: + epi_schedule = EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + elif math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + kernel_schedule = to_grouped_schedule(KernelScheduleType.TmaWarpSpecialized2SmSm100, grouped) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, epi_schedule]], tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + +def GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=593, default_level=593, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + min_cc = 100 + max_cc = 100 + epi_type = DataType.f32 + + pruning_level = get_pruning_level_from_global_level(instantiation_level) + + math_instructions_1sm, math_instructions_2sm = generate_fp8_math_instructions_sm100(instantiation_level, enable_compile_time_dtype=grouped or pruning_level >= 1, enable_runtime_dtype=not grouped) + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level) + + tile_schedulers = [ + TileSchedulerType.Default, + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, + [math_inst.instruction_shape[0], math_inst.instruction_shape[1], + math_inst.instruction_shape[2] * 4])) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, + [1, math_inst.instruction_shape[1], + math_inst.instruction_shape[2] * 4])) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape, + [math_inst.instruction_shape[0], 1, + math_inst.instruction_shape[2] * 4])) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + for data_type in data_types: + if ( data_type["a_type"] == DataType.e4m3 ) and ( data_type["b_type"] == DataType.e4m3 ) and\ + ( data_type["d_type"] == DataType.e5m2 ): + continue + + is_runtime_datatype_a = is_runtime_datatype(data_type["a_type"]) + is_runtime_datatype_b = is_runtime_datatype(data_type["d_type"]) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + kernel_schedule = to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100, grouped) + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + epi_schedule_nosmem = to_grouped_schedule(EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm, grouped) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, epi_schedule], [kernel_schedule, epi_schedule_nosmem]], + tile_schedulers=tile_schedulers, gemm_kind=gemm_kind) + +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.Universal3x): + + # SM100 MMA with mixed F4/F6/F8 inputs + without block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + math_instructions_1sm, math_instructions_2sm = generate_f8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + for kernel_data_type in kernel_data_types: + # Filter out some kernel + if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\ + ( kernel_data_type["d_type"] == DataType.e5m2 ): + continue + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], tile_schedulers=tile_schedulers) + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + } + ] + + for kernel_data_type in kernel_data_types: + # Filter some kernel + if ( kernel_data_type["a_type"] == DataType.e4m3 ) and ( kernel_data_type["b_type"] == DataType.e4m3 ) and\ + ( kernel_data_type["d_type"] == DataType.e5m2 ): + continue + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + if math_inst.instruction_shape[0] == 128: + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], tile_schedulers=tile_schedulers) + else: + CreateGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]], tile_schedulers=tile_schedulers) + +def GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + + # SM100 MMA with mixed F4/F6/F8 inputs + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=590, default_level=590, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 128], [LayoutType.RowMajor, 0]], + ] + + math_instructions_1sm, math_instructions_2sm = generate_mxf8f6f4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func) + + ab_types = [ + DataType.f4, DataType.f6, + DataType.e2m1, + DataType.e2m3, + DataType.e3m2, + DataType.e5m2, + DataType.e4m3, + ] + + acc_types = [ DataType.f32 ] + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if sfdtype["type"] == DataType.void or grouped: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + for math_inst in math_instructions_2sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + ] + + # Set alignment d based on Destination format. + for data_type in data_types: + for layout in layouts: + # alignment for a + layout[0][1] = get_tma_alignment_elt(data_type["a_type"]) + # alignment for b + layout[1][1] = get_tma_alignment_elt(data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(data_type["d_type"]) + for tile in tile_descriptions: + math_inst = tile.math_instruction + # Filter some kernels that does not meet the alignment requirements. + if layout[0][0] == LayoutType.ColumnMajor: + if math_inst.instruction_shape[0] // 2 % layout[0][1] != 0: + continue + else: + if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[0][1] != 0: + continue + + if layout[1][0] == LayoutType.RowMajor: + if math_inst.instruction_shape[1] // 2 % layout[1][1] != 0: + continue + else: + if tile.threadblock_shape[2] // tile.cluster_shape[2] % layout[1][1] != 0: + continue + + if grouped: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[to_grouped_schedule(KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, grouped), to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + elif math_inst.instruction_shape[0] == 128: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + else: + CreateGemmUniversal3xOperator(manifest, [layout], [tile], [data_type], + [[KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100, EpilogueScheduleType.ScheduleAuto]] + , tile_schedulers = tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + + +def GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + # SM100 MMA with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + instantiation_level = manifest.get_instantiation_level(pruned_level=591, default_level=591, exhaustive_level=9999) + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], + ] + + math_instructions_1sm, math_instructions_2sm = generate_mxf4nvf4_math_instructions_sm100(instantiation_level, enable_runtime_dtype=not grouped) + + def change_priority_func(shapes_1sm, shapes_2sm): + shapes_1sm[(1,2,1)] = 6 + shapes_1sm[(1,4,1)] = 6 + shapes_2sm[(2,2,1)] = 6 + shapes_2sm[(2,4,1)] = 6 + shapes_2sm[(4,2,1)] = 6 + + cluster_shapes_1sm, cluster_shapes_2sm = generate_cluster_shapes_sm100(instantiation_level, change_priority_func=change_priority_func) + + acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if sfdtype["type"] == DataType.void or grouped: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + if thor_sm in manifest.compute_capabilities_baseline : + if [4,4,1] in cluster_shapes_1sm : + cluster_shapes_1sm.remove([4,4,1]) + if [4,4,1] in cluster_shapes_2sm : + cluster_shapes_2sm.remove([4,4,1]) + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + assert math_inst.instruction_shape[2] * 4 == 256 + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for layout in layouts: + for data_type in data_types: + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + + # E2M1 x E2M1, vector size 32, E8 + # E2M1 x E2M1, vector size 16, UE4M3 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + epi_schedule = to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped) + epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) + nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped) + fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100, grouped) + + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind + ) + + for math_inst in math_instructions_2sm: + assert math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for layout in layouts: + for data_type in data_types: + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + epi_schedule = EpilogueScheduleType.ScheduleAuto if not grouped else EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm + epi_nosmem_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) + nvfp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped) + fp4_kernel_schedule = to_grouped_schedule(KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100, grouped) + + nvfp4_schedules = [[nvfp4_kernel_schedule, epi_schedule], [nvfp4_kernel_schedule, epi_nosmem_schedule]] + fp4_schedules = [[fp4_kernel_schedule, epi_schedule], [fp4_kernel_schedule, epi_nosmem_schedule]] + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, nvfp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, fp4_schedules + , tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + +def GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.BlockScaledUniversal3x): + # SM100 MMA with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 13, 0): + return + + grouped = is_grouped(gemm_kind) + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.ColumnMajor, 0]], + ] + + instruction_sizes_1sm = [ + [128, 128, 96], + ] + + instruction_sizes_2sm = [ + [256, 128, 96], + [256, 192, 96], + [256, 256, 96] + ] + + ab_types = [ + DataType.f4, + DataType.e2m1, + ] + + sf_types = [ + DataType.ue4m3, + DataType.ue8m0 + ] + + acc_types = [ DataType.f32 ] # Accumulator is always 32 bits for block scaled MMA instructions + + def tile_schedulers(sfdtype): + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K. + if grouped: + return [TileSchedulerType.Default] + if sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 103 + max_cc = 103 + epi_type = DataType.f32 + + math_instructions_1sm = [] + + is_runtime_datatype = lambda runtime_datatype: runtime_datatype in (DataType.f4, DataType.f6, DataType.f8) + + for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_1sm, ab_types, ab_types, sf_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_1sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + sf_type) + ) + + math_instructions_2sm = [] + + for instr_size, a_type, b_type, sf_type, acc_type in product(instruction_sizes_2sm, ab_types, ab_types, sf_types, acc_types): + is_runtime_datatype_a = is_runtime_datatype(a_type) + is_runtime_datatype_b = is_runtime_datatype(b_type) + + # A/B datatypes should be both static or dynamic + if (is_runtime_datatype_a != is_runtime_datatype_b): + continue + + math_instructions_2sm.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + sf_type) + ) + + cluster_shapes_1sm = [ + [1,1,1], + # [1,2,1], + [2,1,1], + # [1,4,1], + [4,4,1], + DynamicClusterShape + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + 768], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + for data_type in data_types: + # Set alignment d based on Destination format. + if DataTypeSize[data_type["c_type"]] == 0 : + layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] + else: + layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) + + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + epilogue_1sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized1Sm, grouped) + + nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped), epilogue_1sm_schedule] + nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] + nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] + fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped), epilogue_1sm_schedule] + fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped), epilogue_1sm_schedule] + fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped), epilogue_1sm_schedule] + nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] + fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] + + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + cluster_shapes_2sm = [ + [2,1,1], + # [2,2,1], + # [2,4,1], + [4,1,1], + # [4,2,1], + [4,4,1], + DynamicClusterShape + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 8 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + for data_type in data_types: + # Set alignment d based on Destination format. + if DataTypeSize[data_type["c_type"]] == 0 : + layout[2][1] = 256 // DataTypeSize[data_type["d_type"]] + else: + layout[2][1] = min(256 // DataTypeSize[data_type["d_type"]], 256 // DataTypeSize[data_type["c_type"]]) + + if data_type["sfd_type"]["type"] != DataType.void and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.RowMajor): + data_type["sfd_type"]["layout"] = layout[2][0] # For FP4 output , the scalefactor layout is same layout as D layout. + if (data_type["sfd_type"]["type"] != DataType.void) and (data_type["d_type"] == DataType.e2m1) and (layout[2][0] == LayoutType.ColumnMajor): + continue + # E2M1 x E2M1, vector size 32, E8 + isFp4 = math_inst.element_scale_factor == DataType.ue8m0 and math_inst.element_a == DataType.e2m1 and math_inst.element_b == DataType.e2m1 + + epilogue_2sm_schedule = to_grouped_schedule(EpilogueScheduleType.NoSmemWarpSpecialized2Sm, grouped) + + nvfp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped), epilogue_2sm_schedule] + nvfp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] + nvfp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] + fp4_schedule = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped), epilogue_2sm_schedule] + fp4_schedule_disable_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped), epilogue_2sm_schedule] + fp4_schedule_tma_prefetch = [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped), epilogue_2sm_schedule] + nvfp4_schedules = [nvfp4_schedule, nvfp4_schedule_disable_prefetch, nvfp4_schedule_tma_prefetch] + fp4_schedules = [fp4_schedule, fp4_schedule_disable_prefetch, fp4_schedule_tma_prefetch] + + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + nvfp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + if isFp4: + CreateGemmUniversal3xOperator(manifest, [layout], tile_descriptions, data_type, + fp4_schedules, tile_schedulers=tile_schedulers(data_type["sfd_type"]), gemm_kind=gemm_kind) + + +def GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.RowMajor, 0]], + [[LayoutType.RowMajor, 16], [LayoutType.RowMajor, 16], [LayoutType.RowMajor, 0]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + epi_type = DataType.f32 + + math_instructions_1sm = [ + MathInstruction( + [64, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add)] + + cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1], [4,4,1] + , DynamicClusterShape + ] + + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_1sm = [[1,2,1], [2,1,1], [1,1,1], [1,4,1] + , DynamicClusterShape + ] + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types_mixed[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + math_instructions_2sm = [ + MathInstruction( + [128, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1] + , DynamicClusterShape + ] + + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1] + , DynamicClusterShape + ] + + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 4 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + if math_inst.instruction_shape[0] == 128: + epi_schedule = EpilogueScheduleType.TmaWarpSpecialized2Sm + else: + epi_schedule = EpilogueScheduleType.ScheduleAuto + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_types_mixed = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + }, + ] + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_types_mixed, + [[KernelScheduleType.TmaWarpSpecialized2SmSm100, epi_schedule]], tile_schedulers=tile_schedulers) + + +def GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # void_c + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + MathInstruction( + [128, 128, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + MathInstruction( + [256, 128, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 16], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # void_c + { + "a_type" : DataType.f16, + "b_type" : DataType.f16, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : DataType.f16, + "b_type" : DataType.f16, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + MathInstruction( + [128, 128, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + MathInstruction( + [256, 128, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 32], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # void_c + { + "a_type" : DataType.s8, + "b_type" : DataType.s8, + "c_type" : DataType.void, + "d_type" : DataType.s8, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : DataType.s8, + "b_type" : DataType.s8, + "c_type" : DataType.s8, + "d_type" : DataType.s8, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + MathInstruction( + [128, 128, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add)] + + math_instructions_2sm = [ + MathInstruction( + [256, 128, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + kernel_data_types = [ + # NOTE: a/b type in kernel will be overwrite below. + #* void_c + # f8_f8_f32_void_f16 + { + "a_type" : DataType.e4m3, + "b_type" : DataType.e4m3, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + #* non-void_c + # f8_f8_f32_f16_f8 + { + "a_type" : DataType.e4m3, + "b_type" : DataType.e4m3, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : DataType.f32, + "epi_type" : DataType.f32, + }, + ] + + math_instructions_1sm = [ + # Runtime DType + MathInstruction( + [128, 128, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + # Runtime DType + MathInstruction( + [256, 128, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update input AB type + kernel_data_type["a_type"] = math_inst.element_a + kernel_data_type["b_type"] = math_inst.element_b + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + for kernel_data_type in kernel_data_types: + # Update input AB type + kernel_data_type["a_type"] = math_inst.element_a + kernel_data_type["b_type"] = math_inst.element_b + + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_copy = copy.deepcopy(layouts) + for layout in layouts_copy: + # alignment for a, 2 for sparsity + layout[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_copy, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +def GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + # Alignment requirement will be over-write below + [[LayoutType.RowMajor, -1], [LayoutType.ColumnMajor, -1], [LayoutType.RowMajor, -1]], + ] + + thor_sm = ThorSMRenumbering(cuda_version) + + min_cc = 100 + max_cc = thor_sm + + tile_schedulers = [ + TileSchedulerType.Default, TileSchedulerType.StreamK + ] + + math_instructions_1sm = [ + # Runtime Dtype + MathInstruction( + [128, 128, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + + MathInstruction( + [128, 128, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [128, 256, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + math_instructions_2sm = [ + # Runtime DType + MathInstruction( + [256, 128, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f4, DataType.f4, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + + MathInstruction( + [256, 128, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + MathInstruction( + [256, 256, 64], + DataType.f6, DataType.f6, DataType.f32, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add), + ] + + # 1xSM MMA kernels + for math_inst in math_instructions_1sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_1sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_1sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else cluster_shape + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_1sm[0], + math_inst.instruction_shape[1] * multiplier_1sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_1sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + # void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + ] + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_filtered = [] + for layout in layouts: + layout_filter = copy.deepcopy(layout) + # * A_K : Logical TileShape_K % 256 == 0 + # * A_M : TileShape_M % 128 == 0 + # * B_N : TileSize_N % 128 == 0 + # * B_K : TileSize_K % 128 == 0 + if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ + (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ + ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 128 == 0) or \ + (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): + # alignment for a, 2 for sparsity + layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + layouts_filtered.append(layout_filter) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]], + tile_schedulers=tile_schedulers) + + # 2xSM MMA kernels + for math_inst in math_instructions_2sm: + tile_descriptions = [] + for cluster_shape in sm100_cluster_shape_2sm: + if thor_sm in manifest.compute_capabilities_baseline : + if cluster_shape == [4,4,1] : + continue + multiplier_2sm = (1, 1, 1) if cluster_shape == DynamicClusterShape else (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + tile_descriptions.append( + TileDescription([ + math_inst.instruction_shape[0] * multiplier_2sm[0], + math_inst.instruction_shape[1] * multiplier_2sm[1], + math_inst.instruction_shape[2] * 2 * multiplier_2sm[2]], + 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + kernel_data_types = [ + # void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + # none void_c + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32, + }, + ] + + for kernel_data_type in kernel_data_types: + # Update layout alignment + # alignment for d might be different for each kernel_data_type + layouts_filtered = [] + for layout in layouts: + layout_filter = copy.deepcopy(layout) + # * A_K : Logical TileShape_K % 256 == 0 + # * A_M : TileShape_M % 128 == 0 + # * B_N : TileSize_N % 256 == 0 + # * B_K : TileSize_K % 128 == 0 + if ((layout_filter[0][0] == LayoutType.RowMajor and (math_inst.instruction_shape[2] * 2) % 256 == 0) or \ + (layout_filter[0][0] == LayoutType.ColumnMajor and math_inst.instruction_shape[0] % 128 == 0)) and \ + ((layout_filter[1][0] == LayoutType.RowMajor and math_inst.instruction_shape[1] % 256 == 0) or \ + (layout_filter[1][0] == LayoutType.ColumnMajor and (math_inst.instruction_shape[0] * 2) % 128 == 0)): + # alignment for a, 2 for sparsity + layout_filter[0][1] = get_tma_alignment_elt(kernel_data_type["a_type"]) * ( 2 if layout[0][0] == LayoutType.RowMajor else 1) + # alignment for b + layout_filter[1][1] = get_tma_alignment_elt(kernel_data_type["b_type"]) + # alignment for d + layout_filter[2][1] = get_tma_alignment_elt(kernel_data_type["d_type"]) + layouts_filtered.append(layout_filter) + + CreateSparseGemmUniversal3xOperator(manifest, layouts_filtered, tile_descriptions, [kernel_data_type], + [[KernelScheduleType.SparseTmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]], + tile_schedulers=tile_schedulers) + +# Conv Utility functions +def make_dims_and_alignments_triple(dim: int, bit_per_element_A: int, bit_per_element_B: int, bit_per_element_C: int): + bit_alignment_required_by_tma = 128 + return ((dim, bit_alignment_required_by_tma // bit_per_element_A), # A + (dim, bit_alignment_required_by_tma // bit_per_element_B), # B + (dim, bit_alignment_required_by_tma // bit_per_element_C)) # C + +def make_math_instruction_w_output(data_types: Tuple[DataType, DataType, DataType, DataType], + instruction_shape: Tuple[int, int, int]) -> (MathInstruction, DataType): + default_opcode = OpcodeClass.TensorOp + default_math_op = MathOperation.multiply_add + [A_data_type, B_data_type, Acc_data_type, Out_data_type] = data_types + return (MathInstruction( + instruction_shape, + A_data_type, B_data_type, Acc_data_type, + default_opcode, + default_math_op + ), Out_data_type) + +""" +Generate CUTLASS 3 convolution kernel(s) for SM100. + +This is meant to be called from GenerateSM100. +""" +def GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version, + log_indent_level: int = 0): + log_debug_line('GenerateSM100_TensorOp_16b_UMMA_conv3x', log_indent_level) + log_indent_level = log_indent_level + 1 + + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + thor_sm = ThorSMRenumbering(cuda_version) + + minimum_compute_capability = 100 + maximum_compute_capability = thor_sm + + spatial_dims = [2, 3] + + conv_kinds = [ + ConvKind.Fprop, + ConvKind.Dgrad, + ConvKind.Wgrad + ] + + stages = 0 # zero means "deduce the number of stages automatically" + + data_types_and_instruction_shapes_1sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (64, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (64, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 256, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (64, 128, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 128, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 256, 16)), + ] + math_instructions_w_output_1sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_1sm) + + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]] + + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]] + + # tile_descriptions is a 2-level list. + # Each inner list is for each cluster shape. + for math_inst, output_type in math_instructions_w_output_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + cluster_multiplier = cluster_shape + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + # It's typical to get the data types from the math instruction. + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + for conv_kind in conv_kinds: + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = conv_kind, + log_indent_level = log_indent_level) + + data_types_and_instruction_shapes_2sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (128, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f16, DataType.f16), (256, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 128, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (128, 256, 16)), + ((DataType.f16, DataType.f16, DataType.f32, DataType.f16), (256, 256, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 128, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (128, 256, 16)), + ((DataType.bf16, DataType.bf16, DataType.f32, DataType.bf16), (256, 256, 16)), + ] + math_instructions_w_output_2sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_2sm) + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]] + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]] + + for math_inst, output_type in math_instructions_w_output_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + cluster_multiplier = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + # It's typical to get the data types from the math instruction. + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + for conv_kind in conv_kinds: + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = conv_kind, + log_indent_level = log_indent_level) + +def GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version, + log_indent_level: int = 0): + # Instantiate Fp8 Fprop kernels with e4m3 A/B, f32 Acc, e4m3/bf16/f16/f32 C/D + log_debug_line('GenerateSM100_TensorOp_fp8_UMMA_conv3x', log_indent_level) + log_indent_level = log_indent_level + 1 + + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + thor_sm = ThorSMRenumbering(cuda_version) + + minimum_compute_capability = 100 + maximum_compute_capability = thor_sm + + spatial_dims = [2, 3] + stages = 0 # zero means "deduce the number of stages automatically" + + data_types_and_instruction_shapes_1sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (64, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 256, 32)), + ] + math_instructions_w_output_1sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_1sm) + + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1],[4,4,1]] + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_1sm = [[1,1,1], [1,2,1], [1,4,1]] + + for math_inst, output_type in math_instructions_w_output_1sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_1sm: + cluster_multiplier = cluster_shape + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = ConvKind.Fprop, + log_indent_level = log_indent_level) + + data_types_and_instruction_shapes_2sm = [ + # ((A,B,Acc,C/D), (InstM,InstN,InstK)) + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.e4m3), (256, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f16), (256, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.bf16), (256, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 128, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (128, 256, 32)), + ((DataType.e4m3, DataType.e4m3, DataType.f32, DataType.f32), (256, 256, 32)), + ] + math_instructions_w_output_2sm = map(lambda x: make_math_instruction_w_output(*x), + data_types_and_instruction_shapes_2sm) + + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1], [4,4,1]] + if thor_sm in manifest.compute_capabilities_baseline : + cluster_shapes_2sm = [[2,1,1], [2,2,1], [2,4,1], [4,1,1], [4,2,1]] + + for math_inst, output_type in math_instructions_w_output_2sm: + tile_descriptions = [] + for cluster_shape in cluster_shapes_2sm: + cluster_multiplier = (cluster_shape[0] // 2, cluster_shape[1], cluster_shape[2]) + # Unlike SM90, SM100 tile shape calculation includes cluster shape. + tile_shape = [ + math_inst.instruction_shape[0] * cluster_multiplier[0], + math_inst.instruction_shape[1] * cluster_multiplier[1], + math_inst.instruction_shape[2] * 4 * cluster_multiplier[2] + ] + warp_count = [4, 1, 1] + tile_description = TileDescription( + tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, + cluster_shape) + tile_descriptions.append(tile_description) + + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : output_type, + "d_type" : output_type, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + dims_and_alignments = [make_dims_and_alignments_triple(dim, DataTypeSize[data_type["a_type"]], DataTypeSize[data_type["b_type"]], DataTypeSize[data_type["d_type"]]) for dim in spatial_dims] + + # Schedules + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100 + epilogue_schedule = EpilogueScheduleType.ScheduleAuto + schedule_pairs = [ + (mainloop_schedule, epilogue_schedule) + ] + + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = tile_descriptions, + data_types = data_type, + schedule_pairs = schedule_pairs, + conv_kind = ConvKind.Fprop, + log_indent_level = log_indent_level) + +def GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # SM120 MMA with mixed F4/F6/F8 inputs + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]] + ] + + instruction_sizes = [ + [16, 8, 32] + ] + + tile_sizes = [ + [128, 128, 128] + ] + + cluster_shape = [1,1,1] + + ab_types = [ + DataType.e2m1, + DataType.e2m3, + DataType.e3m2, + DataType.e5m2, + DataType.e4m3, + ] + + acc_types = [ DataType.f32 ] + + def is_pingpong(kernel_schedule): + if kernel_schedule == KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: + return True + else: + return False + + def tile_schedulers(sfdtype, kernel_schedule): + # Pingpong kernel schedule doesn't support stream-K. + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K + if is_pingpong(kernel_schedule): + return [TileSchedulerType.Default] + elif sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 120 + max_cc = 121 + + epi_type = DataType.f32 + + math_instructions = [] + + kernel_schedules = [ + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120, + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120 + ] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes, ab_types, ab_types, acc_types): + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + for math_inst in math_instructions: + tile_descriptions = [] + for tile_size in tile_sizes: + tile_descriptions.append( + TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e3m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type, kernel_schedule in product(data_types, kernel_schedules): + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule), + gemm_kind = GemmKind.BlockScaledUniversal3x + ) + +def GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version): + # SM120 MMA with with F4 + block scale + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 32], [LayoutType.ColumnMajor, 32], [LayoutType.RowMajor, 0]] + ] + + instruction_sizes = [ + [16, 8, 64] + ] + + tile_sizes_cooperative = [ + [128, 128, 128], + [128, 128, 256], + [256, 128, 128] + ] + + tile_sizes_pingpong = [ + [128, 128, 128], + [128, 128, 256] + ] + + cluster_shape = [1,1,1] + + ab_types = [ + DataType.e2m1 + ] + + sf_types = [ + DataType.ue4m3, + DataType.ue8m0 + ] + + acc_types = [ DataType.f32 ] + + def is_pingpong(kernel_schedule): + if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120 or \ + kernel_schedule == KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: + return True + else: + return False + + def is_nvf4(kernel_schedule): + if kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120 or \ + kernel_schedule == KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: + return True + else: + return False + + def tile_schedulers(sfdtype, kernel_schedule): + # Pingpong kernel schedule doesn't support stream-K. + # Only use the stream-K scheduler for non-void SFD to limit kernel count. When SFD is void, + # the epilogue is the traditional linear combination, for which we already have tests with stream-K + if is_pingpong(kernel_schedule): + return [TileSchedulerType.Default] + elif sfdtype["type"] == DataType.void: + return [TileSchedulerType.Default] + else: + return [TileSchedulerType.Default, TileSchedulerType.StreamK] + + min_cc = 120 + max_cc = 121 + + epi_type = DataType.f32 + + math_instructions = [] + + kernel_schedules = [ + KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120, + KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120, + KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120, + KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120 + ] + + for instr_size, a_type, b_type, acc_type, sf_type in product(instruction_sizes, ab_types, ab_types, acc_types, sf_types): + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + sf_type) + ) + + for math_inst in math_instructions: + for kernel_schedule in kernel_schedules: + tile_descriptions = [] + tile_sizes = tile_sizes_pingpong if is_pingpong(kernel_schedule) else tile_sizes_cooperative + for tile_size in tile_sizes: + # nvf4 kernel only supports ue4m3 SF + # mxf4 kernel only supports ue8m0 SF + if (math_inst.element_scale_factor == DataType.ue4m3 and is_nvf4(kernel_schedule)) or \ + (math_inst.element_scale_factor == DataType.ue8m0 and not is_nvf4(kernel_schedule)): + tile_descriptions.append( + TileDescription(tile_size, 0, [4, 1, 1], math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.void, "vector_size": None, "layout" : None} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 16, "layout" : LayoutType.RowMajor} + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e2m1, + "acc_type" : math_inst.element_accumulator, + "epi_type" : epi_type, + "sf_type" : math_inst.element_scale_factor, + "sfd_type" : {"type": DataType.ue8m0, "vector_size": 32, "layout" : LayoutType.RowMajor} + } + ] + + # Set alignment d based on Destination format. + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_types[0]["d_type"]] + + for data_type in data_types: + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(data_type["sfd_type"], kernel_schedule), + gemm_kind = GemmKind.BlockScaledUniversal3x + ) + +def GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 256], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 0]] + ] + + tile_sizes = [ + [128, 128, 256] + ] + + cluster_shape = [1,1,1] + + warp_count = [4, 2, 1] + + acc_types = [ DataType.f32 ] + + instruction_sizes_mxf8f6f4 = [ + [16, 8, 64] + ] + + ab_types_mxf8f6f4 = [ + DataType.e2m1, + #DataType.e2m3, + DataType.e3m2, + #DataType.e5m2, + DataType.e4m3, + ] + + def tile_schedulers(kernel_schedule): + return [TileSchedulerType.Default] + + min_cc = 120 + max_cc = 121 + + kernel_schedules = [ + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120, + ] + + math_instructions_mxf8f6f4 = [] + + for instr_size, a_type, b_type, acc_type in product(instruction_sizes_mxf8f6f4, ab_types_mxf8f6f4, ab_types_mxf8f6f4, acc_types): + math_instructions_mxf8f6f4.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.SparseTensorOp, + MathOperation.multiply_add) + ) + + # Create gemm operator for mxf8f6f4 + for math_inst in math_instructions_mxf8f6f4: + tile_descriptions_mxf8f6f4 = [] + for tile_size in tile_sizes: + tile_descriptions_mxf8f6f4.append( + TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape)) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + ] + + for data_type, kernel_schedule in product(data_types, kernel_schedules): + # Set alignment d based on Destination format + for layout in layouts: + layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]]) + # Create gemm operator + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_mxf8f6f4, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(kernel_schedule), + gemm_kind = GemmKind.SparseUniversal3x) + +def GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.BlockwiseUniversal3x): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 8): + return + + layouts = [ + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.RowMajor, 16]], + [[LayoutType.RowMajor, 128], [LayoutType.ColumnMajor, 128], [LayoutType.ColumnMajor, 16]] + ] + + cooperative_tile_sizes = [ + [128, 128, 128] + ] + pingpong_tile_sizes = [ + [64, 128, 128] + ] + + def get_tile_sizes(kernel_scheduler): + if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: + return pingpong_tile_sizes + return cooperative_tile_sizes + + def get_warp_count(kernel_scheduler): + if kernel_scheduler == KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: + return [2, 2, 1] + return [4, 2, 1] + + def get_sf_sizes(tile_size): + sf_sizes = [] + for vec_m in [1, 128]: + if tile_size[0] % vec_m > 0: + continue + for vec_n in [1, 128]: + if tile_size[1] % vec_m > 0: + continue + sf_sizes.append( + [vec_m, vec_n, 128] + ) + return sf_sizes + + cluster_shape = [1,1,1] + + acc_types = [ DataType.f32 ] + + instruction_sizes = [ + [16, 8, 32] + ] + + def tile_schedulers(kernel_schedule): + return [TileSchedulerType.Default] + + min_cc = 120 + max_cc = 121 + + kernel_schedulers = [ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120, + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120 + ] + + ab_types = [ + [DataType.e4m3, DataType.e4m3], + [DataType.e4m3, DataType.e5m2] + ] + + math_instructions = [] + + for instr_size, ab_type, acc_type in product(instruction_sizes, ab_types, acc_types): + a_type, b_type = ab_type + math_instructions.append( + MathInstruction( + instr_size, + a_type, b_type, acc_type, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + # Create gemm operator for mxf8f6f4 + for kernel_schedule in kernel_schedulers: + tile_sizes = get_tile_sizes(kernel_schedule) + warp_count = get_warp_count(kernel_schedule) + for math_inst in math_instructions: + tile_descriptions = [] + for tile_size in tile_sizes: + sf_sizes = get_sf_sizes(tile_size) + for sf_size in sf_sizes: + tile_descriptions.append( + TileDescription(tile_size, 0, warp_count, math_inst, min_cc, max_cc, cluster_shape, + explicit_vector_sizes=sf_size) + ) + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.void, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + ] + + for data_type in data_types: + # Set alignment d based on Destination format + for layout in layouts: + layout[2][1] = int(128 // DataTypeSize[data_type["d_type"]]) + # Create gemm operator + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[kernel_schedule, EpilogueScheduleType.ScheduleAuto]], + tile_schedulers = tile_schedulers(kernel_schedule), + gemm_kind = gemm_kind) + +def GenerateSM100(manifest, cuda_version): + arch_family_cc = ['100f', '101f', '103a'] + if CudaToolkitVersionSatisfies(cuda_version, 13, 0): + for old_cc, new_cc in [('101f', '110f')]: + arch_family_cc = [cc.replace(old_cc, new_cc) for cc in arch_family_cc] + + # + # Dense Gemm + # + GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version) + + GenerateSM100_TensorOp_32b_UMMA_gemm(manifest, cuda_version) + + if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): + GenerateSM100_TensorOp_int8_UMMA_gemm(manifest, cuda_version) + + GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version) + # grouped GEMM + GenerateSM100_TensorOp_fp8_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM100_TensorOp_16b_UMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + + # StreamK is included in regular generation + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) + + # Blockwise kernels + GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version) + GenerateSM100_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) + + # + # Sparse Gemm + # + GenerateSM100_SparseTensorOp_32b_UMMA_gemm(manifest, cuda_version) + GenerateSM100_SparseTensorOp_16b_UMMA_gemm(manifest, cuda_version) + if not bool(set(manifest.compute_capabilities_feature_set).intersection(arch_family_cc)): + GenerateSM100_SparseTensorOp_int8_UMMA_gemm(manifest, cuda_version) + GenerateSM100_SparseTensorOp_fp8_UMMA_gemm(manifest, cuda_version) + GenerateSM100_SparseTensorOp_mixed_8bits_UMMA_gemm(manifest, cuda_version) + + # + # Block Scaled Gemm + # + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM100_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM100_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + + GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM103_TensorOp_fp4_ultra_UMMA_gemm_with_block_scaled(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockScaledUniversal3x) + # + # Conv + # + GenerateSM100_TensorOp_16b_UMMA_conv3x(manifest, cuda_version) + GenerateSM100_TensorOp_fp8_UMMA_conv3x(manifest, cuda_version) + + +def GenerateSM120(manifest, cuda_version): + # StreamK is included in regular generation # + # + # Dense Block Scaled Gemm + # + GenerateSM120_TensorOp_mixed_8bits_UMMA_gemm_with_block_scaled(manifest, cuda_version) + GenerateSM120_TensorOp_fp4_UMMA_gemm_with_block_scaled(manifest, cuda_version) + + # + # Sparse Gemm + # + GenerateSM120_Sparse_TensorOp_gemm(manifest, cuda_version) + GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version) + GenerateSM120_TensorOp_fp8_UMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) + +################################################################################################### + +def GenerateSM90_Conv3x(manifest, cuda_version, + log_indent_level: int = 0): + """ + Generate CUTLASS 3 convolution kernel(s) for SM90. + + This is meant to be called from GenerateSM90. + """ + log_debug_line('GenerateSM90_Conv3x', log_indent_level) + log_indent_level = log_indent_level + 1 + + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + minimum_compute_capability = 90 + maximum_compute_capability = 90 + + spatial_dims = (2, 3) + + # MMA shapes (MMA_M, MMA_N, MMA_K): + # + # Different hardware MMA instructions may have different MMA shapes. + # This function may generate kernels with different MMA shapes for + # different data types, either because the hardware only supports + # certain shapes for certain types, or for performance reasons + # (CUTLASS doesn't need to generate all valid kernels for the + # profiler library, just the best-performing ones). + # + # The kernel names refer to tile shapes (TILE_M, TILE_N, TILE_K) + # instead of MMA shapes. For SM >= 90 kernels, TILE_K = 4 * MMA_K, + # where 4, the "number of MMA instructions per tile," is determined + # through some combination of modeling and experiment. + # + # For performance on sm90, generally CUTLASS generates 64x128 + # instead of 128x64. + mma_64x64x16 = ( 64, 64, 16) + mma_64x64x8 = ( 64, 64, 8) + + num_mma_per_tile = 4 + + # Cluster shapes (1, 1, 1) and (2, 2, 1) are valid, + # but not included, because they tend not to perform as well. + cluster_shapes = ( + (2, 1, 1), + (1, 2, 1), + ) + + fp16 = DataType.f16 + bf16 = DataType.bf16 + fp32 = DataType.f32 + s8 = DataType.s8 + s32 = DataType.s32 + + # When generating kernels, the usual way is to specify 4 types, + # (A, B, Acc, C/D). Tests instead have 5 types, + # (ElementAct, ElementFlt, ElementOut, ElementAcc, ElementCompute), + # where ElementCompute is also called 'epi_type', + # and corresponds to the type of epilogue activations. + # This script maps tests' 5 types to 4 types + # by making ElementCompute the same as ElementOut. + + fp16_fp32_fp16_fp32 = { + 'a_type': fp16, # ElementAct(ivation) + 'b_type': fp16, # ElementF(i)lt(er) + 'c_type': fp32, # ElementAcc + 'd_type': fp32, # ElementOut (used only by CollectiveEpilogue) + 'acc_type': fp16, # ElementAcc + 'epi_type': fp32, # ElementCompute (used only by CollectiveEpilogue) + 'alignment_A': 8, # tma alignment elements of A + 'alignment_B': 8, # tma alignment elements of B + 'alignment_C': 4, # tma alignment elements of C + } + fp16_fp32_fp32_fp32 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 4, + } + fp32_fp32_fp32_fp32 = { + 'a_type': fp32, + 'b_type': fp32, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 4, + 'alignment_B': 4, + 'alignment_C': 4, + } + s8_s32_s32_s32 = { + 'a_type': s8, + 'b_type': s8, + 'c_type': s32, + 'd_type': s32, + 'acc_type': s32, + 'epi_type': s32, + 'alignment_A': 16, + 'alignment_B': 16, + 'alignment_C': 4, + } + + # Other NVIDIA libraries may have the habit of specifying data types like this. + bf16bf16_bf16f32_f32 = { + 'a_type': bf16, + 'b_type': bf16, + 'c_type': fp32, + 'd_type': fp32, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 4, + } + f16f16_f16f16_f16 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp16, + 'd_type': fp16, + 'acc_type': fp16, + 'epi_type': fp16, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 8, + } + f16f16_f16f32_f32 = { + 'a_type': fp16, + 'b_type': fp16, + 'c_type': fp16, + 'd_type': fp16, + 'acc_type': fp32, + 'epi_type': fp32, + 'alignment_A': 8, + 'alignment_B': 8, + 'alignment_C': 8, + } + f32f32_tf32f32_f32 = fp32_fp32_fp32_fp32 + + i8i8_i8i32_f32 = { + 'a_type': s8, + 'b_type': s8, + 'c_type': s32, + 'd_type': s32, + 'acc_type': s32, + 'epi_type': s32, + 'alignment_A': 16, + 'alignment_B': 16, + 'alignment_C': 4, + } + + # Each element in the outermost iterable is one combination of + # + # (ConvKind, spatial_dimension, data_types, byte_alignments, mma_sizes, cluster_sizes) + # + # for which to generate a kernel. spatial_dimension is the spatial + # dimension of the convolution: either 1, 2, or 3. byte_alignments + # is a triple of required minimum byte alignments for A, B, and C. + # + # Note that itertools functions produce a single-pass generator. + # The code doesn't need a multipass iterable, but if one did, one + # could call `tuple` or `list` on the generator. + # + # While this happens to use the same cluster sizes for each element, + # the code doesn't require that. Different convolution kinds, data + # types, or mma sizes might have different optimal cluster sizes. + combinations_of_parameters = chain( + # The following are all the kernels exercised in the unit tests. + # Please try to keep in sync with the unit tests. + product( + ( + ConvKind.Fprop, + ), + spatial_dims, + ( + fp16_fp32_fp16_fp32, + fp16_fp32_fp32_fp32, + s8_s32_s32_s32, + ), + ( + mma_64x64x16, + ), + cluster_shapes + ), + product( + ( + ConvKind.Fprop, + ), + spatial_dims, + ( + fp32_fp32_fp32_fp32, + ), + ( + mma_64x64x8, + ), + cluster_shapes + ), + product( + ( + ConvKind.Dgrad, + ConvKind.Wgrad + ), + spatial_dims, + ( + fp16_fp32_fp16_fp32, + fp16_fp32_fp32_fp32, + ), + ( + mma_64x64x16, + ), + cluster_shapes + ), + # Kernels not necessarily in the unit tests, but used elsewhere + # and thus useful to have generated for profiling. They may + # duplicate kernels above. All of them are 2-D. In general, + # CUTLASS prefers 64 x 128 to 128 x 64 on sm90, even if the + # hardware permits 128 x 64. + ( + # Fprop + # + # bf16bf16_bf16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, bf16bf16_bf16f32_f32, (256, 128, 16), (2, 1, 1)), + # + # f16f16_f16f16_f16 + # + # cluster shape (1, 1, 1) + # + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 64, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 64, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 128, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 256, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, ( 64, 256, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 128, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 256, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (128, 256, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 64, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 64, 16), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 128, 8), (1, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f16_f16, (256, 128, 16), (1, 1, 1)), + # + # f16f16_f16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 192, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 192, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 96, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 96, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f16f16_f16f32_f32, (256, 128, 16), (2, 1, 1)), + # + # f32f32_tf32f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (128, 192, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Fprop, 2, f32f32_tf32f32_f32, (256, 96, 8), (2, 1, 1)), + # + # i8i8_i8i32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (128, 256, 32), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (256, 128, 16), (2, 1, 1)), + (ConvKind.Fprop, 2, i8i8_i8i32_f32, (256, 128, 32), (2, 1, 1)), + # + # Dgrad + # + # bf16bf16_bf16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, bf16bf16_bf16f32_f32, (256, 128, 16), (2, 1, 1)), + # + # f16f16_f16f16_f16 + # + # cluster shape (1, 1, 1) + # + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 64, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 64, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 128, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 256, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, ( 64, 256, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 128, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 256, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (128, 256, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 64, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 64, 16), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 128, 8), (1, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f16_f16, (256, 128, 16), (1, 1, 1)), + # + # f16f16_f16f32_f32 + # + # cluster shape (2, 1, 1) + # + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (128, 256, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (128, 256, 16), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (256, 128, 8), (2, 1, 1)), + (ConvKind.Dgrad, 2, f16f16_f16f32_f32, (256, 128, 16), (2, 1, 1)), + ), + ) + + # SM >= 90 kernels don't actually use warp_count, but the + # TileDescription class needs it. The 4 in the default + # warp_count has nothing to do with num_mma_per_tile. + warp_count = [4, 1, 1] + + stages = 0 # zero means "deduce the number of stages automatically" + + mainloop_schedule = KernelScheduleType.ImplicitTmaWarpSpecializedSm90 + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized + schedule_pairs = ( + (mainloop_schedule, epilogue_schedule), + ) + tile_schedulers = ( + TileSchedulerType.Default, # -> void + ) + + def make_math_instruction(data_types: Dict[str, DataType], + mma_shape: Tuple[int, int, int]) -> MathInstruction: + default_opcode = OpcodeClass.TensorOp + default_math_op = MathOperation.multiply_add + return MathInstruction( + mma_shape, + data_types['a_type'], data_types['b_type'], data_types['c_type'], + default_opcode, + default_math_op + ) + + for (conv_kind, spatial_dim, data_types, mma_shape, cluster_shape) in combinations_of_parameters: + math_inst = make_math_instruction(data_types, mma_shape) + tile_shape = (mma_shape[0], mma_shape[1], num_mma_per_tile * mma_shape[2]) + tile_description = TileDescription(tile_shape, stages, warp_count, math_inst, + minimum_compute_capability, maximum_compute_capability, cluster_shape) + assert(isinstance(spatial_dim, int)) + dims_and_alignments = ( + ( + (spatial_dim, data_types['alignment_A']), + (spatial_dim, data_types['alignment_B']), + (spatial_dim, data_types['alignment_C']), + ), + ) + CreateConvOperator3x(manifest, + dims_and_alignments = dims_and_alignments, + tile_descriptions = [tile_description], + data_types = data_types, + schedule_pairs = schedule_pairs, + tile_schedulers = tile_schedulers, + conv_kind = conv_kind, + log_indent_level = log_indent_level) + +def GenerateSM90(manifest, cuda_version): + GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_mixed_dtype_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_1684(manifest, cuda_version) + GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version, gemm_kind=GemmKind.GroupedUniversal3x) + GenerateSM90_TensorOp_1684_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version) + GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version) + GenerateSM90_Conv3x(manifest, cuda_version) + GenerateSM90_SparseTensorOp_16b_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_SparseTensorOp_tf32_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_SparseTensorOp_int8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_SparseTensorOp_fp8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_gemm_with_blockwise(manifest, cuda_version, gemm_kind=GemmKind.GroupedBlockwiseUniversal3x) + +################################################################################################### + +def numeric_log_level(log_level: str) -> int: + """ + Converts the string identifier of the log level + into the numeric identifier used in setting the log level. + + :param x: string representation of log level (e.g., 'INFO', 'DEBUG') + :type x: str + + :return: numeric representation of log level + :rtype: int + """ + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f'Invalid log level: {log_level}') + return numeric_level + +# This function for defining the ArgumentParser is used to make it easy for the CUTLASS Python interface +# to leverage the functionality in this file without running this script via a shell prompt. +def define_parser(): + parser = argparse.ArgumentParser(description="Generates device kernel registration code for CUTLASS Kernels") + parser.add_argument("--operations", default="all", help="Specifies the operation to generate (gemm, all)") + parser.add_argument("--build-dir", default=".", required=False, help="CUTLASS top-level build directory") + parser.add_argument("--curr-build-dir", default=".", help="CUTLASS current build directory. cmake files will be emitted in this directory") + parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") + parser.add_argument("--architectures", default='53;60;61;70;75;80;90;100', help="Target compute architectures") + parser.add_argument("--kernels", default='', help='Comma-delimited list to filter kernels by name. ' + + 'Specifying this as \"all\" includes ALL the kernels, ' + + 'while not specifying this includes only the default set of kernels.') + parser.add_argument("--ignore-kernels", default='', help='Comma-delimited list of kernels ' + + 'to exclude from build. For backwards compatibility reasons, ' + + 'this option only takes effect if --kernels is set to a nonempty value.') + parser.add_argument("--exclude-kernels", default='', help='Comma-delimited list of kernels ' + + 'to exclude from build. In contrast to --ignore-kernels, ' + + 'this option always takes effect, ' + + 'whether or not --kernels is set to a nonempty value. ' + + 'It also can exclude kernels from the filter file ' + + '(see --kernel-filter-file option below).') + parser.add_argument("--filter-by-cc", default='True', type=str, help='If enabled, kernels whose compute capability range is not satisfied by the build target are excluded.') + parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit") + parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file') + parser.add_argument('--heuristics-problems-file', type=str, default=None, required=False, help='Full path of heuristics problem size description file, as a json list') + parser.add_argument('--heuristics-testlist-file', type=str, default=None, required=False, help='Full path of heuristics testlist CSV file, to be passed to cutlass_profiler') + parser.add_argument('--heuristics-gpu', type=str, default=None, required=False, help='GPU to use for evaluating heuristics offline. None or `auto` to autodetect using cuda', choices=['', 'auto', 'H100_SXM', 'H100_PCIE', 'H100_NVL', 'H200_SXM', 'H20_SXM', 'B200', 'GB200_NVL', 'RTX_5080', 'RTX_5090', 'RTX_PRO_6000']) + parser.add_argument('--heuristics-configs-per-problem', type=int, default=10, required=False, help='Number of kernel configs to generate for each problem in the problem list') + parser.add_argument('--heuristics-restrict-kernels', action='store_true', help='Restrict heuristics mode to use only the default set of kernels emitted by generator.py') + parser.add_argument('--selected-kernel-list', type=str, default=None, required=False, + help='Specify the output log file containing all enabled kernels in this build') + parser.add_argument("--interface-dir", default=None, required=False, help="Interface header to kernels") + parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures") + parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, + help='Logging level to be used by the generator script') + parser.add_argument('--instantiation-level', type=str, default="", required=False, help="Instantiation level for SM90 kernels. Set to `max` and make sure `--kernels` is not empty to generate all possible configurations.") + _add_package_disablement_flag(parser) + return parser + + +if __name__ == "__main__": + parser = define_parser() + args = parser.parse_args() + + # Set the logging level based on the user-provided `--log-level` command-line option + logging.basicConfig(level=args.log_level) + + manifest = Manifest(args) + + archs = args.architectures.split(';') + + if args.heuristics_problems_file: + filter_manifest_and_write_heuristics_file(manifest, args) + + GenerateSM50(manifest, args.cuda_version) + GenerateSM60(manifest, args.cuda_version) + GenerateSM61(manifest, args.cuda_version) + GenerateSM70(manifest, args.cuda_version) + GenerateSM75(manifest, args.cuda_version) + GenerateSM80(manifest, args.cuda_version) + GenerateSM89(manifest, args.cuda_version) + GenerateSM90(manifest, args.cuda_version) + + blackwell_arch_list = [ + "100a", "100f", + "101a", "101f", + "103a", "103f", + "110a", "110f", + "120a", "120f", + "121a", "121f", + ] + blackwell_enabled_arch = any(arch in blackwell_arch_list for arch in archs) + if blackwell_enabled_arch: + GenerateSM100(manifest, args.cuda_version) + GenerateSM120(manifest, args.cuda_version) + + if 'library' in args.generator_target.split(','): + manifest.emit(GeneratorTarget.Library) + + if 'kernel_testlist_l0' in args.generator_target.split(','): + emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L0") + + if 'kernel_testlist_l1' in args.generator_target.split(','): + emit_gemm_kernel_testlist(manifest, args.curr_build_dir, args.architectures, "functional_L1") + + if args.selected_kernel_list is not None: + if len(manifest.selected_kernels) > 0: + with open(args.selected_kernel_list, 'w') as file_writer: + for line in manifest.selected_kernels: + file_writer.write("%s\n" % line) + +################################################################################################### diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py new file mode 100644 index 0000000000000000000000000000000000000000..83421a06427acdc3b059855991cf95a1d2f118b3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py @@ -0,0 +1,415 @@ +################################################################################################# +# +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for selecting CUTLASS library kernels based on problem description +""" +import json +import csv + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.generator import * + from cutlass_library.heuristics_provider import * +except ImportError: + from library import * + from generator import * + from heuristics_provider import * + +try: + from .sm90_utils import ( + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) +except ImportError: + from sm90_utils import ( + get_valid_schedules, + generate_data_types_from_math_instruction, + fix_alignments, + ) + +_LOGGER = logging.getLogger(__name__) + +dtype_map = {v: k for k, v in DataTypeNames.items()} + +def serialize_heuristics_results_to_json(problems_with_configs, outfile_path): + """ + Utilitiy function to write heuristics results to a json file for debug + + args: + problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict + outfile_path: Outfile path + + returns: + None + """ + pc_copy = problems_with_configs.copy() + for p in pc_copy: + for k, v in p.items(): + if isinstance(v, DataType): + p[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + p[k] = ShortLayoutTypeNames[v] + configs = p['configs'] + for c in configs: + for k, v in c.items(): + if isinstance(v, DataType): + c[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + c[k] = ShortLayoutTypeNames[v] + with open(outfile_path, 'w') as f: + json.dump(pc_copy, f, indent=2) + +def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None): + """ + Get heuristic-suggested GEMM kernel configurations for a single GEMM problem. + + args: + m, n, k: GEMM dimensions + batch_count: batch count + layouts: tuple of layouts of type LayoutType + use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions + count: Number of configs to return + provider: Heuristics provider to use + + returns: + A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys: + - 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size + - 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size + - 'stages': kernel pipeline stage count + - 'cluster_m', 'cluster_n', 'cluster_k': cluster size + - 'layout_a', 'layout_b': input tensor layouts of type LayoutType + - 'alignment_a', 'alignment_b': input tensor alignments, in count of elements + - 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType + - 'swizzle_size' : suggested threadblock swizzle + - 'split_k_slices': number of partitions of the k dimension for splitK + - 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n') + """ + if provider is None: + provider = MatmulHeuristics() + return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count) + +def get_gemm_configs(problems, provider=None, count=1): + """ + Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems. + + args: + problems: List of dictionaries describing GEMM problems with the following keys: + - 'm', 'n', 'k': Matrix dimensions (required) + - 'dtype_a': Data type of matrix A (required) + - 'dtype_b': Data type of matrix B (required) + - 'dtype_c': Data type of matrix C (default: None) + - 'dtype_d': Data type of matrix D (required) + - 'dtype_acc': Compute data type (default 'f32') + - 'layout': Operation layout (e.g. 'tnt') + - 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements) + - 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements) + - 'alpha': Scalar multiplier for A*B (default: 1.0) + - 'beta': Scalar multiplier for C (default: 0.0) + - 'batch_count': Number of GEMM operations in batch (default: 1) + - 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True) + provider: Heuristics provider to use + count: Number of configurations to return per problem (defualt: 1) + + returns: + A copy of the input dictionary, with key `configs` added containing the selected gemm configs + """ + ret = [] + + for problem in problems: + problem = problem.copy() + + try: + m = problem['m'] + n = problem['n'] + k = problem['k'] + dtype_a = problem['dtype_a'] + dtype_b = problem['dtype_b'] + dtype_d = problem['dtype_d'] + layout = problem['layout'] + except KeyError as e: + _LOGGER.error(f"Missing required parameter {e} for problem {problem}") + raise + + operation = problem.get('operation', 'gemm') + batch_count = problem.get('batch_count', 1) + dtype_acc = problem.get('dtype_acc', 'f32') + dtype_c = problem.get('dtype_c', None) + alpha = problem.get('alpha', 1.0) + beta = problem.get('beta', 0.0) + use_fast_acc = problem.get('use_fast_acc', True) + + if operation != OperationKindNames[OperationKind.Gemm]: + raise ValueError(f"Unsupported operation {operation}") + if not (len(layout) == 3 and all(c in "nt" for c in layout)): + raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}") + layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout) + + try: + dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()] + dtypes = tuple(dtype_map[dt] for dt in dtype_list) + except KeyError as dt: + _LOGGER.error(f"Unsupported data type: {dt}") + raise + + alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]]) + alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]]) + + configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider) + problem['configs'] = configs + + ret.append(problem) + + return ret + + +def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs): + """ + Generate CUTLASS operations based on the list of configs provided by the heuristic provider + + args: + manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest) + cuda_version: Cuda compiler version for generating cutlass operations + kernel_configs: list of configs generated by the heuristic + + returns: + (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations + """ + min_cc = 100 + max_cc = 101 + if manifest is None: + # Use a dummy manifest so we can use existing CreateGemmOperator functions + manifest = Manifest() + + configs = [] + operations = [] + for config in kernel_configs: + layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]]) + element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d'] + + # nvMMH assumes 2sm instruction for !(cluster_m % 2) + is_2sm = config['cluster_m'] % 2 == 0 + instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4] + math_instruction = MathInstruction( + instruction_shape, + element_a, element_b, element_accumulator, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ) + + data_types = [ + { + "a_type" : math_instruction.element_a, + "b_type" : math_instruction.element_b, + "c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator, + "d_type" : element_d, + "acc_type" : math_instruction.element_accumulator, + "epi_type" : math_instruction.element_accumulator, + } + ] + + tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k']) + tile_description = TileDescription( + [instruction_shape[0] * tile_multiplier[0], + instruction_shape[1] * tile_multiplier[1], + instruction_shape[2] * 4 * tile_multiplier[2]], + 0, + [4,1,1], + math_instruction, + min_cc, + max_cc, + cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k']) + ) + + schedules = [] + if is_2sm: + schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm]) + else: + schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm]) + + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x): + configs.append(config) + operations.append(o) + + + return configs, operations + + +def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs): + """ + Generate CUTLASS operations based on the list of configs provided by the heuristic provider + + args: + manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest) + cuda_version: Cuda compiler version for generating cutlass operations + kernel_configs: list of configs generated by the heuristic + + returns: + (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations + """ + min_cc, max_cc = 90, 90 + + if manifest is None: + # Use a dummy manifest so we can use existing CreateGemmOperator functions + manifest = Manifest() + + configs = [] + operations = [] + for config in kernel_configs: + + is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128) + layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1]) + element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d'] + + # instr shape and warp config are unused for emitting 3x collective builder code + dummy_instr_shape = [0, 0, 0] + math_instruction = MathInstruction( + dummy_instr_shape, + element_a, element_b, element_accumulator, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ) + + data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d) + if is_aligned: + layout = fix_alignments(data_types, layout, alignment_bits=128) + + # instr shape and warp config are unused for emitting 3x collective builder code + dummy_warp_count = [0, 0, 0] + tile_description = TileDescription( + [config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']], + 0, + dummy_warp_count, + math_instruction, + min_cc, + max_cc, + cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k']) + ) + + schedules, stream_k_schedules = get_valid_schedules( + tile_description=tile_description, + cuda_version=cuda_version, + is_aligned=is_aligned, + data_types=data_types, + instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic + layout=layout, + gemm_kind=GemmKind.Universal3x, + enable_fp8_fast_acc=config['use_fast_acc'] + ) + + if len(schedules): + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x): + configs.append(config) + operations.append(o) + + if len(stream_k_schedules): + for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, + stream_k_schedules, + tile_schedulers=[TileSchedulerType.StreamK]): + configs.append(config) + operations.append(o) + + + return configs, operations + +def filter_manifest_and_write_heuristics_file(manifest, args): + """ + Prune a manifest according to heuristics suggestions from the problems file + + args: + manifest: Cutlass manifest to prune + args: generator.py args, requires: + - args.heuristics_problems_file + - args.heuristics_gpu + - args.heuristics_testlist_file + + returns: + A list of dictionaries, each of which has information about an operation and a problem from the input problems + """ + heuristics_problems = [] + with open(args.heuristics_problems_file, 'r') as f: + heuristics_problems = json.load(f) + gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu + mmh = MatmulHeuristics(gpu=gpu) + if any(('100' in arch) for arch in args.architectures.split(';')): + mmh.set_cta_div_n(64) + problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem) + + all_configs_and_operations = [] + operations = [] + for problem in problems_with_configs: + if any('90' in arch for arch in args.architectures.split(';')): + problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs']) + if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')): + problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs']) + + operations += problem_operations + problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'} + with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)] + all_configs_and_operations += with_problem_size + + for operation in operations: + manifest.add_kernel_filter(f"^{operation.procedural_name()}$") + if not all_configs_and_operations: + raise Exception("No valid configurations generated") + write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file) + return all_configs_and_operations + +def write_profiler_testlist_to_csv(configs_list, outfile_path): + """ + Write a list of configs to a testlist to be consumed by cutlass_profiler + + args: + configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries + outfile_path: Outfile path + + returns: + None + """ + profiler_testlist = configs_list.copy() + for c in profiler_testlist: + for k, v in c.items(): + if isinstance(v, DataType): + c[k] = DataTypeNames[v] + elif isinstance(v, LayoutType): + c[k] = ShortLayoutTypeNames[v] + + with open(outfile_path, mode='w', newline='') as ofile: + k_names = profiler_testlist[0].keys() + + writer = csv.DictWriter(ofile, fieldnames=k_names) + writer.writeheader() + writer.writerows(profiler_testlist) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py new file mode 100644 index 0000000000000000000000000000000000000000..01a4112a34c87d73a792cce368fede96a9315ac1 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py @@ -0,0 +1,175 @@ +################################################################################################# +# +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Providers for kernel selection heuristics +""" + +import sys +import os +import glob +import logging +import ctypes +import functools + + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import DataType, LayoutType +except ImportError: + from library import DataType, LayoutType + +class MatmulHeuristics: + + def __init__(self, gpu = None): + import nvMatmulHeuristics + self.mmh_lib = nvMatmulHeuristics + self.gpu = gpu + + if 'CUTLASS_NVMMH_SO_PATH' in os.environ: + nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH']) + else: + nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx + + self.lh = nvmmhInterfaceEx( + backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"], + flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING, + load_discovery_implicitly=True, + gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None + ) + self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"]) + + def _layout_from_cutlass(self, layouts): + assert(len(layouts)==3) + full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts) + input_layouts = full_layout_str[:2].upper() + lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR") + return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout] + + def _precision_from_cutlass_dtypes(self, dtypes): + dtype_to_cublas = { + DataType.f64: 'D', + DataType.f32: 'S', + DataType.f16: 'H', + DataType.bf16: 'T', + DataType.e4m3: 'Q', + DataType.e5m2: 'R', + DataType.s32: 'I', + DataType.s8: 'B', + } + + dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes + + a_c = dtype_to_cublas[dtype_a] + + if a_c.lower() != 'q': + return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d] + else: + return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d] + + def set_cta_div_n(self, div_n): + cta_n_div_requirement = ctypes.c_int(div_n) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT, + ctypes.byref(cta_n_div_requirement), + ctypes.sizeof(cta_n_div_requirement) + ) + + def set_cta_div_m(self, div_m): + cta_m_div_requirement = ctypes.c_int(div_m) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT, + ctypes.byref(cta_m_div_requirement), + ctypes.sizeof(cta_m_div_requirement) + ) + + def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1): + if use_fast_acc: + disable_fast_acc_for_fp8 = ctypes.c_int(0) + else: + disable_fast_acc_for_fp8 = ctypes.c_int(1) + self.lh.setBackendValueProperty( + self.backend, + self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8, + ctypes.byref(disable_fast_acc_for_fp8), + ctypes.sizeof(disable_fast_acc_for_fp8) + ) + + precision = self._precision_from_cutlass_dtypes(dtypes) + layout = self._layout_from_cutlass(layouts) + + matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count) + configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision) + + ret = [] + for c in configs: + kernel = c['kernel'] + problem = c['problem'] + + r = {} + r['estimated_runtime'] = c['runtime'] + r['cta_tile_m'] = kernel.cta_tile_m + r['cta_tile_n'] = kernel.cta_tile_n + r['cta_tile_k'] = kernel.cta_tile_k + r['instr_tile_m'] = kernel.instr_tile_m + r['instr_tile_n'] = kernel.instr_tile_n + r['instr_tile_k'] = kernel.instr_tile_k + r['warp_tile_m'] = kernel.warp_tile_m + r['warp_tile_n'] = kernel.warp_tile_n + r['warp_tile_k'] = kernel.warp_tile_k + r['cluster_m'] = kernel.cluster_m + r['cluster_n'] = kernel.cluster_n + r['cluster_k'] = 1 + r['layout_a'] = layouts[0] + r['layout_b'] = layouts[1] + r['layout_d'] = layouts[2] + r['dtype_a'] = dtypes[0] + r['dtype_b'] = dtypes[1] + r['dtype_acc'] = dtypes[2] + r['dtype_c'] = dtypes[3] + r['dtype_d'] = dtypes[4] + r['alignment_a'] = align_a + r['alignment_b'] = align_b + r['swizzle_size'] = kernel.swizzle_factor + r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n' + r['split_k_slices'] = kernel.split_k + r['use_fast_acc'] = use_fast_acc + r['voidC'] = voidC + + ret.append(r) + + return ret + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py new file mode 100644 index 0000000000000000000000000000000000000000..56d22dc4b0705b4813b15b1b09decf53b38f7f37 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py @@ -0,0 +1,1531 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Data types and tags used for emitting CUTLASS C++ kernels +""" + +import enum +import re + +# The following block implements enum.auto() for Python 3.5 variants that don't include it such +# as the default 3.5.2 on Ubuntu 16.04. +# +# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility + +try: + from enum import auto as enum_auto +except ImportError: + __cutlass_library_auto_enum = 0 + def enum_auto() -> int: + global __cutlass_library_auto_enum + i = __cutlass_library_auto_enum + __cutlass_library_auto_enum += 1 + return i + +################################################################################################### + +# +class GeneratorTarget(enum.Enum): + Library = enum_auto() +# +GeneratorTargetNames = { + GeneratorTarget.Library: 'library' +} +# + +################################################################################################### + +# +class DataType(enum.Enum): + void = enum_auto() # primarily used to disable C tensor for epilogues + b1 = enum_auto() + u2 = enum_auto() + u4 = enum_auto() + u8 = enum_auto() + u16 = enum_auto() + u32 = enum_auto() + u64 = enum_auto() + s2 = enum_auto() + s4 = enum_auto() + s8 = enum_auto() + s16 = enum_auto() + s32 = enum_auto() + s64 = enum_auto() + e4m3 = enum_auto() + e5m2 = enum_auto() + f8 = enum_auto() + f6 = enum_auto() + f4 = enum_auto() + e3m2 = enum_auto() + e2m3 = enum_auto() + e2m1 = enum_auto() + ue8m0 = enum_auto() + ue4m3 = enum_auto() + f16 = enum_auto() + bf16 = enum_auto() + f32 = enum_auto() + tf32 = enum_auto() + f64 = enum_auto() + cf16 = enum_auto() + cbf16 = enum_auto() + cf32 = enum_auto() + ctf32 = enum_auto() + cf64 = enum_auto() + cs2 = enum_auto() + cs4 = enum_auto() + cs8 = enum_auto() + cs16 = enum_auto() + cs32 = enum_auto() + cs64 = enum_auto() + cu2 = enum_auto() + cu4 = enum_auto() + cu8 = enum_auto() + cu16 = enum_auto() + cu32 = enum_auto() + cu64 = enum_auto() + invalid = enum_auto() + +# +ShortDataTypeNames = { + DataType.s32: 'i', + DataType.e4m3: 'e4m3', + DataType.e5m2: 'e5m2', + DataType.f16: 'h', + DataType.f32: 's', + DataType.f64: 'd', + DataType.cf32: 'c', + DataType.cf64: 'z', + DataType.f8: 'f8', + DataType.f6: 'f6', + DataType.f4: 'f4', +} + +# +DataTypeNames = { + DataType.void: "void", + DataType.b1: "b1", + DataType.u2: "u2", + DataType.u4: "u4", + DataType.u8: "u8", + DataType.u16: "u16", + DataType.u32: "u32", + DataType.u64: "u64", + DataType.s2: "s2", + DataType.s4: "s4", + DataType.s8: "s8", + DataType.s16: "s16", + DataType.s32: "s32", + DataType.s64: "s64", + DataType.e4m3: 'e4m3', + DataType.e5m2: 'e5m2', + DataType.f8: 'f8', + DataType.f6: 'f6', + DataType.f4: 'f4', + DataType.e2m3: 'e2m3', + DataType.e3m2: 'e3m2', + DataType.e2m1: 'e2m1', + DataType.ue8m0: 'ue8m0', + DataType.ue4m3: 'ue4m3', + DataType.f16: "f16", + DataType.bf16: "bf16", + DataType.f32: "f32", + DataType.tf32: "tf32", + DataType.f64: "f64", + DataType.cf16: "cf16", + DataType.cbf16: "cbf16", + DataType.cf32: "cf32", + DataType.ctf32: "ctf32", + DataType.cf64: "cf64", + DataType.cu2: "cu2", + DataType.cu4: "cu4", + DataType.cu8: "cu8", + DataType.cu16: "cu16", + DataType.cu32: "cu32", + DataType.cu64: "cu64", + DataType.cs2: "cs2", + DataType.cs4: "cs4", + DataType.cs8: "cs8", + DataType.cs16: "cs16", + DataType.cs32: "cs32", + DataType.cs64: "cs64", +} + +DataTypeTag = { + DataType.void: "void", + DataType.b1: "cutlass::uint1b_t", + DataType.u2: "cutlass::uint2b_t", + DataType.u4: "cutlass::uint4b_t", + DataType.u8: "uint8_t", + DataType.u16: "uint16_t", + DataType.u32: "uint32_t", + DataType.u64: "uint64_t", + DataType.s2: "cutlass::int2b_t", + DataType.s4: "cutlass::int4b_t", + DataType.s8: "int8_t", + DataType.s16: "int16_t", + DataType.s32: "int32_t", + DataType.s64: "int64_t", + DataType.e4m3: 'cutlass::float_e4m3_t', + DataType.e5m2: 'cutlass::float_e5m2_t', + DataType.f8: 'cutlass::type_erased_dynamic_float8_t', + DataType.f6: 'cutlass::type_erased_dynamic_float6_t', + DataType.f4: 'cutlass::type_erased_dynamic_float4_t', + DataType.e2m3: 'cutlass::float_e2m3_t', + DataType.e3m2: 'cutlass::float_e3m2_t', + DataType.e2m1: 'cutlass::float_e2m1_t', + DataType.ue8m0: 'cutlass::float_ue8m0_t', + DataType.ue4m3: 'cutlass::float_ue4m3_t', + DataType.f16: "cutlass::half_t", + DataType.bf16: "cutlass::bfloat16_t", + DataType.f32: "float", + DataType.tf32: "cutlass::tfloat32_t", + DataType.f64: "double", + DataType.cf16: "cutlass::complex", + DataType.cbf16: "cutlass::complex", + DataType.cf32: "cutlass::complex", + DataType.ctf32: "cutlass::complex", + DataType.cf64: "cutlass::complex", + DataType.cu2: "cutlass::complex", + DataType.cu4: "cutlass::complex", + DataType.cu8: "cutlass::complex", + DataType.cu16: "cutlass::complex", + DataType.cu32: "cutlass::complex", + DataType.cu64: "cutlass::complex", + DataType.cs2: "cutlass::complex", + DataType.cs4: "cutlass::complex", + DataType.cs8: "cutlass::complex", + DataType.cs16: "cutlass::complex", + DataType.cs32: "cutlass::complex", + DataType.cs64: "cutlass::complex", +} + +DataTypeSize = { + DataType.void: 0, + DataType.b1: 1, + DataType.u2: 2, + DataType.u4: 4, + DataType.u8: 8, + DataType.u16: 16, + DataType.u32: 32, + DataType.u64: 64, + DataType.s2: 2, + DataType.s4: 4, + DataType.s8: 8, + DataType.s16: 16, + DataType.s32: 32, + DataType.s64: 64, + DataType.e4m3: 8, + DataType.e5m2: 8, + DataType.f8: 8, + DataType.f6: 6, + DataType.f4: 4, + DataType.e2m3: 6, + DataType.e3m2: 6, + DataType.e2m1: 4, + DataType.ue8m0: 8, + DataType.ue4m3: 8, + DataType.f16: 16, + DataType.bf16: 16, + DataType.f32: 32, + DataType.tf32: 32, + DataType.f64: 64, + DataType.cf16: 32, + DataType.cbf16: 32, + DataType.cf32: 64, + DataType.ctf32: 32, + DataType.cf64: 128, + DataType.cu2: 4, + DataType.cu4: 8, + DataType.cu8: 16, + DataType.cu16: 32, + DataType.cu32: 64, + DataType.cu64: 128, + DataType.cs2: 4, + DataType.cs4: 8, + DataType.cs8: 16, + DataType.cs16: 32, + DataType.cs32: 64, + DataType.cs64: 128, +} + +################################################################################################### +# +class BlasMode(enum.Enum): + symmetric = enum_auto() + hermitian = enum_auto() + +# +BlasModeTag = { + BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric', + BlasMode.hermitian: 'cutlass::BlasMode::kHermitian', +} + +# +class ComplexTransform(enum.Enum): + none = enum_auto() + conj = enum_auto() + +# +ComplexTransformTag = { + ComplexTransform.none: 'cutlass::ComplexTransform::kNone', + ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', +} + +# Used for cutlass3x complex kernel collective mainloop builder instantiation +ComplexTransformTag3x = { + ComplexTransform.none: 'cute::identity', + ComplexTransform.conj: 'cute::conjugate', +} + +# +RealComplexBijection = [ + (DataType.f16, DataType.cf16), + (DataType.f32, DataType.cf32), + (DataType.f64, DataType.cf64), +] + +# +def is_complex(data_type): + for r, c in RealComplexBijection: + if data_type == c: + return True + return False + +def is_block_scaled(gemm_kind): + return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x) + +def is_blockwise(gemm_kind): + return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x) + +def is_grouped(gemm_kind): + return gemm_kind in (GemmKind.GroupedUniversal3x, + GemmKind.GroupedBlockScaledUniversal3x, GemmKind.GroupedBlockwiseUniversal3x) + +# +def get_complex_from_real(real_type): + for r, c in RealComplexBijection: + if real_type == r: + return c + return DataType.invalid + +# +def get_real_from_complex(complex_type): + for r, c in RealComplexBijection: + if complex_type == c: + return r + return DataType.invalid + +# TMA requires an alignment of 128 bits for all data types +def get_tma_alignment(data_type): + if data_type == DataType.void: + return 0 + elif DataTypeSize[data_type] == 6: + return 128 # 96B alignment for 16U6 format + else: + return 128 // DataTypeSize[data_type] + +# +class ComplexMultiplyOp(enum.Enum): + multiply_add = enum_auto() + gaussian = enum_auto() + +################################################################################################### + +# +class MathOperation(enum.Enum): + multiply_add = enum_auto() + multiply_add_saturate = enum_auto() + multiply_add_mixed_input_upcast = enum_auto() + xor_popc = enum_auto() + and_popc = enum_auto() + multiply_add_fast_bf16 = enum_auto() + multiply_add_fast_f16 = enum_auto() + multiply_add_fast_f32 = enum_auto() + multiply_add_complex_fast_f32 = enum_auto() + multiply_add_complex = enum_auto() + multiply_add_complex_gaussian = enum_auto() + multiply_add_fast_accum = enum_auto() + +# +MathOperationTag = { + MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', + MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', + MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast', + MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', + MathOperation.and_popc: 'cutlass::arch::OpAndPopc', + MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', + MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', + MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', + MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32', + MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', + MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', + MathOperation.multiply_add_fast_accum: 'cutlass::arch::OpMultiplyAddFastAccum', +} + +################################################################################################### + +# +class LayoutType(enum.Enum): + ColumnMajor = enum_auto() + RowMajor = enum_auto() + ColumnMajorInterleaved2 = enum_auto() + RowMajorInterleaved2 = enum_auto() + ColumnMajorInterleaved32 = enum_auto() + RowMajorInterleaved32 = enum_auto() + ColumnMajorInterleaved64 = enum_auto() + RowMajorInterleaved64 = enum_auto() + TensorNWC = enum_auto() + TensorNHWC = enum_auto() + TensorNDHWC = enum_auto() + TensorNCHW = enum_auto() + TensorNGHWC = enum_auto() + TensorNC32HW32 = enum_auto() + TensorNC64HW64 = enum_auto() + TensorC32RSK32 = enum_auto() + TensorC64RSK64 = enum_auto() + TensorKCS = enum_auto() + TensorKCSR = enum_auto() + TensorKCSRT = enum_auto() + +# +LayoutTag = { + LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', + LayoutType.RowMajor: 'cutlass::layout::RowMajor', + LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', + LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', + LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', + LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', + LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', + LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', + LayoutType.TensorNWC: 'cutlass::layout::TensorNWC', + LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', + LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', + LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', + LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', + LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', + LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', + LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', + LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', + LayoutType.TensorKCS: 'cutlass::layout::TensorKCS', + LayoutType.TensorKCSR: 'cutlass::layout::TensorKCSR', + LayoutType.TensorKCSRT: 'cutlass::layout::TensorKCSRT' +} + +# +TransposedLayout = { + LayoutType.ColumnMajor: LayoutType.RowMajor, + LayoutType.RowMajor: LayoutType.ColumnMajor, + LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, + LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, + LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, + LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, + LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, + LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, + LayoutType.TensorNHWC: LayoutType.TensorNHWC +} + +# +ShortLayoutTypeNames = { + LayoutType.ColumnMajor: 'n', + LayoutType.ColumnMajorInterleaved2: 'n2', + LayoutType.ColumnMajorInterleaved32: 'n32', + LayoutType.ColumnMajorInterleaved64: 'n64', + LayoutType.RowMajor: 't', + LayoutType.RowMajorInterleaved2: 't2', + LayoutType.RowMajorInterleaved32: 't32', + LayoutType.RowMajorInterleaved64: 't64', + LayoutType.TensorNWC: 'nwc', + LayoutType.TensorNHWC: 'nhwc', + LayoutType.TensorNDHWC: 'ndhwc', + LayoutType.TensorNCHW: 'nchw', + LayoutType.TensorNGHWC: 'nghwc', + LayoutType.TensorNC32HW32: 'nc32hw32', + LayoutType.TensorNC64HW64: 'nc64hw64', + LayoutType.TensorC32RSK32: 'c32rsk32', + LayoutType.TensorC64RSK64: 'c64rsk64', + LayoutType.TensorKCS: 'kcs', + LayoutType.TensorKCSR: 'kcsr', + LayoutType.TensorKCSRT: 'kcsrt' +} + +# +ShortComplexLayoutNames = { + (LayoutType.ColumnMajor, ComplexTransform.none): 'n', + (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', + (LayoutType.RowMajor, ComplexTransform.none): 't', + (LayoutType.RowMajor, ComplexTransform.conj): 'h' +} + +################################################################################################### +class KernelScheduleType(enum.Enum): + ScheduleAuto = enum_auto() + Multistage = enum_auto() + CpAsyncWarpSpecialized = enum_auto() + CpAsyncWarpSpecializedPingpong = enum_auto() + CpAsyncWarpSpecializedCooperative = enum_auto() + Tma = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedPingpong = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() + TmaWarpSpecializedFP8FastAccum = enum_auto() + TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() + TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() + ImplicitTmaWarpSpecializedSm90 = enum_auto() + PtrArrayTmaWarpSpecializedCooperative = enum_auto() + PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() + PtrArrayTmaWarpSpecializedPingpong = enum_auto() + PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto() + + BlockwiseTmaWarpSpecializedCooperative = enum_auto() + PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto() + BlockwiseTmaWarpSpecializedPingpong = enum_auto() + PtrArrayBlockwiseTmaWarpSpecializedPingpong = enum_auto() + + TmaWarpSpecialized1SmSm100 = enum_auto() + TmaWarpSpecialized2SmSm100 = enum_auto() + ImplicitTmaWarpSpecialized1SmSm100 = enum_auto() + ImplicitTmaWarpSpecialized2SmSm100 = enum_auto() + + PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto() + + PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto() + PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto() + PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto() + PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto() + PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() + + SparseTmaWarpSpecialized1SmSm100 = enum_auto() + SparseTmaWarpSpecialized2SmSm100 = enum_auto() + + BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto() + BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto() + Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto() + Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto() + + BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() + BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() + + PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto() + PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto() + + + Mxf4TmaWarpSpecialized1SmSm100 = enum_auto() + Mxf4TmaWarpSpecialized2SmSm100 = enum_auto() + Nvf4TmaWarpSpecialized1SmSm100 = enum_auto() + Nvf4TmaWarpSpecialized2SmSm100 = enum_auto() + + # FP4 Ultra + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto() + + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto() + PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto() + + Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto() + Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto() + Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto() + Nvf4TmaWarpSpecializedPingpongSm120 = enum_auto() + Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto() + Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto() + + F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto() + + BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto() + BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto() + +KernelScheduleTag = { + KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', + KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage', + KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized', + KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong', + KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative', + KernelScheduleType.Tma: 'cutlass::gemm::KernelTma', + KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized', + KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong', + KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative', + KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum', + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum', + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', + KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8Blockwise', + + KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100', + KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100', + + KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized1SmSm100', + KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100', + + KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100', + KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100', + + KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100', + KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100', + + KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100', + KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100', + + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100', + + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100', + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100', + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100', + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100', + + # FP4 Ultra + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum', + + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100", + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100", + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100", + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100", + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100", + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100", + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100", + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100", + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch', + + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120', + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120', + KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120', + KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongNvf4Sm120', + KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120', + KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120', + + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120', +} + +# +KernelScheduleSuffixes = { + KernelScheduleType.ScheduleAuto: '', + KernelScheduleType.Multistage: '_cpasync', + KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized', + KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong', + KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative', + KernelScheduleType.Tma: '_unspecialized', + KernelScheduleType.TmaWarpSpecialized: '_warpspecialized', + KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum', + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', + KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + + KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm', + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm', + + KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: '_2sm', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: '_1sm', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: '_2sm', + + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', + + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', + + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', + + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong', + + KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm', + KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm', + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm', + KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm', + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm', + KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf', + + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf', + KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf', + + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: '_cooperative_q', + KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: '_pingpong_q', + KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs16', + KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs16', + KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32', + KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32', + + KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q', + + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q', + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: '_pingpong_q' +} + +class EpilogueScheduleType(enum.Enum): + ScheduleAuto = enum_auto() + EpilogueTransposed = enum_auto() + NoSmemWarpSpecialized = enum_auto() + PtrArrayNoSmemWarpSpecialized = enum_auto() + NoSmemWarpSpecialized1Sm = enum_auto() + NoSmemWarpSpecialized2Sm = enum_auto() + FastF32NoSmemWarpSpecialized1Sm = enum_auto() + FastF32NoSmemWarpSpecialized2Sm = enum_auto() + BlockwiseNoSmemWarpSpecialized1Sm = enum_auto() + BlockwiseNoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayNoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayNoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto() + PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto() + PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto() + TmaWarpSpecialized = enum_auto() + TmaWarpSpecializedCooperative = enum_auto() + TmaWarpSpecialized1Sm = enum_auto() + TmaWarpSpecialized2Sm = enum_auto() + PtrArrayTmaWarpSpecialized1Sm = enum_auto() + PtrArrayTmaWarpSpecialized2Sm = enum_auto() + PtrArrayTmaWarpSpecializedPingpong = enum_auto() + PtrArrayTmaWarpSpecializedCooperative = enum_auto() + +# +EpilogueScheduleTag = { + EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto', + EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed', + EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized', + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm', + EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized', + EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative', + EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm', + EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong', +} + +# +EpilogueScheduleSuffixes = { + EpilogueScheduleType.ScheduleAuto: '', + EpilogueScheduleType.EpilogueTransposed: '', + EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem', + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem', + EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem', + EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', + EpilogueScheduleType.TmaWarpSpecialized1Sm: '', + EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '', + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma', + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma', +} + +class EpilogueFunctor3x(enum.Enum): + LinearCombination = enum_auto() + LinearCombinationBlockScaleFactor = enum_auto() + +# +EpilogueFunctor3xTag = { + EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination', + EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor', +} + +# TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type) +def is_tma_epilogue(epilogue_schedule_type): + return epilogue_schedule_type in [ + EpilogueScheduleType.ScheduleAuto, + EpilogueScheduleType.TmaWarpSpecialized, + EpilogueScheduleType.TmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecialized1Sm, + EpilogueScheduleType.TmaWarpSpecialized2Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, + EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, + EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, + ] + +def to_grouped_schedule(schedule, grouped): + if not grouped: + return schedule + + group_schedule_map = { + # SM90 + KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative, + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative, + KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong, + KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong, + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum, + EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong, + EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative, + EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized, + # SM100 + KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100, + KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100, + KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100, + KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100, + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100, + KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100, + KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100, + KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100, + EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm, + EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm, + EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm, + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm, + EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm, + # SM103 + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, + KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, + } + + return group_schedule_map[schedule] + +class TileSchedulerType(enum.Enum): + Default = enum_auto() + Persistent = enum_auto() + StreamK = enum_auto() +# +TileSchedulerTag = { + TileSchedulerType.Default: 'void', + TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler', + TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler', +} + +# +TileSchedulerSuffixes = { + TileSchedulerType.Default: '', + TileSchedulerType.Persistent: '', + TileSchedulerType.StreamK: '_stream_k', +} + +################################################################################################### + +# +class SideMode(enum.Enum): + Left = enum_auto() + Right = enum_auto() + +# +SideModeTag = { + SideMode.Left: 'cutlass::SideMode::kLeft', + SideMode.Right: 'cutlass::SideMode::kRight' +} + +# +ShortSideModeNames = { + SideMode.Left: 'ls', + SideMode.Right: 'rs' +} + +################################################################################################### + +# +class FillMode(enum.Enum): + Lower = enum_auto() + Upper = enum_auto() + +# +FillModeTag = { + FillMode.Lower: 'cutlass::FillMode::kLower', + FillMode.Upper: 'cutlass::FillMode::kUpper' +} + +# +ShortFillModeNames = { + FillMode.Lower: 'l', + FillMode.Upper: 'u' +} + +################################################################################################### + +# +class DiagType(enum.Enum): + NonUnit = enum_auto() + Unit = enum_auto() + +# +DiagTypeTag = { + DiagType.NonUnit: 'cutlass::DiagType::kNonUnit', + DiagType.Unit: 'cutlass::DiagType::kUnit' +} + +# +ShortDiagTypeNames = { + DiagType.NonUnit: 'nu', + DiagType.Unit: 'un' +} + +################################################################################################### + +# +class OpcodeClass(enum.Enum): + Simt = enum_auto() + TensorOp = enum_auto() + WmmaTensorOp = enum_auto() + SparseTensorOp = enum_auto() + BlockScaledTensorOp = enum_auto() + + +OpcodeClassNames = { + OpcodeClass.Simt: 'simt', + OpcodeClass.TensorOp: 'tensorop', + OpcodeClass.WmmaTensorOp: 'wmma_tensorop', + OpcodeClass.SparseTensorOp: 'sptensorop', + OpcodeClass.BlockScaledTensorOp: 'bstensorop' +} + +OpcodeClassTag = { + OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', + OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', + OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', + OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp', + OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp' +} + +################################################################################################### + +# +class OperationKind(enum.Enum): + Gemm = enum_auto() + RankK = enum_auto() + Rank2K = enum_auto() + Trmm = enum_auto() + Symm = enum_auto() + Conv2d = enum_auto() + Conv3d = enum_auto() + +# +OperationKindNames = { + OperationKind.Gemm: 'gemm' + , OperationKind.RankK: 'rank_k' + , OperationKind.Rank2K: 'rank_2k' + , OperationKind.Trmm: 'trmm' + , OperationKind.Symm: 'symm' + , OperationKind.Conv2d: 'conv2d' + , OperationKind.Conv3d: 'conv3d' +} + +# +class Target(enum.Enum): + library = enum_auto() +# +ArchitectureNames = { + 50: 'maxwell', + 60: 'pascal', + 61: 'pascal', + 70: 'volta', + 75: 'turing', + 80: 'ampere', + 89: 'ada', + 90: 'hopper' +} + +# +SharedMemPerCC = { + 70: 96, # 96KB of SMEM + 72: 96, # 96KB of SMEM + 75: 64, # 64KB of SMEM + 80: 163, # 163KB of SMEM - 1KB reserved for the driver + 86: 99, # 99KB of SMEM - 1KB reserved for the driver + 87: 163, # 163KB of SMEM - 1KB reserved for the driver + 89: 99, # 99KB of SMEM - 1KB reserved for the driver + 90: 227, # 227KB of SMEM - 1KB reserved for the driver + 100: 227, # 227KB of SMEM - 1KB reserved for the driver +} + +################################################################################################### + +# +def SubstituteTemplate(template, values): + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text + +################################################################################################### + +# +class GemmKind(enum.Enum): + Gemm = enum_auto() + Sparse = enum_auto() + Universal = enum_auto() + Universal3x = enum_auto() + SparseUniversal3x = enum_auto() + PlanarComplex = enum_auto() + PlanarComplexArray = enum_auto() + Grouped = enum_auto() + BlockScaledUniversal3x = enum_auto() + GroupedUniversal3x = enum_auto() + GroupedBlockScaledUniversal3x = enum_auto() + BlockwiseUniversal3x = enum_auto() + GroupedBlockwiseUniversal3x = enum_auto() + +# +GemmKindNames = { + GemmKind.Gemm: "gemm", + GemmKind.Sparse: "spgemm", + GemmKind.Universal: "gemm", + GemmKind.Universal3x: "gemm", + GemmKind.SparseUniversal3x: "spgemm", + GemmKind.PlanarComplex: "gemm_planar_complex", + GemmKind.PlanarComplexArray: "gemm_planar_complex_array", + GemmKind.Grouped: "gemm_grouped", + GemmKind.BlockScaledUniversal3x: "gemm", + GemmKind.GroupedUniversal3x: "gemm_grouped", + GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped", + GemmKind.BlockwiseUniversal3x: "gemm", + GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped" +} + +# +class RankKKind(enum.Enum): + Universal = enum_auto() + +# +RankKKindNames = { + RankKKind.Universal: "rank_k" +} + +# +class TrmmKind(enum.Enum): + Universal = enum_auto() + +# +TrmmKindNames = { + TrmmKind.Universal: "trmm" +} + +# +class SymmKind(enum.Enum): + Universal = enum_auto() + +# +SymmKindNames = { + SymmKind.Universal: "symm" +} + +# +class EpilogueFunctor(enum.Enum): + LinearCombination = enum_auto() + LinearCombinationClamp = enum_auto() + +# +EpilogueFunctorTag = { + EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', + EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', +} + +# +class MixedInputMode(enum.Enum): + ConvertOnly = enum_auto() + ScaleOnly = enum_auto() + ScaleWithZeroPoint = enum_auto() + +# +class SwizzlingFunctor(enum.Enum): + Identity1 = enum_auto() + Identity2 = enum_auto() + Identity4 = enum_auto() + Identity8 = enum_auto() + Horizontal = enum_auto() + StridedDgradIdentity1 = enum_auto() + StridedDgradIdentity4 = enum_auto() + StridedDgradHorizontal = enum_auto() + StreamK = enum_auto() + +# +SwizzlingFunctorTag = { + SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', + SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', + SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', + SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', + SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle', + SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>', + SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>', + SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle', + SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK', +} + +# +class GroupScheduleMode(enum.Enum): + Device = enum_auto(), + Host = enum_auto() + +# +GroupScheduleModeTag = { + GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly', + GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute' +} + +# +ShortGroupScheduleModeNames = { + GroupScheduleMode.Device: 'Device', + GroupScheduleMode.Host: 'Host' +} + +################################################################################################### + +# +class ConvKind(enum.IntEnum): + Fprop = 0 + Dgrad = 1 + Wgrad = 2 + +# +ConvKindTag = { + ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', + ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', + ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' +} + +ConvKindNames = { + ConvKind.Fprop: 'fprop', + ConvKind.Dgrad: 'dgrad', + ConvKind.Wgrad: 'wgrad', +} + +class ConvMode(enum.IntEnum): + CrossCorrelation = 0 + Convolution = 1 + +# +class IteratorAlgorithm(enum.Enum): + Analytic = 0 + Optimized = 1 + FixedChannels = 2 + FewChannels = 3 + FixedStrideDilation = 4 + +# +IteratorAlgorithmTag = { + IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', + IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', + IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels', + IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels', + IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation' +} + +IteratorAlgorithmNames = { + IteratorAlgorithm.Analytic: 'analytic', + IteratorAlgorithm.Optimized: 'optimized', + IteratorAlgorithm.FixedChannels: 'fixed_channels', + IteratorAlgorithm.FewChannels: 'few_channels', + IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation' +} + +# +class StrideSupport(enum.Enum): + Strided = 0 + Unity = 1 + Fixed = 2 + +# +StrideSupportTag = { + StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', + StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', + StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed' +} + +StrideSupportNames = { + StrideSupport.Strided: '', + StrideSupport.Unity: 'unity_stride', + StrideSupport.Fixed: 'fixed_stride' +} + +# +class GroupMode(enum.Enum): + NoneGroup = enum_auto() # dense conv (G=1) + SingleGroup = enum_auto() # grouped convolution (single group per CTA) + MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) + Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) + +# +GroupModeTag = { + GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone', + GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup', + GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup', + GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise', +} + +GroupModeNames = { + GroupMode.NoneGroup: '', + GroupMode.SingleGroup: 'single_group', + GroupMode.MultipleGroup: 'multiple_group', + GroupMode.Depthwise: 'depthwise', +} + +DynamicClusterShape = [0, 0, 1] + +################################################################################################### + +# +class MathInstruction: + def __init__(self, + instruction_shape, \ + element_a, element_b, element_accumulator, \ + opcode_class, math_operation = MathOperation.multiply_add \ + , element_scale_factor = None + ): + + self.instruction_shape = instruction_shape + self.element_a = element_a + self.element_b = element_b + self.element_accumulator = element_accumulator + self.opcode_class = opcode_class + self.math_operation = math_operation + self.element_scale_factor = element_scale_factor + +# +class TileDescription: + + def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1], explicit_vector_sizes = None): + self.threadblock_shape = threadblock_shape + self.tile_shape = threadblock_shape + self.stages = stages + self.warp_count = warp_count + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + self.cluster_shape = cluster_shape + self.explicit_vector_sizes = explicit_vector_sizes + + def procedural_name(self): + if self.minimum_compute_capability >= 90: + return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format( + tbm = self.threadblock_shape[0], + tbn = self.threadblock_shape[1], + tbk = self.threadblock_shape[2], + cm = self.cluster_shape[0], + cn = self.cluster_shape[1], + ck = self.cluster_shape[2], + s = self.stages) + else: + return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) + +# +class Direct2dConvFixedStrideDilationTileDescription: + def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): + self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] + self.threadblock_output_shape = threadblock_output_shape + self.filter_shape = filter_shape + self.stages = stages + self.warp_count = warp_count + self.stride = stride + self.dilation = dilation + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.threadblock_output_shape[0], + self.threadblock_output_shape[1], + self.threadblock_output_shape[2], + self.threadblock_output_shape[3], + self.stages, + self.filter_shape[0], + self.filter_shape[1]) + # Fixed Strided and dilation + if self.stride != [-1, -1] and self.dilation != [-1, -1]: + str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], + self.stride[1], + self.dilation[0], + self.dilation[1]) + return str_name + +# +class Direct2dConvFixedStrideDilationTileDescription: + def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute): + self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]] + self.threadblock_output_shape = threadblock_output_shape + self.filter_shape = filter_shape + self.stages = stages + self.warp_count = warp_count + self.stride = stride + self.dilation = dilation + self.math_instruction = math_instruction + self.minimum_compute_capability = min_compute + self.maximum_compute_capability = max_compute + + def procedural_name(self): + str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0], + self.threadblock_shape[1], + self.threadblock_shape[2], + self.threadblock_output_shape[0], + self.threadblock_output_shape[1], + self.threadblock_output_shape[2], + self.threadblock_output_shape[3], + self.stages, + self.filter_shape[0], + self.filter_shape[1]) + # Fixed Strided and dilation + if self.stride != [-1, -1] and self.dilation != [-1, -1]: + str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0], + self.stride[1], + self.dilation[0], + self.dilation[1]) + return str_name + +# +class TensorDescription: + def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): + self.element = element + self.layout = layout + self.alignment = alignment + self.complex_transform = complex_transform + +# +class SymmetricTensorDescription: + def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left): + self.element = element + self.layout = layout + self.fill_mode = fill_mode + self.alignment = alignment + self.complex_transform = complex_transform + self.side_mode = side_mode + +# +class TriangularTensorDescription: + def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none): + self.element = element + self.layout = layout + self.side_mode = side_mode + self.fill_mode = fill_mode + self.diag_type = diag_type + self.alignment = alignment + self.complex_transform = complex_transform + +# +def CalculateSmemUsage(operation): + cta_shape = operation.tile_description.threadblock_shape + stages = operation.tile_description.stages + + if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse: + # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity) + if DataTypeSize[operation.A.element] == 32: + elements_per_8b_md = 2 + elif DataTypeSize[operation.A.element] == 4: + elements_per_8b_md = 8 + else: + elements_per_8b_md = 4 + + smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \ + DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \ + cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md + else: + # Few BLAS3 operations only have A tensor + data_type_size_a = DataTypeSize[operation.A.element] + data_type_size_b = DataTypeSize[operation.A.element] + if operation.is_mixed_input(): + data_type_size_b = DataTypeSize[operation.B.element] + + smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \ + data_type_size_b * cta_shape[1] * cta_shape[2] // 8 + + smem_usage = smem_per_stage * stages + return (smem_usage >> 10) + + +class GemmUniversalMode(enum.IntEnum): + """ + Types corresponding to GemmUniversalMode + """ + Gemm = 0 + GemmSplitKParallel = 1 + Batched = 2 + Array = 3 + + +class SplitKMode(enum.IntEnum): + """ + Types corresponding to SplitKMode + """ + NoneSplitK = 0 + Serial = 1 + Parallel = 2 diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..5733ef26322794ee650dfa0c8c2b170bd8c6f3e5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py @@ -0,0 +1,868 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for filtering CUTLASS library kernels and emitting library intitialization +and building code +""" + +import enum +import logging +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.gemm_operation import * + from cutlass_library.rank_k_operation import * + from cutlass_library.rank_2k_operation import * + from cutlass_library.trmm_operation import * + from cutlass_library.symm_operation import * + from cutlass_library.conv2d_operation import * + from cutlass_library.conv3d_operation import * +except ImportError: + from library import * + from gemm_operation import * + from rank_k_operation import * + from rank_2k_operation import * + from trmm_operation import * + from symm_operation import * + from conv2d_operation import * + from conv3d_operation import * + +################################################################################################### +_LOGGER = logging.getLogger(__name__) + + +class EmitOperationKindAll: + """ + Emit the OperationKind-level CUTLASS library initialization code. + The code is generated in the {generated_path}/{operation_kind} directory + (e.g., tools/library/generated/gemm in the build directory, + for OperationKind=Gemm), in the all_{operation_kind}_operations.cu file + (e.g., all_gemm_operations.cu for OperationKind=Gemm). + That file declares several functions in namespace cutlass::library. + The functions all have this form, + + void initialize_{configuration_name}(Manifest& manifest); + + The file also _defines_ the following function in that namespace. + + void initialize_all_{operation_kind}_operations(Manifest& manifest); + + That function calls all of the functions declared in this file. + Those functions are defined in subdirectories + (which this class does not create). + """ + + def __init__(self, generated_path, kind, args): + self.generated_path = generated_path + self.kind = kind + self.args = args + + self.header_template =""" +/* + Generated by manifest.py - Do not edit. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.entry_template = """ + +// +// Entry point to construct operations +// +void initialize_all_${operation_name}_operations(Manifest &manifest) { +""" + self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" + self.configuration_template =" initialize_${configuration_name}(manifest);\n" + + self.epilogue_template ="""} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +""" + + # + def __enter__(self): + _LOGGER.debug("*** EmitOperationKindAll::__enter__") + + self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind]) + _LOGGER.debug('*** operation_path (directory to create): ' + + str(self.operation_path)); + os.makedirs(self.operation_path, exist_ok=True) + + self.top_level_path = os.path.join(self.operation_path, f"all_{OperationKindNames[self.kind]}_operations.cu") + _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}") + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.header_template) + + self.source_files = [self.top_level_path,] + + self.configurations = [] + + return self + + # + def emit(self, operations): + _LOGGER.debug('*** EmitOperationKindAll::emit') + _LOGGER.debug(f"*** len(operations): {len(operations)}") + _LOGGER.debug(f"*** min_cc list: {sorted(min_cc for min_cc, _ in operations.items())}") + + for min_cc, configurations in sorted(operations.items()): + _LOGGER.debug(f"*** min_cc={min_cc}") + + for configuration_name, _ in configurations.items(): + _LOGGER.debug(f"*** configuration_name={configuration_name}") + self.configurations.append(configuration_name) + self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} )) + + # + def __exit__(self, exception_type, exception_value, traceback): + _LOGGER.debug("*** EmitOperationKindAll::__exit__") + + self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]})) + + for configuration_name in self.configurations: + self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name})) + + self.top_level_file.write(self.epilogue_template) + self.top_level_file.close() + + +class EmitOperationKindLibrary: + """ + Emit the CUTLASS library initialization code for each OperationKind. + The code is generated in the directory + {generated_path}/{operation_kind}/{min_cc} + (e.g., tools/library/generated/gemm/90 in the build directory, + for min_cc=90 and OperationKind=Gemm), in the file + all_sm{min_cc}_{operation_kind}_operations.cu + (e.g., all_sm90_gemm_operations.cu for min_cc=90 and OperationKind=Gemm). + The min_cc variable here indicates the minimum GPU architecture version + that the things to be initialized require. + For example, min_cc=90 indicates sm90. + + That file declares several functions in namespace cutlass::library. + The functions all have this form, + + void initialize_all_sm{min_cc}_{subclass_name}_{extended_name}_operations(Manifest& manifest); + + where extended_name is operation.extended_name() for all the operations + given to the emit method (which see below). (All operations for a given + configuration_name are guaranteed to have the same extended_name().) + + The file also _defines_ the following function in that namespace. + + void initialize_all_sm{min_cc}__{operation_kind}_operations(Manifest& manifest); + + That function calls all of the functions declared in this file. + Those functions are defined in subdirectories. + The mapping from OperationKind to emitter handles the details + of what happens in each of those subdirectories. + """ + + def __init__(self, generated_path, min_cc, kind, args): + self.generated_path = generated_path + self.min_cc = min_cc + self.kind = kind + self.args = args + self.emitters = { + OperationKind.Gemm: EmitGemmConfigurationLibrary, + OperationKind.Conv2d: EmitConv2dConfigurationLibrary, + OperationKind.Conv3d: EmitConv3dConfigurationLibrary, + OperationKind.RankK: EmitRankKConfigurationLibrary, + OperationKind.Rank2K: EmitRank2KConfigurationLibrary, + OperationKind.Trmm: EmitTrmmConfigurationLibrary, + OperationKind.Symm: EmitSymmConfigurationLibrary + } + + self.header_template =""" +/* + Generated by manifest.py - Do not edit. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + self.entry_template = """ + +// +// Entry point to construct operations +// +void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest) { +""" + self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n" + self.configuration_template = " initialize_${configuration_name}(manifest);\n" + self.subclass_call_template = " initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(manifest);\n" + self.subclass_prototype_template = "void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest);\n" + self.epilogue_template ="""} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +""" + + # + def __enter__(self): + _LOGGER.debug("*** EmitOperationKindLibrary::__enter__") + _LOGGER.debug(f"*** generated_path: {str(self.generated_path)}") + _LOGGER.debug(f"*** OperationKindNames[kind]: {OperationKindNames[self.kind]}") + _LOGGER.debug(f"*** min_cc: {self.min_cc}") + + self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind], str(self.min_cc)) + _LOGGER.debug(f"*** operation_path (directory to make): {str(self.operation_path)}") + os.makedirs(self.operation_path) + + self.top_level_path = os.path.join(self.operation_path, f"all_sm{self.min_cc}_{OperationKindNames[self.kind]}_operations.cu") + _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}") + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.header_template) + + self.source_files = {} + + # Each {operation_kind x cc} combination is further decomposed by the instruction + # types used. This dictionary used to track the file handles for the top-level + # files of each subclass + self.subclass_files = {} + + # Configurations in each sub class + self.subclass_configurations = {} + + return self + + # + def emit(self, configuration_name, operations): + _LOGGER.debug("*** EmitOperationKindLibrary::emit") + _LOGGER.debug(f"*** configuration_name: {configuration_name}") + + assert len(operations) > 0 + + # The extended name for all operations of a given configuration_name is guaranteed + # to be the same because extended_name() is used in defining configuration_name. Thus, + # we can safely use the extended_name() of the first operation. + extended_name = operations[0].extended_name() + _LOGGER.debug('*** extended_name (for all ops): ' + extended_name) + + # Create a directory for operations with this subclass if it does not exist + if extended_name not in self.subclass_files: + subclass_path = os.path.join(self.operation_path, extended_name) + _LOGGER.debug(f"*** subclass_path: {str(subclass_path)}") + os.mkdir(subclass_path) + + self.subclass_configurations[extended_name] = [] + + # Open a new top-level file for this sub class + subclass_top_level_path = os.path.join( + subclass_path, f"all_sm{self.min_cc}_{extended_name}_{OperationKindNames[self.kind]}_operations.cu") + _LOGGER.debug('*** subclass_top_level_path (min_cc, extended_name, ' + + 'OperationKind): ' + str(subclass_top_level_path)) + + self.subclass_files[extended_name] = open(subclass_top_level_path, "w") + self.subclass_files[extended_name].write(self.header_template) + + self.source_files[extended_name] = [subclass_top_level_path] + + subclass_dir = os.path.dirname(self.subclass_files[extended_name].name) + _LOGGER.debug('*** subclass_dir: ' + str(subclass_dir)) + + with self.emitters[self.kind](subclass_dir, configuration_name) as configuration_emitter: + for operation in operations: + configuration_emitter.emit(operation) + + _LOGGER.debug('*** configuration_emitter.configuration_path: ' + + str(configuration_emitter.configuration_path)) + self.source_files[extended_name].append(configuration_emitter.configuration_path) + + self.subclass_configurations[extended_name].append(configuration_name) + self.subclass_files[extended_name].write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} )) + + # + def __exit__(self, exception_type, exception_value, traceback): + _LOGGER.debug("*** EmitOperationKindLibrary::__exit__") + for subclass_name, subclass_file in sorted(self.subclass_files.items()): + subclass_cfg = { + 'min_cc': str(self.min_cc), + 'subclass_name': subclass_name, + 'operation_name': OperationKindNames[self.kind] + } + self.top_level_file.write(SubstituteTemplate(self.subclass_prototype_template, subclass_cfg)) + + self.top_level_file.write( + SubstituteTemplate(self.entry_template, { + 'min_cc': str(self.min_cc), + 'subclass_name': '', + 'operation_name': OperationKindNames[self.kind] + })) + + # Finish and close all subclass files + for subclass_name, subclass_file in sorted(self.subclass_files.items()): + subclass_cfg = { + 'min_cc': str(self.min_cc), + 'subclass_name': subclass_name, + 'operation_name': OperationKindNames[self.kind] + } + subclass_file.write(SubstituteTemplate(self.entry_template, subclass_cfg)) + + for configuration in self.subclass_configurations[subclass_name]: + subclass_file.write( + SubstituteTemplate(self.configuration_template, { + 'configuration_name': configuration + })) + + subclass_file.write(self.epilogue_template) + subclass_file.close() + + # Write the call to initialize_all for this subclass to the top-level file + self.top_level_file.write(SubstituteTemplate(self.subclass_call_template, subclass_cfg)) + + self.top_level_file.write(self.epilogue_template) + self.top_level_file.close() + +class EmitInterfaceLibrary: + """ + Emit the topmost-level CUTLASS library initialization code. + The code is generated in the generated_path directory + (e.g., tools/library/generated in the build directory), + in the initialize_all.cpp file. + That file declares several functions in namespace cutlass::library. + The functions all have this form, + + void initialize_all_{operation_kind}_operations(Manifest& manifest); + + where {operation_kind} abbreviates the "kind" of operation + (e.g., gemm for matrix-matrix multiply, conv2d for 2-d convolution, + or trmm for triangular solve with multiple right-hand sides). + The definitions of these functions live in subdirectories. + + The file also _defines_ the following function in that namespace. + + void initialize_all(Manifest& manifest); + + That function first prepares the manifest, and then + calls all of the functions declared in this file. + """ + + def __init__(self, generated_path, operation_count, args): + self.generated_path = generated_path + self.args = args + + self.prototypes = [] + self.fn_calls = [] + self.operation_count = str(operation_count) + + self.top_level_hdr_template = ''' +/* + Generated by manifest.py - Do not edit. +*/ +''' + self.top_level_prologue = ''' + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +namespace cutlass { +\tnamespace library { + +${prototypes} +''' + + self.top_level_initialize_kind = ''' +\t\tvoid initialize_all_${kind}_operations(Manifest &manifest) { +${fn_calls} +\t\t} +''' + + self.top_level_initialize = ''' +\t\tvoid initialize_all(Manifest &manifest) { +\t\t\tmanifest.reserve(${operation_count});\n +${fn_calls} +\t\t} +''' + + self.top_level_suffix = ''' +\t} // namespace library +} // namespace cutlass + +''' + + # + def __enter__(self): + _LOGGER.debug("*** EmitInterfaceLibrary::__enter__") + + self.top_level_path = os.path.join(self.generated_path, 'initialize_all.cpp') + _LOGGER.debug("*** top_level_path: " + str(self.top_level_path)) + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.top_level_hdr_template) + + self.source_files = [self.top_level_path,] + + return self + + # + def emit(self, operation_name): + _LOGGER.debug("*** EmitInterfaceLibrary::emit") + _LOGGER.debug("*** operation_name: " + operation_name) + + self.prototypes.append(SubstituteTemplate( + "\t\tvoid initialize_all_${operation_kind}_operations(Manifest &manifest);", + {'operation_kind': operation_name})) + + self.fn_calls.append(SubstituteTemplate( + "\t\t\tinitialize_all_${operation_kind}_operations(manifest);", + {'operation_kind': operation_name})) + + # + def __exit__(self, exception_type, exception_value, traceback): + _LOGGER.debug("*** EmitInterfaceLibrary::__exit__") + + self.top_level_file.write(SubstituteTemplate(self.top_level_prologue, {'prototypes':"\n".join(self.prototypes)})) + + # Write out initialize_all method + self.top_level_file.write(SubstituteTemplate(self.top_level_initialize, + {'operation_count': self.operation_count, 'fn_calls':"\n".join(self.fn_calls)})) + + self.top_level_file.write(self.top_level_suffix) + self.top_level_file.close() + +################################################################################################### +################################################################################################### + +class Options: + def __init__(self): + pass + +################################################################################################### + +# +class Manifest: + + # + def __init__(self, args = None): + self.operations = {} + self.args = args + self.operation_count = 0 + self.operations_by_name = {} + + self.kernel_filter = '' + self.kernel_filter_list = [] + self.kernel_names = [] + self.operations_enabled = [] + self.selected_kernels = [] + self.ignore_kernel_names = [] + self.exclude_kernel_names = [] + self.compute_capabilities_baseline = [50,] + self.compute_capabilities_feature_set = ['50',] + self.curr_build_dir = '.' + self.filter_by_cc = True + + if self.args: + self.kernel_filter = self.args.kernels + self.curr_build_dir = args.curr_build_dir + + # A common user error is to use commas instead of semicolons. + if ',' in args.architectures: + raise RuntimeError("The list of architectures (CMake option CUTLASS_NVCC_ARCHS) must be semicolon-delimited.\nDon't use commas to separate the architectures; use semicolons.\nYou specified the list as: " + args.architectures) + + self.compute_capabilities_feature_set = args.architectures.split(';') if len(args.architectures) else ['50',] + self.compute_capabilities_baseline = sorted(set(int(arch.split('a')[0].split('f')[0]) for arch in self.compute_capabilities_feature_set)) + + if args.filter_by_cc in ['false', 'False', '0']: + self.filter_by_cc = False + + if args.operations == 'all': + self.operations_enabled = [] + else: + operations_list = [ + OperationKind.Gemm + , OperationKind.Conv2d + , OperationKind.Conv3d + , OperationKind.RankK + , OperationKind.Trmm + , OperationKind.Symm + ] + self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')] + + if args.kernels == 'all': + self.kernel_names = [] + else: + self.kernel_names = [x for x in args.kernels.split(',') if x != ''] + + self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != ''] + self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != ''] + + if args.kernel_filter_file is None: + self.kernel_filter_list = [] + else: + self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) + _LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format( + filter_count = len(self.kernel_filter_list), + filter_file = args.kernel_filter_file)) + + self.operation_count = 0 + self.operations_by_name = {} + self.disable_full_archs_compilation = args.disable_full_archs_compilation + self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != '' + self.instantiation_level = 0 + try: + self.instantiation_level = int(args.instantiation_level) + except ValueError: + self.instantiation_level = 0 + + def add_kernel_filter(self, filter_str): + filter_re = re.compile(filter_str) + + self.kernel_filter_list.append(filter_re) + + def get_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992): + # Non-negative integer which determines how many kernels are instantiated. + # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations. + # increasing first digit reduces schedule / mixed type pruning, + # increasing second digit generates more cluster sizes, + # increasing third digit generates more MMA multipliers, + # increasing fourth digit generates more instruction shapes. + + if self.instantiation_level > 0: + return self.instantiation_level + + elif self.is_kernel_filter_set_to_all: + return exhaustive_level + + elif self.kernel_filter == '': + return pruned_level + + else: + return default_level + + + def get_kernel_filters(self, kernelListFile): + if os.path.isfile(kernelListFile): + with open(kernelListFile, 'r') as fileReader: + lines = [line.rstrip() for line in fileReader if not line.startswith("#")] + + lines = [re.compile(line) for line in lines if line] + return lines + else: + return [] + + # + def filter_out_kernels(self, kernel_name, kernel_filter_list): + + for kernel_filter_re in kernel_filter_list: + if kernel_filter_re.search(kernel_name) is not None: + return True + + return False + + + # + def _filter_string_matches(self, filter_string, haystack): + ''' Returns true if all substrings appear in the haystack in order''' + substrings = filter_string.split('*') + for sub in substrings: + idx = haystack.find(sub) + if idx < 0: + return False + haystack = haystack[idx + len(sub):] + return True + + # + def filter(self, operation): + ''' Filtering operations based on various criteria''' + + # filter based on compute capability + enabled = not (self.filter_by_cc) + + for cc in self.compute_capabilities_baseline: + + if cc >= operation.tile_description.minimum_compute_capability and \ + cc <= operation.tile_description.maximum_compute_capability and \ + (cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)): + + enabled = True + break + + if not enabled: + return False + + if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled: + return False + + name = operation.procedural_name() + + # eliminate duplicates + if name in self.operations_by_name.keys(): + return False + + # Filter based on list of valid substrings + if len(self.kernel_names): + enabled = False + + # compare against the include list + for name_substr in self.kernel_names: + if self._filter_string_matches(name_substr, name): + _LOGGER.debug(f"Kernel {name} included due to filter string '{name_substr}'.") + enabled = True + break + else: + _LOGGER.debug(f"Kernel {name} NOT included due to not matching '{name_substr}'.") + + # compare against the exclude list + for name_substr in self.ignore_kernel_names: + if self._filter_string_matches(name_substr, name): + _LOGGER.debug(f"Kernel {name} ignored due to filter string '{name_substr}'.") + enabled = False + break + else: + _LOGGER.debug(f"Kernel {name} NOT ignored due to not matching '{name_substr}'.") + + if len(self.kernel_filter_list) > 0: + if self.filter_out_kernels(name, self.kernel_filter_list): + _LOGGER.debug(f"Kernel {name} matched via kernel filter file.") + enabled = True + else: + _LOGGER.debug(f"Kernel {name} culled due to no match in kernel filter file.") + enabled = False + + # CUTLASS_LIBRARY_IGNORE_KERNELS ("ignore" list) only takes effect + # if CUTLASS_LIBRARY_KERNELS was specified. + # Changing that would break backwards compatibility. + # Thus, CUTLASS has introduced the new CMake option CUTLASS_LIBRARY_EXCLUDE_KERNELS, + # that always takes effect, whether or not CUTLASS_LIBRARY_KERNELS was specified. + for name_substr in self.exclude_kernel_names: + if self._filter_string_matches(name_substr, name): + _LOGGER.debug(f"Kernel {name} excluded due to filter string '{name_substr}'.") + enabled = False + break + else: + _LOGGER.debug(f"Kernel {name} NOT excluded due to not matching '{name_substr}'.") + + # TODO: filter based on compute data type + return enabled + # + + # + def append(self, operation): + ''' + Inserts the operation. + + operation_kind -> configuration_name -> [] + ''' + + if self.filter(operation): + + self.selected_kernels.append(operation.procedural_name()) + + self.operations_by_name[operation.procedural_name()] = operation + + # add the configuration + configuration_name = operation.configuration_name() + + # Split operations by minimum CC + min_cc = operation.arch + + if operation.operation_kind not in self.operations.keys(): + self.operations[operation.operation_kind] = {} + + if min_cc not in self.operations[operation.operation_kind]: + self.operations[operation.operation_kind][min_cc] = {} + + if configuration_name not in self.operations[operation.operation_kind][min_cc].keys(): + self.operations[operation.operation_kind][min_cc][configuration_name] = [] + + self.operations[operation.operation_kind][min_cc][configuration_name].append(operation) + self.operation_count += 1 + else: + _LOGGER.debug("Culled {} from manifest".format(operation.procedural_name())) + # + + def emit_manifest_cmake(self, manifest_path, top_level_path, source_files): + with open(manifest_path, "w") as manifest_file: + + target_text = SubstituteTemplate("""cutlass_target_sources(cutlass_library_objs PRIVATE + """, { }) + manifest_file.write(target_text + '\n\n') + manifest_file.write(" %s\n" % str(top_level_path.replace('\\', '/'))) + generated_path = os.path.join(self.curr_build_dir, 'generated') + for kind in self.operations.keys(): + kind_str = OperationKindNames[kind] + all_kind_file = os.path.join(generated_path, kind_str, f"all_{kind_str}_operations.cu").replace('\\', '/') + manifest_file.write(f" {all_kind_file}\n") + manifest_file.write(')\n\n') + + for kind in self.operations.keys(): + for min_cc in sorted(self.operations[kind].keys()): + for subclass in sorted(source_files[kind][min_cc].keys()): + target_text = SubstituteTemplate("""cutlass_add_cutlass_library( + SUFFIX ${kind}_sm${min_cc}_${subclass} +""", { 'min_cc': str(min_cc), 'kind': OperationKindNames[kind], 'subclass': subclass }) + manifest_file.write(target_text + '\n\n') + + for source_file in source_files[kind][min_cc][subclass]: + manifest_file.write(" %s\n" % str(source_file.replace('\\', '/'))) + + manifest_file.write(")\n") + + if self.disable_full_archs_compilation: + self.emit_disable_full_archs_compilation(manifest_file, source_files) + + def emit_disable_full_archs_compilation(manifest_file, source_files): + def for_hopper(name): + pass + + def for_ampere(name): + return "16816" in name or \ + "16832" in name or \ + "16864" in name or \ + ("1688" in name and "tf32" in name) + + def for_turing(name): + return ("1688" in name and "tf32" not in name) or \ + "8816" in name + + def for_volta(name): + return "884" in name + + def is_cpp(name): + return name.endswith(".cpp") + + def get_src_archs_str_given_requested_cuda_archs(archs, source_file): + intersected_archs = archs & set(self.compute_capabilities_baseline) + if intersected_archs == set(): + raise RuntimeError( + """ + Empty archs set for file {} after taking + the intersection of {} (global requested archs) and + {} (per file requested archs) + """.format(source_file, set(self.compute_capabilities_baseline), archs)) + else: + return " ".join(map(str, intersected_archs)) + + for min_cc in sorted(source_files.keys()): + for source_file in source_files[min_cc]: + if is_cpp(source_file): + continue # skip because source is cpp + elif for_ampere(source_file): + archs_str = get_src_archs_str_given_requested_cuda_archs({80, 87, 90}, source_file) + elif for_turing(source_file): + archs_str = get_src_archs_str_given_requested_cuda_archs({75}, source_file) + elif for_volta(source_file): + archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file) + else: + raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file)) + + manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str)) + + # + def emit(self, target = GeneratorTarget.Library): + + operation_emitters = { + GeneratorTarget.Library: EmitOperationKindLibrary + } + + # Emitters for all operations that fall under a particular kind (e.g., GEMM, Conv2d) + kind_emitters = { + GeneratorTarget.Library: EmitOperationKindAll + } + + interface_emitters = { + GeneratorTarget.Library: EmitInterfaceLibrary + } + + generated_path = os.path.join(self.curr_build_dir, 'generated') + + # create generated/ + if os.path.exists(generated_path): + shutil.rmtree(generated_path) + + os.mkdir(generated_path) + + with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter: + top_level_path = iface_emitter.top_level_path + for operation_kind in self.operations.keys(): + iface_emitter.emit(OperationKindNames[operation_kind]) + + source_files = {} + for kind in self.operations.keys(): + source_files[kind] = {} + for min_cc in self.operations[kind].keys(): + source_files[kind][min_cc] = {} + + for operation_kind, ops in self.operations.items(): + for min_cc, configurations in sorted(ops.items()): + with operation_emitters[target](generated_path, min_cc, operation_kind, self.args) as operation_kind_emitter: + for configuration_name, operations in configurations.items(): + _LOGGER.info(f"Emitting {configuration_name} with {len(operations)} operation{'' if len(operations) == 1 else 's'}.") + operation_kind_emitter.emit(configuration_name, operations) + + for subclass, files in operation_kind_emitter.source_files.items(): + if subclass not in source_files[operation_kind][min_cc]: + source_files[operation_kind][min_cc][subclass] = [] + source_files[operation_kind][min_cc][subclass].extend(operation_kind_emitter.source_files[subclass]) + + # Emit top level all_{gemm, conv2d, ...}_operations.cu files + with kind_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter: + operation_kind_emitter.emit(ops) + + # write the manifest.cmake file containing paths from all targets + manifest_path = os.path.join(generated_path, "manifest.cmake") + + self.emit_manifest_cmake(manifest_path, top_level_path, source_files) + +################################################################################################### diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..29ef056f26f914a9c3c33e13900c33642ad2f1b7 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py @@ -0,0 +1,438 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Rank2K kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a Rank K update operation +# +################################################################################################### + +# +class Rank2KOperation: + # + def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + blas_mode = BlasMode.symmetric): + + self.blas_mode = blas_mode + self.operation_kind = OperationKind.Rank2K + self.arch = arch + self.tile_description = tile_description + self.rank_k_kind = rank_k_kind + # tensor A and B have same data type and layout + self.A = A + self.B = A + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def is_planar_complex(self): + return False + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + operation_name = 'syr2k' if self.blas_mode == BlasMode.symmetric else 'her2k' + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] + ) + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.C.fill_mode]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = max([self.A.alignment, self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'fill_mode': self.fill_mode_name(), + 'alignment': "%d" % self.A.alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitRank2KUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.rank_k_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Rank2K< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.rank_k_complex_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Rank2K< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation}, + ${transform_a}, + ${transform_b}, + ${blas_mode} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + + warp_count = operation.tile_description.warp_count + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'fill_mode': FillModeTag[operation.C.fill_mode], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'blas_mode': BlasModeTag[operation.blas_mode] + } + + rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template + + return SubstituteTemplate(rank_k_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitRank2KConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + RankKKind.Universal: EmitRank2KUniversalInstance, + } + + self.rank_k_kind_wrappers = { + RankKKind.Universal: 'Rank2KOperation', + } + + self.instance_template = { + RankKKind.Universal: """ +${compile_guard_start} + manifest.append(new ${rank_k_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by rank_2k_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "rank_2k_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.rank_k_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..9841952332a170d6f401dbe34a0093540c166fb8 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py @@ -0,0 +1,427 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting RankK kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a Rank K update operation +# +################################################################################################### + +# +class RankKOperation: + # + def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + blas_mode = BlasMode.symmetric): + + self.blas_mode = blas_mode + self.operation_kind = OperationKind.RankK + self.arch = arch + self.tile_description = tile_description + self.rank_k_kind = rank_k_kind + self.A = A + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_mixed_input(self): + return False + + # + def is_planar_complex(self): + return False + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + operation_name = 'syrk' if self.blas_mode == BlasMode.symmetric else 'herk' + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] + ) + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.C.fill_mode]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = max([self.A.alignment, self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'fill_mode': self.fill_mode_name(), + 'alignment': "%d" % self.A.alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitRankKUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.rank_k_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::RankK< + ${element_a}, ${layout_a}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.rank_k_complex_template = """ +// Rank K operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::RankK< + ${element_a}, ${layout_a}, + ${element_c}, ${layout_c}, ${fill_mode}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${split_k_serial}, + ${math_operation}, + ${transform_a}, + ${blas_mode} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + + warp_count = operation.tile_description.warp_count + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'fill_mode': FillModeTag[operation.C.fill_mode], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'blas_mode': BlasModeTag[operation.blas_mode] + } + + rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template + + return SubstituteTemplate(rank_k_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitRankKConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + RankKKind.Universal: EmitRankKUniversalInstance, + } + + self.rank_k_kind_wrappers = { + RankKKind.Universal: 'RankKOperation', + } + + self.instance_template = { + RankKKind.Universal: """ +${compile_guard_start} + manifest.append(new ${rank_k_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by rank_k_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "rank_k_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.rank_k_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..32e4376513679f06dc085ead068e258b3d8b5e72 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py @@ -0,0 +1,342 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Valid tcgen05 shapes and cluster sizes for SM100, associated with levels. +These shape and level pairs are defined as dicts, where keys are shapes and values are their +associated levels. If the user input level for that category (tcgen05 shape, cluster +size) is smaller than a shape's associated level, it will be excluded, and otherwise, included. +Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently. +Level 0 is always emitted. +""" + +try: + from .library import DynamicClusterShape +except: + from library import DynamicClusterShape + +SM100_CLUSTER_SHAPES_1SM = { + tuple(DynamicClusterShape) : 0, + # size 1 cluster + (1, 1, 1): 1, + # size 2 cluster + (1, 2, 1): 2, + (2, 1, 1): 5, + # size 4 clusters + (2, 2, 1): 6, + (1, 4, 1): 3, + (4, 1, 1): 6, + # size 8 clusters + (2, 4, 1): 7, + (4, 2, 1): 7, + (1, 8, 1): 8, + (8, 1, 1): 8, + # size 16 cluster + (4, 4, 1): 4, +} + +SM100_CLUSTER_SHAPES_2SM = { + tuple(DynamicClusterShape) : 0, + # size 2 cluster + (2, 1, 1): 1, + # size 4 clusters + (2, 2, 1): 2, + (4, 1, 1): 2, + # size 8 clusters + (2, 4, 1): 3, + (4, 2, 1): 3, + (8, 1, 1): 6, + # size 16 cluster + (4, 4, 1): 4, +} + +# MMA shapes + +# 16b Dense + +SM100_MMA_SHAPES_16b_DENSE_1SM = { + (64, 8, 16): 5, + (64, 16, 16): 2, + (64, 24, 16): 5, + (64, 32, 16): 2, + (64, 40, 16): 5, + (64, 48, 16): 5, + (64, 56, 16): 5, + (64, 64, 16): 2, + (64, 72, 16): 5, + (64, 80, 16): 5, + (64, 88, 16): 5, + (64, 96, 16): 5, + (64, 104, 16): 5, + (64, 112, 16): 5, + (64, 120, 16): 5, + (64, 128, 16): 0, + (64, 136, 16): 5, + (64, 144, 16): 5, + (64, 152, 16): 5, + (64, 160, 16): 5, + (64, 168, 16): 5, + (64, 176, 16): 5, + (64, 184, 16): 5, + (64, 192, 16): 3, + (64, 200, 16): 5, + (64, 208, 16): 5, + (64, 216, 16): 5, + (64, 224, 16): 5, + (64, 232, 16): 5, + (64, 240, 16): 5, + (64, 248, 16): 5, + (64, 256, 16): 3, + + (128, 16, 16): 2, + (128, 32, 16): 2, + (128, 48, 16): 5, + (128, 64, 16): 2, + (128, 80, 16): 5, + (128, 96, 16): 5, + (128, 112, 16): 5, + (128, 128, 16): 0, + (128, 144, 16): 5, + (128, 160, 16): 5, + (128, 176, 16): 5, + (128, 192, 16): 3, + (128, 208, 16): 5, + (128, 224, 16): 5, + (128, 240, 16): 5, + (128, 256, 16): 0, + +} + + +SM100_MMA_SHAPES_16b_DENSE_2SM = { + (128, 32, 16): 2, + (128, 64, 16): 2, + (128, 96, 16): 5, + (128, 128, 16): 0, + (128, 160, 16): 5, + (128, 192, 16): 5, + (128, 224, 16): 5, + (128, 256, 16): 0, + + (256, 32, 16): 2, + (256, 64, 16): 2, + (256, 96, 16): 5, + (256, 128, 16): 0, + (256, 160, 16): 5, + (256, 192, 16): 3, + (256, 224, 16): 5, + (256, 256, 16): 0, +} + +# TF32 Dense + +SM100_MMA_SHAPES_TF32_DENSE_1SM = { + (64, 8, 8): 5, + (64, 16, 8): 2, + (64, 24, 8): 5, + (64, 32, 8): 2, + (64, 40, 8): 5, + (64, 48, 8): 5, + (64, 56, 8): 5, + (64, 64, 8): 1, + (64, 72, 8): 5, + (64, 80, 8): 5, + (64, 88, 8): 5, + (64, 96, 8): 5, + (64, 104, 8): 5, + (64, 112, 8): 5, + (64, 120, 8): 5, + (64, 128, 8): 0, + (64, 136, 8): 5, + (64, 144, 8): 5, + (64, 152, 8): 5, + (64, 160, 8): 5, + (64, 168, 8): 5, + (64, 176, 8): 5, + (64, 184, 8): 5, + (64, 192, 8): 3, + (64, 200, 8): 5, + (64, 208, 8): 5, + (64, 216, 8): 5, + (64, 224, 8): 5, + (64, 232, 8): 5, + (64, 240, 8): 5, + (64, 248, 8): 5, + (64, 256, 8): 3, + + (128, 16, 8): 2, + (128, 32, 8): 2, + (128, 48, 8): 5, + (128, 64, 8): 2, + (128, 80, 8): 5, + (128, 96, 8): 5, + (128, 112, 8): 5, + (128, 128, 8): 0, + (128, 144, 8): 5, + (128, 160, 8): 5, + (128, 176, 8): 5, + (128, 192, 8): 3, + (128, 208, 8): 5, + (128, 224, 8): 5, + (128, 240, 8): 5, + (128, 256, 8): 0, + +} + +SM100_MMA_SHAPES_TF32_DENSE_2SM = { + (128, 32, 8): 2, + (128, 64, 8): 1, + (128, 96, 8): 5, + (128, 128, 8): 0, + (128, 160, 8): 5, + (128, 192, 8): 5, + (128, 224, 8): 5, + (128, 256, 8): 0, + + (256, 32, 8): 2, + (256, 64, 8): 1, + (256, 96, 8): 5, + (256, 128, 8): 0, + (256, 160, 8): 5, + (256, 192, 8): 5, + (256, 224, 8): 5, + (256, 256, 8): 0, +} + +# F8F6F4 +SM100_MMA_SHAPES_F8F6F4_DENSE_1SM = { + (64, 8, 32): 4, + (64, 16, 32): 4, + (64, 24, 32): 5, + (64, 32, 32): 3, + (64, 40, 32): 5, + (64, 48, 32): 5, + (64, 56, 32): 5, + (64, 64, 32): 2, + (64, 72, 32): 5, + (64, 80, 32): 5, + (64, 88, 32): 5, + (64, 96, 32): 5, + (64, 104, 32): 5, + (64, 112, 32): 5, + (64, 120, 32): 5, + (64, 128, 32): 0, + (64, 136, 32): 5, + (64, 144, 32): 5, + (64, 152, 32): 5, + (64, 160, 32): 5, + (64, 168, 32): 5, + (64, 176, 32): 5, + (64, 184, 32): 5, + (64, 192, 32): 5, + (64, 200, 32): 5, + (64, 208, 32): 5, + (64, 216, 32): 5, + (64, 224, 32): 5, + (64, 232, 32): 5, + (64, 240, 32): 5, + (64, 248, 32): 5, + (64, 256, 32): 0, + + (128, 16, 32): 4, + (128, 32, 32): 3, + (128, 48, 32): 5, + (128, 64, 32): 2, + (128, 80, 32): 5, + (128, 96, 32): 5, + (128, 112, 32): 5, + (128, 128, 32): 0, + (128, 144, 32): 5, + (128, 160, 32): 5, + (128, 176, 32): 5, + (128, 192, 32): 5, + (128, 208, 32): 5, + (128, 224, 32): 5, + (128, 240, 32): 5, + (128, 256, 32): 0, + +} + +SM100_MMA_SHAPES_F8F6F4_DENSE_2SM = { + (128, 32, 32): 3, + (128, 64, 32): 2, + (128, 96, 32): 5, + (128, 128, 32): 1, + (128, 160, 32): 5, + (128, 192, 32): 5, + (128, 224, 32): 5, + (128, 256, 32): 1, + + (256, 32, 32): 2, + (256, 64, 32): 2, + (256, 96, 32): 5, + (256, 128, 32): 0, + (256, 160, 32): 5, + (256, 192, 32): 5, + (256, 224, 32): 5, + (256, 256, 32): 0, +} + +# MXF8F6F4 +SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM = { + (128, 64, 32): 1, + (128, 128, 32): 0, + (128, 192, 32): 1, + (128, 256, 32): 0, +} + + +SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM = { + (256, 64, 32): 1, + (256, 128, 32): 0, + (256, 192, 32): 1, + (256, 256, 32): 0, + + +} + +# MXF4NVF4 +SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM = { + (128, 64, 64): 1, + (128, 128, 64): 0, + (128, 192, 64): 1, + (128, 256, 64): 0, +} + +SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM = { + # Multiples of 16 for N + (256, 64, 64): 1, + (256, 128, 64): 0, + (256, 192, 64): 1, + (256, 256, 64): 0, + +} diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf24fe7f528020be4dcfc6ac41cfe949dd63be5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py @@ -0,0 +1,661 @@ +################################################################################################# +# +# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library SM100 kernels +""" + +import argparse +import enum +from itertools import product +import math +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Optional, Sequence, Tuple, List, Union, Callable + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +#### Step 0: define levels + +# One integer level controls multiple "generators" and how many +# combinations they generate. That is the "global" level. +# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and +# anything that is eventually involved in the Cartesian product +# which yields our kernel configurations. +# For simplicity, each generator defines their own levels, +# starting from 0. As a rule we assume 10 or fewer levels, making +# their level a digit. +# The "global" level simply stacks these digits and represents them +# as a single integer. +# +# For example, level 500 indicates cluster sizes are at level 5, MMA +# multipliers are at level 0, and WGMMA shapes are at level 0 as well. +# +# Here we define the global level to generator level mappings. + + +def get_tcgen05_level_from_global_level(global_level: int): + return global_level % 10 + +def get_mma_level_from_global_level(global_level: int): + return (global_level // 10) % 10 + + +def get_cluster_level_from_global_level(global_level: int): + return (global_level // 100) % 10 + + +def get_pruning_level_from_global_level(global_level: int): + return (global_level // 1000) % 10 + + +#### Step 1: generate MMA instruction shapes based on levels + +try: + from .sm100_shapes import * +except: + from sm100_shapes import * + +########### + +def generate_tf32_math_instructions_sm100(level: int): + """ + Generate all TensorOp math instructions for TF32 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + for shape in shapes_2sm: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_16b_math_instructions_sm100(level: int): + """ + Generate all TensorOp math instructions for 16b MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + + for shape in shapes_2sm: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + + +def generate_fp8_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all TensorOp math instructions for FP8 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if enable_compile_time_dtype: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if pruning_level >= 2: + math_instructions_1sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + for shape in shapes_2sm: + if enable_runtime_dtype: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.f8, DataType.f8, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if enable_compile_time_dtype: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + if pruning_level >= 2: + math_instructions_2sm.append( + MathInstruction( + shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_f8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all TensorOp math instructions for FP8 FP6 and FP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_mxf8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all BlockScaledTensorOp math instructions for MXFP8, MXFP6, and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + + tcgen05_level = get_tcgen05_level_from_global_level(level) + pruning_level = get_pruning_level_from_global_level(level) + + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + DataType.e5m2, + DataType.e3m2, + DataType.e2m3, + DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + + if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)): + continue + + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e4m3, + DataType.e5m2, + DataType.e3m2, + DataType.e2m3, + DataType.e2m1 ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + + return math_instructions_1sm, math_instructions_2sm + +def generate_mxf4nvf4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True): + """ + Generate all BlockScaledTensorOp math instructions for MXFP4 and MXFP4 MMA that are supported by SM100 at or above the given level. + + Args: + level: The global level to generate math instructions for. + enable_runtime_dtype: Whether to generate runtime dtype math instructions. + enable_compile_time_dtype: Whether to generate compile time dtype math instructions. + + Returns: + A tuple of two lists of MathInstruction objects. + The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM. + """ + tcgen05_level = get_tcgen05_level_from_global_level(level) + math_instructions_1sm = [] + math_instructions_2sm = [] + + shapes_1sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM.items() if tcgen05_level >= min_level + ] + shapes_2sm = [ + shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM.items() if tcgen05_level >= min_level + ] + + for shape in shapes_1sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_1sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + for shape in shapes_2sm: + if enable_runtime_dtype: + + runtime_types = [ DataType.f4 ] + + for a_type, b_type in product(runtime_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + if enable_compile_time_dtype: + compile_time_types = [ DataType.e2m1, + ] + + for a_type, b_type in product(compile_time_types, repeat=2): + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue8m0) + ) + math_instructions_2sm.append( + MathInstruction( + shape, + a_type, b_type, DataType.f32, + OpcodeClass.BlockScaledTensorOp, + MathOperation.multiply_add, + DataType.ue4m3) + ) + + + return math_instructions_1sm, math_instructions_2sm + + +def generate_cluster_shapes_sm100(level: int, change_priority_func : Union[Callable, None] = None): + """ + Generate all cluster shapes for SM100 at or above the given level. + + Args: + level: The global level to generate cluster shapes for. + + Returns: + A tuple of two lists of cluster shapes. + The first list contains the cluster shapes for 1SM, and the second list contains the cluster shapes for 2SM. + """ + cluster_level = get_cluster_level_from_global_level(level) + + assert cluster_level >= 4 + + if change_priority_func is not None: + SM100_CLUSTER_SHAPES_1SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_1SM) + SM100_CLUSTER_SHAPES_2SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_2SM) + change_priority_func(SM100_CLUSTER_SHAPES_1SM_CPY, SM100_CLUSTER_SHAPES_2SM_CPY) + shapes_1sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM_CPY.items() if cluster_level >= min_level + ] + shapes_2sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM_CPY.items() if cluster_level >= min_level + ] + + return shapes_1sm, shapes_2sm + + else: + + shapes_1sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM.items() if cluster_level >= min_level + ] + shapes_2sm = [ + list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM.items() if cluster_level >= min_level + ] + + return shapes_1sm, shapes_2sm diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..e14761aae6494f877e6dc6521b30baea0db7509c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py @@ -0,0 +1,212 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Valid WGMMA shapes, MMA multipliers, and cluster sizes for SM90, associated with levels. +These shape and level pairs are defined as dicts, where keys are shapes and values are their +associated levels. If the user input level for that category (MMA multiplier, WGMMA shape, cluster +size) is smaller than a shape's associated level, it will be excluded, and otherwise, included. +Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently. +Level 0 is always emitted. The default behavior in `generator.py` is that level 1 is only emitted +when the `--kernel` argument is non-empty. +""" + +# NOTE: more combinations are possible here. +# Levels [0, 3] exist in order to control exactly what configs are generated in different dtypes. +# The rest are only used in the exhaustive mode (when the corresponding level digit is 9). +# MMA multipliers are multiplied by MMA instruction shapes (WGMMA shapes) to get CTA shapes. +SM90_MMA_MULTIPLIERS = { + (2, 1, 4): 0, + (1, 1, 4): 1, + (4, 1, 4): 2, + (2, 2, 4): 3, + (2, 1, 8): 4, + (4, 1, 8): 4, + (1, 1, 8): 4, + (2, 2, 8): 4, + (2, 1, 16): 5, + (4, 1, 16): 5, + (1, 1, 16): 5, + (2, 2, 16): 5, +} + +# Level 0: only (1, 2, 1) -- fp8 dense gemms in pruned case +# Level 1: clusters with 2 CTAs -- all but fp8 (s8, u8, f16, b16, f32, tf32) dense gemms in pruned case +# Level 2: clusters with 1 or 2 CTAs +# Level 3: clusters with 1, 2, or 4 CTAs +# Level 4: clusters with 1, 2, 4, or 8 CTAs +# Level 5: clusters with 1, 2, 4, 8, or 16 CTAs +SM90_CLUSTER_SIZES = { + (1, 2, 1): 0, + (2, 1, 1): 1, + (1, 1, 1): 2, + (2, 2, 1): 3, + (1, 4, 1): 3, + (4, 1, 1): 3, + (2, 4, 1): 4, + (4, 2, 1): 4, + (1, 8, 1): 4, + (8, 1, 1): 4, + (4, 4, 1): 5, +} + + +# WGMMA shapes +# Level 0: "default" shape only, +# Level 1: additional shapes for the unpruned case (tf32 only) +# Level 2: shapes that are all powers of 2 +# Level 3: all other shapes +SM90_WGMMA_SHAPES_FP16_BF16_DENSE = { + (64, 8, 16): 2, + (64, 16, 16): 2, + (64, 24, 16): 3, + (64, 32, 16): 2, + (64, 40, 16): 3, + (64, 48, 16): 3, + (64, 56, 16): 3, + (64, 64, 16): 2, + (64, 72, 16): 3, + (64, 80, 16): 3, + (64, 88, 16): 3, + (64, 96, 16): 3, + (64, 104, 16): 3, + (64, 112, 16): 3, + (64, 120, 16): 3, + (64, 128, 16): 0, + (64, 136, 16): 3, + (64, 144, 16): 3, + (64, 152, 16): 3, + (64, 160, 16): 3, + (64, 168, 16): 3, + (64, 176, 16): 3, + (64, 184, 16): 3, + (64, 192, 16): 3, + (64, 200, 16): 3, + (64, 208, 16): 3, + (64, 216, 16): 3, + (64, 224, 16): 3, + (64, 232, 16): 3, + (64, 240, 16): 3, + (64, 248, 16): 3, + (64, 256, 16): 1, +} + +SM90_WGMMA_SHAPES_TF32_DENSE = { + (64, 8, 8): 2, + (64, 16, 8): 2, + (64, 24, 8): 3, + (64, 32, 8): 2, + (64, 40, 8): 3, + (64, 48, 8): 3, + (64, 56, 8): 3, + (64, 64, 8): 2, + (64, 72, 8): 3, + (64, 80, 8): 3, + (64, 88, 8): 3, + (64, 96, 8): 3, + (64, 104, 8): 3, + (64, 112, 8): 3, + (64, 120, 8): 3, + (64, 128, 8): 0, + (64, 136, 8): 3, + (64, 144, 8): 3, + (64, 152, 8): 3, + (64, 160, 8): 3, + (64, 168, 8): 3, + (64, 176, 8): 3, + (64, 184, 8): 3, + (64, 192, 8): 3, + (64, 200, 8): 3, + (64, 208, 8): 3, + (64, 216, 8): 3, + (64, 224, 8): 3, + (64, 232, 8): 3, + (64, 240, 8): 3, + (64, 248, 8): 3, + (64, 256, 8): 1, +} + +SM90_WGMMA_SHAPES_FP8_DENSE = { + (64, 8, 32): 2, + (64, 16, 32): 2, + (64, 24, 32): 3, + (64, 32, 32): 2, + (64, 40, 32): 3, + (64, 48, 32): 3, + (64, 56, 32): 3, + (64, 64, 32): 2, + (64, 72, 32): 3, + (64, 80, 32): 3, + (64, 88, 32): 3, + (64, 96, 32): 3, + (64, 104, 32): 3, + (64, 112, 32): 3, + (64, 120, 32): 3, + (64, 128, 32): 0, + (64, 136, 32): 3, + (64, 144, 32): 3, + (64, 152, 32): 3, + (64, 160, 32): 3, + (64, 168, 32): 3, + (64, 176, 32): 3, + (64, 184, 32): 3, + (64, 192, 32): 3, + (64, 200, 32): 3, + (64, 208, 32): 3, + (64, 216, 32): 3, + (64, 224, 32): 3, + (64, 232, 32): 3, + (64, 240, 32): 3, + (64, 248, 32): 3, + (64, 256, 32): 1, +} + +SM90_WGMMA_SHAPES_INT8_DENSE = { + (64, 8, 32): 2, + (64, 16, 32): 2, + (64, 24, 32): 3, + (64, 32, 32): 2, + (64, 48, 32): 3, + (64, 64, 32): 2, + (64, 80, 32): 3, + (64, 96, 32): 3, + (64, 112, 32): 3, + (64, 128, 32): 0, + (64, 144, 32): 3, + (64, 160, 32): 3, + (64, 176, 32): 3, + (64, 192, 32): 3, + (64, 208, 32): 3, + (64, 224, 32): 3, + (64, 240, 32): 3, + (64, 256, 32): 1, +} diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5fdf14abb85835f71ecfd704a2738f5792af50 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py @@ -0,0 +1,753 @@ +################################################################################################# +# +# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for enumerating CUTLASS library SM90 kernels +""" + +import argparse +import enum +from itertools import product +import math +import logging +import os.path +import shutil +import sys +import copy +from typing import Any, Optional, Sequence, Tuple, List + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + +# NOTE: this is a duplicate of CudaToolkitVersionSatisfies in generator.py +def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): + + # by default, use the latest CUDA Toolkit version + cuda_version = [11, 0, 132] + + # Update cuda_version based on parsed string + if semantic_ver_string != '': + for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]): + if i < len(cuda_version): + cuda_version[i] = x + else: + cuda_version.append(x) + return cuda_version >= [major, minor, patch] + +#### Step 0: define levels + +# One integer level controls multiple "generators" and how many +# combinations they generate. That is the "global" level. +# "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and +# anything that is eventually involved in the Cartesian product +# which yields our kernel configurations. +# For simplicity, each generator defines their own levels, +# starting from 0. As a rule we assume 10 or fewer levels, making +# their level a digit. +# The "global" level simply stacks these digits and represents them +# as a single integer. +# +# For example, level 500 indicates cluster sizes are at level 5, MMA +# multipliers are at level 0, and WGMMA shapes are at level 0 as well. +# +# Here we define the global level to generator level mappings. + + +def get_wgmma_level_from_global_level(global_level: int): + return global_level % 10 + + +def get_mma_level_from_global_level(global_level: int): + return (global_level // 10) % 10 + + +def get_cluster_level_from_global_level(global_level: int): + return (global_level // 100) % 10 + + +def get_pruning_level_from_global_level(global_level: int): + return (global_level // 1000) % 10 + + +#### Step 1: generate MMA instruction shapes based on levels + +try: + from .sm90_shapes import ( + SM90_MMA_MULTIPLIERS, + SM90_CLUSTER_SIZES, + SM90_WGMMA_SHAPES_TF32_DENSE, + SM90_WGMMA_SHAPES_FP16_BF16_DENSE, + SM90_WGMMA_SHAPES_FP8_DENSE, + SM90_WGMMA_SHAPES_INT8_DENSE, + ) +except: + from sm90_shapes import ( + SM90_MMA_MULTIPLIERS, + SM90_CLUSTER_SIZES, + SM90_WGMMA_SHAPES_TF32_DENSE, + SM90_WGMMA_SHAPES_FP16_BF16_DENSE, + SM90_WGMMA_SHAPES_FP8_DENSE, + SM90_WGMMA_SHAPES_INT8_DENSE, + ) + + +def generate_tf32_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_TF32_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_fp16_bf16_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP16_BF16_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_fp8_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP8_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_int8_math_instruction_shapes_sm90(level: int): + assert isinstance(level, int) and level >= 0 + filtered_list_of_wgmma_shapes = [ + wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_INT8_DENSE.items() if level >= min_level + ] + return filtered_list_of_wgmma_shapes + +def generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level: int, a_type: DataType, b_type: DataType): + # DataTypeSize are in the unit of bits + a_bytes = DataTypeSize[a_type] // 8 + b_bytes = DataTypeSize[b_type] // 8 + if a_bytes == 4 or b_bytes == 4: + return generate_tf32_math_instruction_shapes_sm90(wgmma_level) + elif a_bytes == 2 or b_bytes == 2: + return generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level) + else: + return generate_fp8_math_instruction_shapes_sm90(wgmma_level) + +########### + +def generate_tf32_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_tf32_math_instruction_shapes_sm90(wgmma_level): + math_instructions.append( + MathInstruction( + math_instruction_shape, + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + ) + return math_instructions + +def generate_fp16_bf16_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level): + math_instructions += [ + MathInstruction( + math_instruction_shape, + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + return math_instructions + +def generate_fp8_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_fp8_math_instruction_shapes_sm90(wgmma_level): + math_instructions += [ + MathInstruction( + math_instruction_shape, + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + return math_instructions + +def generate_mixed_dtype_math_instructions_sm90(level: int, types_of_a_b_acc: List[Tuple[DataType, DataType, DataType]]): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for a_type, b_type, acc_type in types_of_a_b_acc: + math_instruction_shapes = generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level, a_type, b_type) + for math_instruction_shape in math_instruction_shapes: + math_instructions += [ + MathInstruction( + math_instruction_shape, + a_type, b_type, acc_type, + OpcodeClass.TensorOp, + MathOperation.multiply_add + ), + ] + return math_instructions + +def generate_int8_math_instructions_sm90(level: int): + wgmma_level = get_wgmma_level_from_global_level(level) + math_instructions = [] + for math_instruction_shape in generate_int8_math_instruction_shapes_sm90(wgmma_level): + math_instructions += [ + MathInstruction( + math_instruction_shape, + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + math_instruction_shape, + DataType.u8, DataType.u8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + return math_instructions + +def make_sparse_math_instructions(math_instructions): + sparse_instructions = [] + for inst in math_instructions: + if inst.opcode_class == OpcodeClass.TensorOp: + sparse_instructions.append(MathInstruction( + (inst.instruction_shape[0], inst.instruction_shape[1], inst.instruction_shape[2] * 2), + inst.element_a, inst.element_b, inst.element_accumulator, + OpcodeClass.SparseTensorOp, + inst.math_operation),) + return sparse_instructions + + +#### Step 2: generate tile descriptions from math instruction shapes + +def is_tile_desc_valid(tile_description): + if tile_description.minimum_compute_capability != 90 or tile_description.maximum_compute_capability != 90: + return False + + element_a, element_b, element_accum = ( + tile_description.math_instruction.element_a, + tile_description.math_instruction.element_b, + tile_description.math_instruction.element_accumulator + ) + + cluster_size, cta_shape = ( + tile_description.cluster_shape, + tile_description.threadblock_shape, + ) + grid_size = ( + cta_shape[0] * cluster_size[0] + + cta_shape[1] * cluster_size[1] + + cta_shape[2] * cluster_size[2] + ) + num_ctas_in_cluster = cluster_size[0] * cluster_size[1] * cluster_size[2] + cluster_shape = ( + cluster_size[0] * cta_shape[0], + cluster_size[1] * cta_shape[1], + cluster_size[2] * cta_shape[2] + ) + + FP32_TYPES = [DataType.f32, DataType.tf32] + FP16_TYPES = [DataType.f16, DataType.bf16] + is_fp32 = element_a in FP32_TYPES and element_b in FP32_TYPES + is_fp16 = element_a in FP16_TYPES and element_b in FP16_TYPES + + # Maximum number of CTAs per cluster is 8 for Hopper, but up to 16 is + # allowed for non portable clusters. + if num_ctas_in_cluster > 16 or num_ctas_in_cluster < 1: + return False + + if grid_size < 1: + return False + + # SM90 WGMMA shapes are always 64 across M, therefore + # CTA shape across M must always be a multiple of 64. + if cta_shape[0] < 64 or cta_shape[0] % 64 != 0: + return False + + # The minimum WGMMA shape across N is 8, and increments + # vary across different dtypes, but they're never smaller + # than 8. The minimum CTA shape allowed across N though is 16. + if cta_shape[1] < 16 or cta_shape[1] % 8 != 0: + return False + + # SM90 WGMMA shapes across K are always 8 for 32 bit dense + # operations, 16 for 16 bit, and 32 for 8 bit. In any case, + # the CTA shape across K should be a multiple of 8 and at least + # twice the WGMMA shape across K. + if cta_shape[2] < 16 or cta_shape[2] % 8 != 0: + return False + + # Minimum of 2 stages (very rough heuristic that may filter out valid kernel configs) + if (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 256: + return False + + if is_fp32 and (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 128: + return False + + if is_fp32 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 64: + return False + + if is_fp16 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 128: + return False + + # CTA shape upper bound: <256, 256, 256> + if cta_shape[0] > 256 or cta_shape[1] > 256 or cta_shape[2] > 256: + return False + + return True + +def get_mma_multipliers(level: int): + assert isinstance(level, int) and level >= 0 + mma_level = get_mma_level_from_global_level(level) + return [ + mma_mul for mma_mul, mma_min_level in SM90_MMA_MULTIPLIERS.items() if mma_level >= mma_min_level + ] + +def get_cluster_sizes(level: int, is_aligned: bool): + if not is_aligned: + return [(1, 1, 1)] + assert isinstance(level, int) and level >= 0 + cluster_level = get_cluster_level_from_global_level(level) + return [ + cluster_size for cluster_size, cluster_min_level in SM90_CLUSTER_SIZES.items() if cluster_level >= cluster_min_level + ] + +def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level: int): + tile_descriptions = set() + mma_multipliers, cluster_sizes = get_mma_multipliers(level), get_cluster_sizes(level, is_aligned) + for math_inst, mma_mul, cluster_size in product(math_instructions, mma_multipliers, cluster_sizes): + + # generator can stamp out duplicate kernels, because it doesn't explicitly set instruction + # shape for SM90 kernels, and the 3.X collective API doesn't directly expose them when using + # the auto kernel schedule. + + math_inst_stub = copy.deepcopy(math_inst) + math_inst_stub.instruction_shape = [0, 0, 0] + + tile_desc = TileDescription( + threadblock_shape=[ + math_inst.instruction_shape[0] * mma_mul[0], + math_inst.instruction_shape[1] * mma_mul[1], + math_inst.instruction_shape[2] * mma_mul[2] + ], + stages=0, + warp_count=[4, 1, 1], + math_instruction=math_inst_stub, + min_compute=90, + max_compute=90, + cluster_shape=cluster_size) + # For sparse kernels K-tile is twice as large (due to 2x MMA-K size) + # Reduce it to same size as dense to afford more smem stages + if math_inst.opcode_class == OpcodeClass.SparseTensorOp: + tile_desc.threadblock_shape[2] = tile_desc.threadblock_shape[2] // 2 + if is_tile_desc_valid(tile_desc): + tile_descriptions.add(tile_desc) + + return tile_descriptions + +#### Step 3: map tile description to valid schedules + +def is_tile_desc_compatible_with_cooperative(tile_description): + # Cooperative kernels require a minimum CTA-M of 128 + return tile_description.threadblock_shape[0] % 128 == 0 + + +def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types): + dtype_a, dtype_b, dtype_c, dtype_d, dtype_acc, dtype_epi = ( + data_types["a_type"], + data_types["b_type"], + data_types["c_type"], + data_types["d_type"], + data_types["acc_type"], + data_types["epi_type"] + ) + mn = tile_description.threadblock_shape[0] * tile_description.threadblock_shape[1] + bitsize_c, bitsize_d = DataTypeSize[dtype_c], DataTypeSize[dtype_d] + + shmem_bits_c, shmem_bits_d = bitsize_c * mn, bitsize_d * mn + shmem_bits_total = shmem_bits_c + shmem_bits_d + # Magic number: 2^20 + # Existing logic suggested that tile shape 256x128 (or 128x256) + # would run out of shmem if D is FP32, and source is needed. + # That would be 256 * 128 * 32 == 2^21 (~262 KB), which is over the limit. + # Hopper's max shmem size is 228 KB, and 2^20 ~= 131 KB. + # Since epilogue can't possibly use ALL of the shmem available + # we can just settle on 2^20 bits (~ 131 KB) being the upper bound + # we would allow for epilogue. + # This can be different for non-persistent kernels where epilogue and + # mainloop shmem is shared. + if shmem_bits_total > 2 ** 20: + return False + + return True + + +def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, layout, + instantiation_level, enable_fp8_fast_acc=True, gemm_kind=GemmKind.Universal3x): + # Level 0: prune according to existing generator.py behavior + # Level >= 1: no pruning + level = get_pruning_level_from_global_level(instantiation_level) + schedules = [] + stream_k_schedules = [] + + if not is_tile_desc_valid(tile_description): + return schedules, stream_k_schedules + + FP16_TYPES = [DataType.f16, DataType.bf16] + is_fp16 = data_types["a_type"] in FP16_TYPES and data_types["b_type"] in FP16_TYPES + + FP8_TYPES = [DataType.e4m3, DataType.e5m2] + is_fp8 = data_types["a_type"] in FP8_TYPES and data_types["b_type"] in FP8_TYPES + can_do_fp8_fast_accum = is_fp8 and enable_fp8_fast_acc + + FP32_TYPES = [DataType.f32, DataType.tf32] + is_fp32 = data_types["a_type"] in FP32_TYPES and data_types["b_type"] in FP32_TYPES + requires_transposed_epilogue = is_fp32 and layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.RowMajor + + can_do_cooperative = is_tile_desc_compatible_with_cooperative(tile_description) + can_do_tma_epilogue = is_aligned and not requires_transposed_epilogue and can_tile_desc_use_shmem_in_epilogue(tile_description, data_types) + + default_epilogue = EpilogueScheduleType.NoSmemWarpSpecialized if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed + auto_epilogue = EpilogueScheduleType.ScheduleAuto if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed + + cta_m, cta_n, cta_k = ( + tile_description.threadblock_shape[0], + tile_description.threadblock_shape[1], + tile_description.threadblock_shape[2] + ) + c_type = data_types["c_type"] + d_type = data_types["d_type"] + is_void_c = c_type == DataType.void + + # Filter out invalid kernels + is_nt = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.RowMajor + is_tn = layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.ColumnMajor + is_nn = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.ColumnMajor + + # static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, + # "Copy size must evenly divide SMEM tile."); + if is_fp32 and is_nt and (cta_n % cta_k != 0): + return [], [] + + # static_assert(!TransposeB || (cutlass::bits_to_bytes((size<1>(SmemLayoutB{}) * sizeof_bits::value))) == 128, + # "SmemLayoutB K must be 128bytes to be transposed.") + if is_fp32 and is_nt and cta_k != 32: + return [], [] + + # Static assert failure when instantiating SmemLayoutB + if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0): + return [], [] + + grouped = is_grouped(gemm_kind) + if grouped: + # the following cases are unsupported by grouped GEMM + if not is_aligned: + return [], [] + if requires_transposed_epilogue: + return [], [] + + # Early pruning + if level < 1: + # Don't stamp out FP16/BF16 kernels smaller than or equal to 64x128x64 + if is_fp16 and cta_m <= 64 and cta_n <= 128 and cta_k <= 64: + return [], [] + + # FP8 configs with CTA tile larger than or equal to 256x128x128 limit data types and schedules + is_large_fp8_tile = is_fp8 and cta_m >= 256 and cta_n >= 128 and cta_k >= 128 + if is_large_fp8_tile: + # Only void-C, and only FP8 outputs allowed + if not is_void_c or d_type not in FP8_TYPES: + return [], [] + if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue: + schedules = [] + if is_blockwise(gemm_kind): + schedules.append( + [ + to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + else: + schedules.append( + [ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + schedules.append( + [ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + return schedules, [] + return [], [] + + if is_fp8 and not is_large_fp8_tile: + valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16, DataType.void] + # Prune all configs with fp8 source, and all configs with non-fp8 output + # that have different dtypes for source and output. + if c_type not in valid_dtypes_for_c or (d_type not in FP8_TYPES and c_type != d_type): + return [], [] + + # FP32/TF32 kernels don't stamp out void-C + if is_fp32 and is_void_c: + return [], [] + + # Void-c only makes a difference for TMA epilogues + if is_void_c and not can_do_tma_epilogue: + return [], [] + + # For mixed input data types + a_type_size = DataTypeSize[data_types["a_type"]] + b_type_size = DataTypeSize[data_types["b_type"]] + if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1): + schedules = [] + stream_k_schedules = [] + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized + if a_type_size > b_type_size: + epilogue_schedule = EpilogueScheduleType.EpilogueTransposed + + if not is_blockwise(gemm_kind): + schedules.append([ + KernelScheduleType.TmaWarpSpecialized, + epilogue_schedule + ]) + schedules.append([ + KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule + ]) + if cta_m >= 128: + if a_type_size > b_type_size: + epilogue_schedule = EpilogueScheduleType.EpilogueTransposed + else: + epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative + if is_blockwise(gemm_kind): + schedules.append([ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, + epilogue_schedule + ]) + else: + schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule + ]) + return schedules, stream_k_schedules + + if not is_aligned and not is_blockwise(gemm_kind): + schedules = [[KernelScheduleType.CpAsyncWarpSpecialized, + default_epilogue]] + stream_k_schedules = [] + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative: + schedules.append([ + KernelScheduleType.CpAsyncWarpSpecializedCooperative, + default_epilogue + ]) + stream_k_schedules.append([ + KernelScheduleType.CpAsyncWarpSpecializedCooperative, + default_epilogue + ]) + + return schedules, stream_k_schedules + + schedules = [] + # Pruning: emit Void-C and Grouped kernels with persistent kernels only + if (level >= 1 or not is_void_c) and not grouped and not is_blockwise(gemm_kind): + # Pruning: don't stamp out fp8 kernels with auto schedule + if not is_fp8: + schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue]) + schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue]) + stream_k_schedules = [] + + if CudaToolkitVersionSatisfies(cuda_version, 12, 0): + if can_do_tma_epilogue: + assert not requires_transposed_epilogue + # Inconsistency: fp8 pingpong only gets stamped out with fast accum + if (not is_fp8 or level >= 1) and not is_blockwise(gemm_kind): + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped) + ]) + if can_do_fp8_fast_accum: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped) + ]) + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + # Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue + if not is_fp8 or level >= 1: + if not is_blockwise(gemm_kind): + schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) + else: + schedules.append([to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)]) + + if can_do_fp8_fast_accum: + if not grouped: + schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue]) + schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)]) + + if can_do_cooperative: + if is_blockwise(gemm_kind): + schedules.append([ + to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(default_epilogue, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, + default_epilogue + ]) + else: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(default_epilogue, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + default_epilogue + ]) + if can_do_fp8_fast_accum: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(default_epilogue, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, + default_epilogue + ]) + + # persistent kernels with TMA epilogues + if can_do_tma_epilogue: + assert not requires_transposed_epilogue + if can_do_cooperative: + if is_blockwise(gemm_kind): + schedules.append([ + to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecializedCooperative + ]) + else: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperative, + EpilogueScheduleType.TmaWarpSpecializedCooperative + ]) + if can_do_fp8_fast_accum: + schedules.append([ + to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped), + to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped) + ]) + stream_k_schedules.append([ + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, + EpilogueScheduleType.TmaWarpSpecializedCooperative + ]) + # Grouped GEMM do not support Stream-K scheduler + if grouped: + return schedules, [] + return schedules, stream_k_schedules + + +#### Misc: helpers + +def generate_data_types_from_math_instruction(math_instruction, element_source = None, element_dest = None, element_epilogue = None): + element_a, element_b = math_instruction.element_a, math_instruction.element_b + element_accumulator = math_instruction.element_accumulator + element_c = element_source or element_accumulator + element_d = element_dest or element_accumulator + element_epilogue = element_epilogue or element_accumulator + data_types = { + "a_type" : element_a, + "b_type" : element_b, + "c_type" : element_c, + "d_type" : element_d, + "acc_type" : element_accumulator, + "epi_type" : element_epilogue + } + return data_types + +def fix_alignments(data_types, layout, alignment_bits = 128): + operand_keys = ["a_type", "b_type", "c_type"] + operands_to_fix = ["c_type"] + new_layout = [] + assert len(layout) == len(operand_keys) + for i, k in enumerate(operand_keys): + assert k in data_types and data_types[k] in DataTypeSize + dtype = data_types[k] + dtype_size_bits = DataTypeSize[dtype] + + layout_type = layout[i][0] + layout_alignment = layout[i][1] + + # Don't modify alignment if dtype's been changed to void + if k in operands_to_fix and dtype_size_bits >= 1: + layout_alignment = alignment_bits // dtype_size_bits + + new_layout.append([layout_type, layout_alignment]) + + return new_layout diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..8661ff798b2e3e0987fdf7e050b6ad2e0f8f3678 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py @@ -0,0 +1,440 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Symm kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a Symm update operation +# +################################################################################################### + +# +class SymmOperation: + # + def __init__(self, symm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \ + blas_mode = BlasMode.symmetric): + + self.blas_mode = blas_mode + self.operation_kind = OperationKind.Symm + self.arch = arch + self.tile_description = tile_description + self.symm_kind = symm_kind + # tensor A and B have same data type and layout + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def is_planar_complex(self): + return False + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + operation_name = 'symm' if self.blas_mode == BlasMode.symmetric else 'hemm' + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)] + ) + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + # + def side_mode_name(self): + return "%s" % (ShortSideModeNames[self.A.side_mode]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.A.fill_mode]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = self.C.alignment + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'side_mode': self.side_mode_name(), + 'fill_mode': self.fill_mode_name(), + 'alignment': "%d" % alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitSymmUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.symm_template = """ +// Symm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Symm< + ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.symm_complex_template = """ +// Symm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Symm< + ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation}, + ${blas_mode} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + + warp_count = operation.tile_description.warp_count + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'side_mode': SideModeTag[operation.A.side_mode], + 'fill_mode': FillModeTag[operation.A.fill_mode], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'blas_mode': BlasModeTag[operation.blas_mode] + } + + symm_template = self.symm_complex_template if operation.is_complex() else self.symm_template + + return SubstituteTemplate(symm_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitSymmConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + SymmKind.Universal: EmitSymmUniversalInstance, + } + + self.symm_kind_wrappers = { + SymmKind.Universal: 'SymmOperation', + } + + self.instance_template = { + SymmKind.Universal: """ +${compile_guard_start} + manifest.append(new ${symm_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by symm_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "symm_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.symm_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.symm_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'symm_kind': self.symm_kind_wrappers[operation.symm_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..46ba360cb615c955d329b390c0ab93d13ed88c7c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py @@ -0,0 +1,447 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for emitting Trmm kernels +""" + +import enum +import functools +import operator +import os.path +import shutil + +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * + + +################################################################################################### +# +# Data structure modeling a TRMM operation +# +################################################################################################### + +# +class TrmmOperation: + # + def __init__(self, trmm_kind, arch, tile_description, A, B, C, element_epilogue, \ + epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8): + + self.operation_kind = OperationKind.Trmm + self.arch = arch + self.tile_description = tile_description + self.trmm_kind = trmm_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.swizzling_functor = swizzling_functor + + # + def is_complex(self): + complex_operators = [ + MathOperation.multiply_add_complex, + MathOperation.multiply_add_complex_gaussian, + MathOperation.multiply_add_complex_fast_f32 + ] + return self.tile_description.math_instruction.math_operation in complex_operators + return False + + # + def is_planar_complex(self): +# return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray) + return False + + # + def is_mixed_input(self): + return self.A.element != self.B.element + + # + def accumulator_type(self): + accum = self.tile_description.math_instruction.element_accumulator + + if self.is_complex(): + return get_complex_from_real(accum) + + return accum + + # + def short_math_name(self): + if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian: + return "g%s" % ShortDataTypeNames[self.accumulator_type()] + return ShortDataTypeNames[self.accumulator_type()] + + + # + def core_name(self): + ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' + + inst_shape = '' + inst_operation = '' + intermediate_type = '' + + math_operations_map = { + MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' + } + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ + self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + + math_op = self.tile_description.math_instruction.math_operation + math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' + + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + inst_shape += math_op_string + + if self.tile_description.math_instruction.element_a != self.A.element and \ + self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator: + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + + return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, TrmmKindNames[self.trmm_kind]) + + # + def extended_name(self): + ''' Append data types if they differ from compute type. ''' + if self.is_complex(): + extended_name = "${core_name}" + else: + if self.C.element != self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${element_c}_${core_name}_${element_a}" + elif self.C.element == self.tile_description.math_instruction.element_accumulator and \ + self.A.element != self.tile_description.math_instruction.element_accumulator: + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = SubstituteTemplate(extended_name, { + 'element_a': DataTypeNames[self.A.element], + 'element_c': DataTypeNames[self.C.element], + 'core_name': self.core_name() + }) + + return extended_name + + # + def layout_name(self): + if self.is_complex() or self.is_planar_complex(): + return "%s%s" % ( + ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)], + ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)] + ) + return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout]) + + # + def side_mode_name(self): + return "%s" % (ShortSideModeNames[self.A.side_mode]) + + # + def fill_mode_name(self): + return "%s" % (ShortFillModeNames[self.A.fill_mode]) + + # + def diag_type_name(self): + return "%s" % (ShortDiagTypeNames[self.A.diag_type]) + + # + def procedural_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + threadblock = self.tile_description.procedural_name() + + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + alignment = max([self.C.alignment]) + + return SubstituteTemplate( + "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_${diag_type}_align${alignment}", + { + 'opcode_class': opcode_class_name, + 'extended_name': self.extended_name(), + 'threadblock': threadblock, + 'layout': self.layout_name(), + 'side_mode': self.side_mode_name(), + 'fill_mode': self.fill_mode_name(), + 'diag_type': self.diag_type_name(), + 'alignment': "%d" % self.C.alignment, + } + ) + + # + def configuration_name(self): + ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' + return self.procedural_name() + +################################################################################################### +# +# Emits single instances of a CUTLASS device-wide operator +# +################################################################################################### + +# +class EmitTrmmUniversalInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.trmm_template = """ +// Trmm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Trmm< + ${element_a}, ${layout_a}, + ${side_mode}, ${fill_mode}, ${diag_type}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation} +>; +""" + self.trmm_complex_template = """ +// Trmm operator ${operation_name} +using Operation_${operation_name} = + typename cutlass::gemm::device::Trmm< + ${element_a}, ${layout_a}, + ${side_mode}, ${fill_mode}, ${diag_type}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue}, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${split_k_serial}, + ${math_operation}, + ${transform_a} +>; +""" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'side_mode' : SideModeTag[operation.A.side_mode], + 'fill_mode': FillModeTag[operation.A.fill_mode], + 'diag_type' : DiagTypeTag[operation.A.diag_type], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(1), # TRMM A's alignment is always 1 for no padding to work until we make zfill work with variable bytes + 'align_b': str(operation.B.alignment), + 'split_k_serial': 'false', + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'transform_a': ComplexTransformTag[operation.A.complex_transform] + } + + trmm_template = self.trmm_complex_template if operation.is_complex() else self.trmm_template + + return SubstituteTemplate(trmm_template, values) + +################################################################################################### + + +################################################################################################### +# +# Emitters functions for all targets +# +################################################################################################### + +class EmitTrmmConfigurationLibrary: + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/') + + self.instance_emitter = { + TrmmKind.Universal: EmitTrmmUniversalInstance, + } + + self.trmm_kind_wrappers = { + TrmmKind.Universal: 'TrmmOperation', + } + + self.instance_template = { + TrmmKind.Universal: """ +${compile_guard_start} + manifest.append(new ${trmm_kind}< + Operation_${operation_name} + >("${operation_name}")); +${compile_guard_end} +""" + } + + self.header_template = """ +/* + Generated by trmm_operation.py - Do not edit. +*/ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "library_internal.h" +#include "trmm_operation.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.initialize_function_template = """ + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_${configuration_name}(Manifest &manifest) { + +""" + self.epilogue_template = """ + +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + + def emit(self, operation): + emitter = self.instance_emitter[operation.trmm_kind]() + + self.operations.append(operation) + + self.instance_definitions.append(emitter.emit(operation)) + + self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.trmm_kind], { + 'configuration_name': self.configuration_name, + 'operation_name': operation.procedural_name(), + 'trmm_kind': self.trmm_kind_wrappers[operation.trmm_kind], + 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "", + 'compile_guard_end': "#endif" \ + if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "" + })) + + def __exit__(self, exception_type, exception_value, traceback): + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + # Add wrapper objects within initialize() function + self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, { + 'configuration_name': self.configuration_name + })) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + +################################################################################################### diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..c396d75a5534493f1ebf90043f2a182eb46abb7f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py @@ -0,0 +1,132 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath('..')) +sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath('../../media/docs')) + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = 'CUTLASS Python interface' +copyright = '2023, NVIDIA' +author = 'NVIDIA' +release = '3.1.0' + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'myst_parser', + 'nbsphinx', + 'nbsphinx_link', + 'sphinx_copybutton', + 'sphinx.ext.autodoc', + 'sphinx.ext.autosectionlabel', + 'sphinx.ext.autosummary', + 'sphinx.ext.coverage', + 'sphinx.ext.extlinks', + 'sphinx.ext.ifconfig', + 'sphinx.ext.intersphinx', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx_inline_tabs', + ] + +source_suffix = { + '.rst': 'restructuredtext', + '.md': 'markdown', +} + +autodoc_typehints = 'description' + +pygments_style = "sphinx" +pygments_dark_style = "monokai" + +templates_path = ['_templates'] +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# Ignore errors when converting notebooks +nbsphinx_allow_errors = True + +language = 'en' +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_static_path = ['_static'] + +html_title = "CUTLASS Python" +html_baseurl = 'docs' +html_theme = 'furo' +html_theme_options = { + "light_logo": "cutlass-logo-small.png", + "dark_logo": "cutlass-logo-small.png", + "light_css_variables": { + "color-brand-primary": "#76B900", + "color-brand-content": "#76B900", + }, + "dark_css_variables": { + "color-brand-primary": "#76B900", + "color-brand-content": "#76B900", + }, + "footer_icons": [ + { + "name": "GitHub", + "url": "https://github.com/NVIDIA/cutlass", + "html": """ + + + + """, + "class": "", + }, + ], +} diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..308a5676b06f00089d1cdfe0fb83b442ca2df36e --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py @@ -0,0 +1,36 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from .int_tuple import * +from .layout import * +from .swizzle import * +from .typing import * diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..3d722130c52142e68a3bcd54ac708012aeeeaad3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py @@ -0,0 +1,225 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Functions for manipulating IntTuples +""" + +from functools import reduce +from itertools import chain +from typing import Union +from .typing import Integer + + +def is_int(x): + return isinstance(x, Integer) + + +def is_tuple(x): + return isinstance(x, tuple) + + +def flatten(t): + if is_tuple(t): + if len(t) == 0: + return () + else: + return tuple(i for a in t for i in flatten(a)) + else: + return (t,) + + +def signum(a): + return bool(a > 0) - bool(a < 0) + + +def product(a): + if is_tuple(a): + return reduce(lambda val,elem : val*product(elem), a, 1) + else: + return a + + +def inner_product(a, b): + if is_tuple(a): # tuple tuple + assert len(a) == len(b) + return sum(inner_product(x,y) for x,y in zip(a,b)) + else: # "int" "int" + assert not is_tuple(b) + return a * b + + +def tuple_max(a): + if is_tuple(a): + return max(tuple_max(x) for x in a) + else: + return a + + +def elem_scale(a, b): + if is_tuple(a): + if is_tuple(b): # tuple tuple + assert len(a) == len(b) + return tuple(elem_scale(x,y) for x,y in zip(a,b)) + else: # tuple "int" + assert False # Error + else: + if is_tuple(b): # "int" tuple + return elem_scale(a, product(b)) + else: # "int" "int" + return a * b + + +# Inclusive prefix ceil div with output congruent to input a +def shape_div(a, b): + if is_tuple(a): + if is_tuple(b): # tuple tuple + assert len(a) == len(b) + return tuple(shape_div(x,y) for x,y in zip(a,b)) + else: # tuple "int" + #r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))] + r = [] + for v in a: + r.append(shape_div(v,b)) + b = shape_div(b,product(v)) + return tuple(r) + else: + if is_tuple(b): # "int" tuple + return shape_div(a, product(b)) + else: # "int" "int" + assert a % b == 0 or b % a == 0 + return (a + b - 1) // b + +# Exclusive prefix product with output congruent to input a +def prefix_product(a, init=1): + if is_tuple(a): + if is_tuple(init): # tuple tuple + assert len(a) == len(init) + return tuple(prefix_product(x,i) for x,i in zip(a,init)) + else: # tuple "int" + #r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))] + r = [] + for v in a: + r.append(prefix_product(v,init)) + init = init * product(v) + return tuple(r) + else: + if is_tuple(init): # "int" tuple + assert False # Error + else: # "int" "int" + return init + + +def idx2crd(idx, shape, stride=None): + if stride is None: + stride = prefix_product(shape) + + if is_tuple(idx): + if is_tuple(shape): # tuple tuple tuple + assert len(idx) == len(shape) and len(idx) == len(stride) + return tuple(idx2crd(i, s, d) for i, s, d in zip(idx,shape,stride)) + else: # tuple "int" "int" + assert False # Error + else: + if is_tuple(shape): # "int" tuple tuple + assert len(shape) == len(stride) + return tuple(idx2crd(idx, s, d) for s,d in zip(shape,stride)) + else: # "int" "int" "int" + return (idx // stride) % shape + + +def crd2idx(crd, shape, stride=None): + if stride is None: + stride = prefix_product(shape) + + if is_tuple(crd): + if is_tuple(shape): # tuple tuple tuple + assert len(crd) == len(shape) and len(crd) == len(stride) + return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride)) + else: # tuple "int" "int" + assert False, f"crd={crd}, shape={shape}" # Error + else: + if crd is None: + crd = 0 + + if is_tuple(shape): # "int" tuple tuple + assert len(shape) == len(stride) + result = 0 + for i in range(len(shape)-1): + result += crd2idx(crd % product(shape[i]), shape[i], stride[i]) + crd = crd // product(shape[i]) + return result + crd2idx(crd, shape[-1], stride[-1]) + else: # "int" "int" "int" + return crd * stride + + +# Transform crd into the dst_shape's iteration space +def crd2crd(crd, dst_shape, src_shape=None): + if is_tuple(crd): + if is_tuple(dst_shape): # tuple tuple + assert len(crd) == len(dst_shape) + return tuple(crd2crd(x, y) for x, y in zip(crd,dst_shape)) + else: # tuple "int" + # Ambiguous unless we have src_shape + assert src_shape is not None + return crd2idx(crd, src_shape) + else: + if is_tuple(dst_shape): # "int" tuple + return idx2crd(crd, dst_shape) + else: # "int" "int" + assert crd < dst_shape + return crd + + +# Filter trg according to crd: keep only elements of trg that are paired with None +def slice_(crd: Union[None, tuple, int], + trg: Union[tuple, int]): + if is_tuple(crd): + if is_tuple(trg): # tuple tuple + assert len(crd) == len(trg) + # match C++ behavior of `filter_tuple` using `tuple_cat(...)` + return tuple(chain(*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)]))) + else: + assert False # tuple "int" : Error + elif crd is None: + # match C++ behavior `return cute::tuple{b};` + return (trg,) + else: + return () + + +# Determine if None appears at any of an int_tuples' terminals +def has_none(a: Union[None, tuple, int]): + if is_tuple(a): + return any(has_none(v) for v in a) + else: + return a is None diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..7c220eb16dd089c65fdbe6d6929b357ace0a77c1 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py @@ -0,0 +1,367 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Definition of CuTe Layouts and functions to manipulate them +""" + +from itertools import chain +from typing import Union + +from .int_tuple import * + + +class LayoutBase: + pass + + +def is_layout(x): + return isinstance(x, LayoutBase) + + +class Layout(LayoutBase): + def __init__(self, _shape, _stride=None): + self.shape = _shape + if _stride is None: + self.stride = prefix_product(self.shape) + else: + self.stride = _stride + + # operator == + def __eq__(self, other): + return self.shape == other.shape and self.stride == other.stride + + # operator len(L) (len [rank] like tuples) + def __len__(self): + if is_tuple(self.shape): + return len(self.shape) + else: + return 1 + + # operator () (map coord to idx) + def __call__(self, *args): + """ + Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + OR + Slice the layout and return the sublayout (Coord has an Underscore slice op) + + Follow the same behavior of `Layout::operator(Coord const&)` in cute C++ + """ + if has_none(args): + if len(args) == 1: + return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride)) + else: + return Layout(slice_(args, self.shape), slice_(args, self.stride)) + else: + if len(args) == 1: + return crd2idx(args[0], self.shape, self.stride) + else: + return crd2idx(args, self.shape, self.stride) + + # operator [] (get-i like tuples) + def __getitem__(self, i): + if is_tuple(self.shape): + return Layout(self.shape[i], self.stride[i]) + else: + assert i == 0 + return Layout(self.shape, self.stride) + + # size(layout) Size of the domain + def size(self): + return product(self.shape) + + # cosize(layout) Size of the codomain + def cosize(self): + return self(self.size() - 1) + 1 + + # print and str + def __str__(self): + return f"{self.shape}:{self.stride}" + + # error msgs and representation + def __repr__(self): + return f"Layout({self.shape},{self.stride})" + + +# Make Layout from a list of layouts (each layout it's own mode in the result) +def make_layout(*layouts): + if len(layouts) == 1 and not is_layout(layouts[0]): + layouts = layouts[0] + + shape, stride = zip(*((a.shape,a.stride) for a in layouts)) + return Layout(shape, stride) + + +# Size of the domain +def size(layout): + if is_layout(layout): + return layout.size() + return product(layout) + + +# Size of the codomain +def cosize(layout): + return layout.cosize() + + +# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function +def coalesce(layout, profile=None): + if is_tuple(profile): + assert len(layout) >= len(profile) + return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))), + (layout[i] for i in range(len(profile),len(layout))))) + + result_shape = [1] + result_stride = [0] + for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): + # skip their shape-1s + if shape == 1: + continue + # replace our shape-1 with anything + elif result_shape[-1] == 1: + result_shape[-1] = shape + result_stride[-1] = stride + # merge modes if the shape*stride match + elif result_shape[-1] * result_stride[-1] == stride: + result_shape[-1] = result_shape[-1] * shape + # append a new mode + else: + result_shape.append(shape) + result_stride.append(stride) + + if len(result_shape) == 1: + return Layout(result_shape[0], result_stride[0]) + else: + return Layout(tuple(result_shape), tuple(result_stride)) + + +# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them +def filter(layout, profile=None): + if is_tuple(profile): + assert len(layout) >= len(profile) + return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))), + (layout[i] for i in range(len(profile),len(layout))))) + + result_shape = [] + result_stride = [] + for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)): + # skip their shape-1s and stride-0s + if not (shape == 1 or stride == 0): + result_shape.append(shape) + result_stride.append(stride) + + if len(result_shape) == 0: + return Layout(1,0) + else: + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout composition +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def composition(layoutA, layoutB): + if layoutB is None: + return layoutA + elif is_int(layoutB): + return composition(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA))))) + elif is_tuple(layoutB.shape): + return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB) + + if layoutB.stride == 0: + return Layout(layoutB.shape, 0) + else: + result_shape = [] + result_stride = [] + rest_shape = layoutB.shape + rest_stride = layoutB.stride + flat_A = coalesce(layoutA) + for (curr_shape, curr_stride) in zip(flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]): + assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0 + new_shape = min(max(1, curr_shape // rest_stride), rest_shape) + + if new_shape != 1: + result_shape.append(new_shape) + result_stride.append(rest_stride * curr_stride) + + rest_shape = rest_shape // new_shape + rest_stride = -(-rest_stride // curr_shape) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride) + + if rest_shape != 1 or len(result_shape) == 0: + result_shape.append(rest_shape) + result_stride.append(rest_stride * flatten(flat_A.stride)[-1]) + + if len(result_shape) == 1: + return Layout(result_shape[0], result_stride[0]) + else: + return Layout(tuple(result_shape), tuple(result_stride)) + + +# Layout complement +def complement(layout, max_idx=1): + if is_int(layout): + return complement(Layout(layout)) + + result_shape = [] + result_stride = [] + current_idx = 1 + + sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape))) + for (stride, shape) in sorted_DS: + if stride == 0 or shape == 1: + continue + + in_bound = current_idx <= shape * stride + # To support symbolic value which can't be evaluated now + assert (type(in_bound) is not bool) or in_bound + + result_shape.append(stride // current_idx) + result_stride.append(current_idx) + current_idx = shape * stride + + result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div + result_stride.append(current_idx) + + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout right inverse +def right_inverse(layout): + if layout is None: + return None + elif is_int(layout): + return Layout(layout) + + result_shape = [] + result_stride = [] + current_idx = 1 + + flat_shape = flatten(layout.shape) + flat_stride = flatten(layout.stride) + sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape))) + for (stride,shape,rstride) in sorted_DSA: + if shape == 1: + continue + if current_idx != stride: + break + + result_shape.append(shape) + result_stride.append(rstride) + current_idx = shape * stride + + return coalesce(Layout(tuple(result_shape), tuple(result_stride))) + + +# Layout left inverse +def left_inverse(layout): + if layout is None: + return None + elif is_int(layout): + return Layout(layout) + return right_inverse(make_layout(layout, complement(layout))) + + +# Split a layout by the composition of B and the "rest" +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def logical_divide(layoutA, layoutB): + if layoutB is None: + return layoutA + elif is_int(layoutB): + return logical_divide(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA))))) + + return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA)))) + + +# Reproduce a layoutA over a layoutB +# Use tuples-of-layouts to perform this operation by-mode and None as no-op +def logical_product(layoutA, layoutB): + if layoutB is None: + return layoutA + elif is_int(layoutB): + return logical_divide(layoutA, Layout(layoutB)) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA))))) + + return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB)); + + +# Gather the modes from a hierarchical logical_divide or logical_product +def hier_unzip(splitter, layoutA, layoutB): + if layoutB is None: + return make_layout(Layout(1,0), layoutA) + elif is_tuple(layoutB): + assert len(layoutA) >= len(layoutB) + # A layout with shape ((A,a),(B,b),(C,c)) + split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB))) + # Gather to shape ((A,B,C,...),(a,b,c,...,y,z)) + return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))), + make_layout(chain((split[i][1] for i in range( 0,len(layoutB))), + (layoutA[i] for i in range(len(layoutB),len(layoutA)))))) + + # splitter must return a rank-2 layout + return splitter(layoutA, layoutB) + + +# Apply logical divide hierarchically and gather the split modes into two modes +def zipped_divide(layoutA, layoutB): + return hier_unzip(logical_divide, layoutA, layoutB) + + +# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode +def tiled_divide(layoutA, layoutB): + result = zipped_divide(layoutA, layoutB) + return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) + + +# Apply logical product hierarchically and gather the split modes into two modes +def zipped_product(layoutA, layoutB): + return hier_unzip(logical_product, layoutA, layoutB) + + +# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode +def tiled_product(layoutA, layoutB): + result = zipped_product(layoutA, layoutB) + return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))]) + + +def slice_and_offset(crd: tuple, + layout: Layout): + return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)), + crd2idx(crd, layout.shape, layout.stride)) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py new file mode 100644 index 0000000000000000000000000000000000000000..308aee0c3838a82c4de53833fb8a36950b30f62d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py @@ -0,0 +1,129 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Methods for layout swizzling +""" + +from .layout import * + + +def shiftr(a, s): + return a >> s if s > 0 else shiftl(a, -s) + + +def shiftl(a, s): + return a << s if s > 0 else shiftr(a, -s) + + +## A generic Swizzle functor + # 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx + # ^--^ Base is the number of least-sig bits to keep constant + # ^-^ ^-^ Bits is the number of bits in the mask + # ^---------^ Shift is the distance to shift the YYY mask + # (pos shifts YYY to the right, neg shifts YYY to the left) + # + # e.g. Given + # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx + # the result is + # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY + # +class Swizzle: + def __init__(self, bits, base, shift): + assert bits >= 0 + assert base >= 0 + assert abs(shift) >= bits + self.bits = bits + self.base = base + self.shift = shift + bit_msk = (1 << bits) - 1 + self.yyy_msk = bit_msk << (base + max(0,shift)) + self.zzz_msk = bit_msk << (base - min(0,shift)) + + # operator () (transform integer) + def __call__(self, offset): + return offset ^ shiftr(offset & self.yyy_msk, self.shift) + + # Size of the domain + def size(self): + return 1 << (self.bits + self.base + abs(self.shift)) + + # Size of the codomain + def cosize(self): + return self.size() + + # print and str + def __str__(self): + return f"SW_{self.bits}_{self.base}_{self.shift}" + + # error msgs and representation + def __repr__(self): + return f"Swizzle({self.bits},{self.base},{self.shift})" + + +class ComposedLayout(LayoutBase): + def __init__(self, layoutB, offset, layoutA): + self.layoutB = layoutB + self.offset = offset + self.layoutA = layoutA + + # operator == + def __eq__(self, other): + return self.layoutB == other.layoutB and self.offset == other.offset and self.layoutA == other.layoutA + + # operator len(L) (len [rank] like tuples) + def __len__(self): + return len(self.layoutA) + + # operator () (map coord to idx) + def __call__(self, *args): + return self.layoutB(self.offset + self.layoutA(*args)) + + # operator [] (get-i like tuples) + def __getitem__(self, i): + return ComposedLayout(self.layoutB, self.offset, self.layoutA[i]) + + # size(layout) Size of the domain + def size(self): + return size(self.layoutA) + + # cosize(layout) Size of the codomain + def cosize(self): + return cosize(self.layoutB) + + # print and str + def __str__(self): + return f"{self.layoutB} o {self.offset} o {self.layoutA}" + + # error msgs and representation + def __repr__(self): + return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})" diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py new file mode 100644 index 0000000000000000000000000000000000000000..834f7e5411f5c2a4e218f9ce8a4f0a229d039710 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py @@ -0,0 +1,42 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from abc import ABC + + +class Integer(ABC): + @classmethod + def __subclasshook__(cls, c): + if c in [bool, float]: + return False + + return issubclass(c, int) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..acc0c46e540735443a4943908852010a80d02187 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py @@ -0,0 +1,74 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + + +import copy +import os +import setuptools +from setuptools import setup +from setuptools.command.build_ext import build_ext + +import setup_pycute +import setup_library + + +# Install cutlass_library package +setup_library.perform_setup() + + +# Install the PyCuTe package +setup_pycute.perform_setup() + + +setup( + name='cutlass_cppgen', + version='4.2.0', + description='CUTLASS Pythonic Interface', + package_dir={'': '.'}, + packages=[ + 'cutlass_cppgen', + 'cutlass_cppgen.emit', + 'cutlass_cppgen.op', + 'cutlass_cppgen.utils', + 'cutlass_cppgen.backend', + 'cutlass_cppgen.backend.utils' + ], + setup_requires=['pybind11'], + install_requires=[ + 'bfloat16', + 'cuda-python>=11.8.0', + 'pybind11', + 'scikit-build', + 'treelib', + 'pydot' + ] +) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_library.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_library.py new file mode 100644 index 0000000000000000000000000000000000000000..c56d6b5556fea2d5e56209b13f5b95e487ca22fb --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_library.py @@ -0,0 +1,46 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from setuptools import setup + + +def perform_setup(): + setup( + name='cutlass_library', + version='4.2.1', + description='CUTLASS library generation scripts', + packages=['cutlass_library'] + ) + + +if __name__ == '__main__': + perform_setup() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py new file mode 100644 index 0000000000000000000000000000000000000000..0bad050fcade8b26d33043abbb0f8226be7d816c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py @@ -0,0 +1,46 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from setuptools import setup + + +def perform_setup(): + setup( + name='pycute', + version='4.2.1', + description='Python implementation of CuTe', + packages=['pycute'], + ) + + +if __name__ == '__main__': + perform_setup() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py new file mode 100644 index 0000000000000000000000000000000000000000..852c0277ebae2fce7e0b083ce2f497a2c828256f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py @@ -0,0 +1,661 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utilities for defining Conv2D problem sizes for testing. + +This file was ported from the C++ version in test/unit/conv/device/conv2d_problems.h +""" + +from cutlass_library import ConvMode + +import cutlass_cppgen +from cutlass_cppgen.shape import Conv2DProblemSize + + +class TestbedConv2dProblemSizes: + def __init__(self, minimum_channel_size: int): + conv2d_default_sizes = self.initialize_conv2d_default_sizes(minimum_channel_size) + conv2d_rigorous_sizes = self.initialize_conv2d_rigorous_sizes(minimum_channel_size) + conv2d_resnet50_sizes = self.initialize_conv2d_resnet50_sizes(1) + conv2d_resnet50_sizes_perf = self.initialize_conv2d_resnet50_sizes(34) + grouped_sizes = self.initialize_conv2d_grouped_sizes() + + # Filter all problems + self.all = [] + for size_list in [conv2d_default_sizes, conv2d_rigorous_sizes, conv2d_resnet50_sizes, conv2d_resnet50_sizes_perf, grouped_sizes]: + for size in size_list: + if (size.C // size.groups) % minimum_channel_size == 0: + self.all.append(size) + + + def initialize_conv2d_default_sizes(self, minimum_channel_size): + # Small input size x stride (1,1) + # C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + + conv2d_default_sizes = [] + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 1, 1, minimum_channel_size, + 8, 1, 1, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 1, 8, minimum_channel_size, + 8, 1, 3, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 7, 8, minimum_channel_size, + 8, 3, 3, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 7, 9, minimum_channel_size, + 8, 4, 4, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 2, 7, 9, minimum_channel_size, + 8, 5, 5, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 7, 9, minimum_channel_size, + 8, 6, 5, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 7, 9, minimum_channel_size, + 8, 6, 6, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 7, 9, minimum_channel_size, + 8, 7, 7, minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + ############################################## + # Small input size x stride (2,2) + # C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + ############################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 11, 7, minimum_channel_size, + 8, 1, 1, minimum_channel_size, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 11, 7, minimum_channel_size, + 8, 3, 3, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 13, 11, minimum_channel_size, + 8, 1, 1, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 17, 19, minimum_channel_size, + 16, 2, 2, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 23, 5, minimum_channel_size, + 16, 3, 3, minimum_channel_size, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 13, 17, 8, + 24, 3, 3, 8, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 23, 21, 8, + 24, 3, 3, 8, + 1, 1, + 3, 3, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 20, 24, 8, + 40, 3, 3, 8, + 3, 3, + 3, 3, + 1, 1, + )) + + ########################################## + # Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 15, 19, 160, + 224, 1, 1, 160, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 19, 37, 160, + 224, 3, 3, 160, + 1, 1, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 16, 16, 160, + 224, 2, 3, 160, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 23, 21, 128, + 224, 3, 3, 128, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 29, 37, 160, + 224, 5, 5, 160, + 2, 2, + 1, 1, + 1, 1, + )) + + ########################################## + # C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 15, 19, 32 + minimum_channel_size, + 96, 3, 3, 32 + minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 16, 24, 64 + minimum_channel_size, + 96, 3, 3, 64 + minimum_channel_size, + 1, 1, + 1, 1, + 1, 1, + )) + + ########################################## + # Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 13, 16, 288, + 160, 5, 5, 288, + 2, 2, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 55, 51, 256, + 512, 1, 1, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 71, 80, 32, + 64, 5, 5, 32, + 2, 2, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 224, 224, 8, + 64, 7, 7, 8, + 3, 3, + 2, 2, + 1, 1, + )) + + ########################################## + # Medium input size stride (3, 3), filter (3, 3), non-default padding + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 23, 256, + 512, 3, 3, 256, + 0, 0, + 3, 3, + 1, 1, + )) + + ########################################## + # Medium input size padding > stride, asymmetric filter, padding and striding + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 31, 256, + 512, 3, 3, 256, + 5, 7, + 3, 4, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 35, 256, + 512, 7, 5, 256, + 11, 7, + 3, 5, + 1, 1, + )) + + ########################################## + # Medium input size *mixed* stride (1, 2) and (2, 1), + # filter (3, 3), default padding + ########################################## + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 27, 256, + 512, 3, 3, 256, + 1, 1, + 1, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 27, 27, 256, + 512, 3, 3, 256, + 1, 1, + 2, 1, + 1, 1, + )) + + ######################################/ + # Additional input size + ######################################/ + conv2d_default_sizes.append(Conv2DProblemSize( + 3, 28, 28, 256, + 256, 2, 2, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 1, 32, 32, 16, + 32, 3, 3, 16, + 1, 1, + 6, 2, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 32, 24, 32, 32, + 32, 1, 2, 32, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_default_sizes.append(Conv2DProblemSize( + 4, 2, 3, 256, + 328, 3, 5, 256, + 1, 1, + 1, 1, + 1, 1, + )) + return conv2d_default_sizes + + # Add a few large and rigorous convolution problem sizes + def initialize_conv2d_rigorous_sizes(self, minimum_channel_size): + sizes = [] + if False: + sizes.append(Conv2DProblemSize.from_sizes( + (1, 124, 224, 2 * minimum_channel_size), + (24, 7, 7, 2 * minimum_channel_size), + )) + + sizes.append(Conv2DProblemSize.from_sizes( + (1, 233, 35, minimum_channel_size), + (24, 7, 5, minimum_channel_size), + )) + return sizes + + # Add resent50 layers to unit testing sizes + def initialize_conv2d_resnet50_sizes(self, batch_size): + conv2d_problem_vector = [] + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 64, + 256, 1, 1, 64, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 64, + 64, 1, 1, 64, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 64, + 64, 3, 3, 64, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 256, + 64, 1, 1, 256, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 256, + 512, 1, 1, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 56, 56, 256, + 128, 1, 1, 256, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 128, + 128, 3, 3, 128, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 128, + 512, 1, 1, 128, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 512, + 128, 1, 1, 512, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 512, + 1024, 1, 1, 512, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 28, 28, 512, + 256, 1, 1, 512, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 256, + 256, 3, 3, 256, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 256, + 1024, 1, 1, 256, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 1024, + 256, 1, 1, 1024, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 1024, + 2048, 1, 1, 1024, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 14, 14, 1024, + 512, 1, 1, 1024, + 0, 0, + 2, 2, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 7, 7, 512, + 512, 3, 3, 512, + 1, 1, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 7, 7, 512, + 2048, 1, 1, 512, + 0, 0, + 1, 1, + 1, 1, + )) + + conv2d_problem_vector.append(Conv2DProblemSize( + batch_size, 7, 7, 2048, + 512, 1, 1, 2048, + 0, 0, + 1, 1, + 1, 1, + )) + + return conv2d_problem_vector + + def initialize_conv2d_grouped_sizes(self): + threadblock_n = 128 + threadblock_k = 32 + + sizes = [] + ########################################## + # One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 + # One CTA calculates a single group + ########################################## + for cta_per_group_k in range(1, 4): + for groups in range(2, 5): + conv_k = cta_per_group_k * threadblock_n * groups + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 2 * groups, + conv_k, 3, 3, threadblock_k * 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + groups + )) + + # Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k, + threadblock_n * 2, 3, 3, threadblock_k // 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 2 + )) + + sizes.append(Conv2DProblemSize( + 1, 56, 56, 696, + 768, 3, 3, 232, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 3 + )) + sizes.append(Conv2DProblemSize( + 1, 14, 14, 1392, + 1536, 3, 3, 232, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 3 + )) + + ########################################## + # One CTA calculate multiple groups: CTA::N % k_per_group = 0 + ########################################## + + # 2 groups per CTA + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 4, + threadblock_n, 3, 3, threadblock_k * 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 2 + )) + + # 2 groups per CTA and partial gemm_k + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k, + threadblock_n, 3, 3, threadblock_k // 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 2 + )) + + # 4 groups per CTA + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 8, + threadblock_n // 2, 3, 3, threadblock_k * 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 4 + )) + + # 4 groups per CTA and partial gemm_k + sizes.append(Conv2DProblemSize( + 1, 8, 8, threadblock_k * 2, + threadblock_n // 2, 3, 3, threadblock_k // 2, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, + 4 + )) + + return sizes diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..f77a0ec831be087bd3badc929eee955f0b37c489 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py @@ -0,0 +1,146 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for Conv2d opreations on SM80 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from conv2d_test_utils import * + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is invalid for SM80 tests.') +class Conv2dSm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +conv_problems = get_conv_problems() + + +# Tests for optimized & analytic +for conv_kind in ["fprop", "wgrad", "dgrad"]: + # F16, simt + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="simt", threadblock_shape=[128, 128, 8], + warp_count=[4, 2, 1], stages=2, instruction_shape=[1, 1, 1]) + # F16, tensor op + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) + # F16, tensor op, analytic iterator + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="analytic") + # F16, tensor op, f32 output + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) + # F16, tensor op, different tile description + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 64, 32], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8]) + # F32, simt + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, + opclass="simt", threadblock_shape=[128, 128, 8], + warp_count=[4, 2, 1], stages=4, instruction_shape=[1, 1, 1]) + # Tf32, tensorop + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, + opclass="tensor_op", threadblock_shape=[128, 128, 16], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8] + ) + # Split-K + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="serial", + split_k_slices=2) + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="parallel", + split_k_slices=5) + # Swizzling functor + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 64, 32], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8], swizzle=4) + +# Tests for few channels and fixed channels +# F16, tensor op, few channels +for c, tb, stage, inst in zip([2, 1], + [[128, 128, 64], [128, 128, 32]], + [3, 2], + [[16, 8, 16], [16, 8, 8]]): + add_test( + Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=tb, + warp_count=[2, 2, 1], stages=stage, instruction_shape=inst, iterator_algorithm="few_channels" + ) +# F16, tensor op, fixed channels +for c in [8, 4, 2]: + add_test( + Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="fixed_channels" + ) + +# Test activations +for activation in ["relu", "leaky_relu"]: + for split_k_mode, split_k_slices in zip(["parallel", "serial", "parallel"], [1, 7, 5]): + add_test( + Conv2dSm80, cc, "fprop", conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode=split_k_mode, + split_k_slices=split_k_slices, activation=activation) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc4542cd5ccf72341f7db3c7947d481b032926d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py @@ -0,0 +1,428 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility functions for Conv2d tests. +""" + +from cutlass_library import SubstituteTemplate +import torch + +import cutlass_cppgen +from cutlass_library import ( + ConvKind, + ConvMode, + DataType, + DataTypeNames, + EpilogueScheduleSuffixes, + KernelScheduleSuffixes, + LayoutType, + OpcodeClassNames, + ShortDataTypeNames, + ShortLayoutTypeNames, + SplitKMode, +) +from cutlass_cppgen.shape import Conv2DProblemSize +from cutlass_cppgen.utils.datatypes import numpy_type, torch_type + +from conv2d_problem_sizes import TestbedConv2dProblemSizes + + +def get_name_conv2d( + arch, + conv_kind, + element, + element_accumulator, + element_output, + opclass, + threadblock_shape, + warp_count, + instruction_shape, + stages, + iterator_algorithm, + swizzle, + split_k_mode, + split_k_slices, + activation +): + """ + Generates a procedural name for a test case for conv2d + + :param arch: compute capability of kernel being generated + :type arch: int + :param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad) + :type conv_kind: str + :param iterator_algorithm: the iterator algorithm applied + :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_c: data type of operand C + :param element_accumulator: data type used in accumulation + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param stride_support: stride support of dgrad + :param alignment: int + :type alignment: int + + :return: str + """ + if iterator_algorithm is None: + iterator_algorithm = "AUTO" + if swizzle is None: + swizzle = 1 + name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}" + + return SubstituteTemplate( + name_format, + { + "arch": str(arch), + "conv_kind": conv_kind, + "iter_alg": iterator_algorithm, + "eA": DataTypeNames[element], + "eB": DataTypeNames[element], + "eC": DataTypeNames[element_output], + "opclass": opclass, + "acc": DataTypeNames[element_accumulator], + "tbM": str(threadblock_shape[0]), + "tbN": str(threadblock_shape[1]), + "tbK": str(threadblock_shape[2]), + "wM": str(threadblock_shape[0] // warp_count[0]), + "wN": str(threadblock_shape[1] // warp_count[1]), + "wK": str(threadblock_shape[2] // warp_count[2]), + "IM": str(instruction_shape[0]), + "IN": str(instruction_shape[1]), + "IK": str(instruction_shape[2]), + "stages": str(stages), + "swizzle": str(swizzle), + "split_k_mode": split_k_mode, + "split_k_slices": str(split_k_slices), + "activation": activation + } + ) + + +def conv2d_few_channel_problemsizes(channels): + problem_sizes = [ + Conv2DProblemSize( + 1, 8, 8, channels, + 16, 3, 3, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 16, 16, channels, + 16, 3, 3, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 16, 16, channels, + 16, 7, 7, channels, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 32, 7, 7, channels, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 64, 7, 7, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 64, 5, 5, channels, + 1, 1, + 1, 1, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 224, 224, channels, + 64, 5, 5, channels, + 1, 1, + 2, 2, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + ] + + return problem_sizes + + +def validate_problem_size(ps, conv_kind, split_k_slices): + P = (ps.H + 2 * ps.pad_h - ps.dilation_h * (ps.R - 1) - 1) // ps.stride_h + 1 + Q = (ps.W + 2 * ps.pad_w - ps.dilation_w * (ps.S - 1) - 1) // ps.stride_w + 1 + if P != ps.P or Q != ps.Q: + return False + + # Split-K (serial or parallel) is not supported for strided dgrad + if conv_kind == "dgrad" and split_k_slices > 1 and (ps.stride_h > 1 or ps.stride_w > 1): + return False + return True + + +class Conv2dLauncherFrontend: + def __init__(self, plan: cutlass_cppgen.Conv2d, seed: int = 80, backend="numpy"): + self.operation = plan + self.conv_kind = plan.conv_kind + self.seed = seed + self.backend = backend + + self.dtype_A = plan._element_a + self.dtype_B = plan._element_b + self.dtype_C = plan._element_c + self.dtype_acc = plan._element_accumulator + self.layout_A = LayoutType.TensorNHWC + self.layout_B = LayoutType.TensorNHWC + self.layout_C = LayoutType.TensorNHWC + self.layout_D = LayoutType.TensorNHWC + + self.element_compute = DataType.f32 + + if self.dtype_A in [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.bf16]: + self.rand_max = 1 + else: + self.rand_max = 4 + self.activation = plan.activation + + def uniform_init(self, size, dtype): + tensor = torch.ceil( + torch.empty(size=size, dtype=torch_type(dtype), device="cuda").uniform_(-self.rand_max - 0.5, self.rand_max - 0.5) + ).to(memory_format=torch.channels_last) + return tensor + + def reference(self, ps, A, B, C, alpha, beta, activation): + if self.conv_kind == ConvKind.Fprop: + torch_result = alpha * torch.ops.aten.conv2d( + A, + B, + stride=(ps.stride_h, ps.stride_w), + padding=(ps.pad_h, ps.pad_w), + dilation=(ps.dilation_h, ps.dilation_w) + ) + beta * C + elif self.conv_kind == ConvKind.Dgrad: + torch_result = alpha * torch.nn.grad.conv2d_input( + (ps.N, ps.C, ps.H, ps.W), + B, + A, + padding=(ps.pad_h, ps.pad_w), + stride=(ps.stride_h, ps.stride_w) + ) + beta * C + elif self.conv_kind == ConvKind.Wgrad: + torch_result = alpha * torch.nn.grad.conv2d_weight( + B, + (ps.K, ps.C, ps.R, ps.S), + A, + padding=(ps.pad_h, ps.pad_w), + stride=(ps.stride_h, ps.stride_w) + ) + beta * C + else: + raise Exception(f"Conv kind {self.conv_kind} is currently unsupported.") + + if activation == cutlass_cppgen.backend.epilogue.relu: + torch_result = torch.nn.functional.relu(torch_result) + elif activation == cutlass_cppgen.backend.epilogue.leaky_relu: + torch_result = torch.nn.functional.leaky_relu(torch_result, 0.5) + return torch_result + + def run(self, ps, split_k_mode=SplitKMode.Serial, split_k_slices=1, alpha=1.0, beta=0.0): + if self.conv_kind == ConvKind.Fprop: + tensor_A_size = (ps.N, ps.C, ps.H, ps.W) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.K, ps.P, ps.Q) + elif self.conv_kind == ConvKind.Dgrad: + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.C, ps.H, ps.W) + elif self.conv_kind == ConvKind.Wgrad: + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.N, ps.C, ps.H, ps.W) + tensor_C_size = (ps.K, ps.C, ps.R, ps.S) + else: + raise Exception(f"Conv kind {self.conv_kind} is not supported") + + torch.manual_seed(self.seed) + + tensor_A = self.uniform_init(size=tensor_A_size, dtype=self.dtype_A) + tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B) + tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C) + tensor_D = torch.zeros_like(tensor_C).to(memory_format=torch.channels_last) + args = self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D, + stride=(ps.stride_h, ps.stride_w), + padding=(ps.pad_h, ps.pad_w), + dilation=(ps.dilation_h, ps.dilation_w), + alpha=alpha, beta=beta, + split_k=(split_k_mode, split_k_slices)) + + args.sync() + + tensor_D_ref = self.reference(ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation) + + torch.cuda.synchronize() + passed = torch.allclose(tensor_D, tensor_D_ref, atol=2e-06) + + return passed + + +def add_test( + cls, + cc, + conv_kind, + problem_sizes, + element, + element_accumulator, + element_output, + opclass, + threadblock_shape, + warp_count, + instruction_shape, + stages, + iterator_algorithm=None, + swizzle=None, + split_k_mode="serial", + split_k_slices=1, + activation = "identity" +): + """Create a test-running function with the given specification""" + test_name = get_name_conv2d( + cc, conv_kind, element, element_accumulator, + element_output, opclass, threadblock_shape, warp_count, instruction_shape, stages, + iterator_algorithm, swizzle, split_k_mode, split_k_slices, activation) + + def run(self): + # Create the plan + plan = cutlass_cppgen.Conv2d( + kind=conv_kind, + element=element, + element_accumulator=element_accumulator, + element_C=element_output, + element_D=element_output + ) + + # Set the opclass + plan.opclass = opclass + # Set the tile description + td = { + "threadblock_shape": threadblock_shape, + "warp_count": warp_count, + "stages": stages, + "instruction_shape": instruction_shape, + } + + plan.tile_description = td + # Set iterator algorithm + if iterator_algorithm is not None: + plan.iterator_algorithm = iterator_algorithm + # Set swizzling functor + if swizzle is not None: + plan.swizzling_stride = swizzle + + if activation != "identity": + if activation == "leaky_relu": + plan.activation = (cutlass_cppgen.epilogue.leaky_relu, 0.5) + else: + plan.activation = getattr(cutlass_cppgen.epilogue, activation) + + conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch") + + for ps in problem_sizes: + if not validate_problem_size(ps, conv_kind, split_k_slices): + continue + + self.assertTrue(conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 2.0)) + + setattr(cls, test_name, run) + + return run + + +def get_conv_problems(): + # 64: minimum channel size + conv_problems = TestbedConv2dProblemSizes(64).all + + # Insert alignment 4 & 2 tests + conv_problems += [ + Conv2DProblemSize( + 1, 4, 4, 12, + 8, 3, 3, 12, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 4, 4, 14, + 8, 3, 3, 14, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + Conv2DProblemSize( + 1, 23, 56, 98, + 128, 3, 3, 98, + 4, 5, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ), + ] + + return conv_problems diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..d892b5df047d5121345d902a77aadf2256b4c3b5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'conv2d_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d4c52a9f75fb4c3bc809947bf48ba85356ec70 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py @@ -0,0 +1,309 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests emitting a CUTLASS kernel to a PyTorch CUDA extension +""" + +import random +import tempfile +import unittest + +from cutlass_library import ConvMode + +import cutlass_cppgen + +if cutlass_cppgen.utils.datatypes.is_torch_available(): + import torch + + +def _initialize(dtype, M: int, N: int, K: int): + """ + Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K + + :param dtype: data type of tensors + :param M: M dimension of GEMM problem + :type M: int + :param N: N dimension of GEMM problem + :type N: int + :param K: N dimension of GEMM problem + :type K: int + + :return: initialized tensors A, B, C, and D + :rtype: list + """ + sizes = [(M, K), (K, N), (M, N), (M, N)] + return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes] + + +def _generate_problems(dtype, num): + """ + Utility function to generate `num` GEMMs of random sizes + + :param dtype: data type of tensors + :param num: number of GEMMs to generate + :type num: int + + :return: lists of A, B, C, and D tensors + :rtype: list + """ + valid_sizes = [128, 256, 512, 1024] + As, Bs, Cs, Ds = [], [], [], [] + for _ in range(num): + M, N, K = [random.choice(valid_sizes) for _ in range(3)] + A, B, C, D = _initialize(dtype, M, N, K) + As.append(A) + Bs.append(B) + Cs.append(C) + Ds.append(D) + return As, Bs, Cs, Ds + +def _generate_conv2d_problem(conv_kind, dtype, ps): + """ + Utility function to generate conv2d inputs + + :param conv_kind: kind of convolution + :type conv_kind: str + :param dtype: data type of tensors + :param problem_size: the conv2d problem size + :type problem_size: cutlass_cppgen.shape.Conv2DProblemSize + + :return: initialized tensors A, B, C, and D + :rtype: list + """ + if conv_kind == "fprop": + tensor_A_size = (ps.N, ps.C, ps.H, ps.W) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.K, ps.P, ps.Q) + elif conv_kind == "dgrad": + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.C, ps.H, ps.W) + else: + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.N, ps.C, ps.H, ps.W) + tensor_C_size = (ps.K, ps.C, ps.R, ps.S) + sizes = [tensor_A_size, tensor_B_size, tensor_C_size] + return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes] + + +@unittest.skipIf(not cutlass_cppgen.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests') +class PyTorchExtensionTest(unittest.TestCase): + + def test_gemm(self): + random.seed(2023) + + dtype = torch.float16 + plan = cutlass_cppgen.op.Gemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor) + op = plan.construct() + + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name='gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) + + A, B, C, _ = _initialize(dtype, 1024, 256, 512) + + D_ref = A @ B + D = mod.run(A, B) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C, 1.0) + assert torch.allclose(D, D_ref) + + D = mod.run(A, B, C, 1.0, 0.0) + assert torch.allclose(D, D_ref) + + alpha = 2.0 + beta = -1.0 + D_ref = (A @ B) * alpha + (beta * C) + D = mod.run(A, B, C, alpha, beta) + assert torch.allclose(D, D_ref) + + def test_grouped_gemm(self): + random.seed(2023) + + dtype = torch.float16 + plan = cutlass_cppgen.op.GroupedGemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor) + op = plan.construct() + + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name='grouped_gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True) + + As, Bs, Cs, _ = _generate_problems(dtype, 50) + + def check_all(X, Y): + for x, y in zip(X, Y): + assert torch.allclose(x, y) + + Ds_ref = [a @ b for a, b in zip(As, Bs)] + Ds = mod.run(As, Bs) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs, 1.0) + check_all(Ds, Ds_ref) + + Ds = mod.run(As, Bs, Cs, 1.0, 0.0) + check_all(Ds, Ds_ref) + + alpha = 2.0 + beta = -1.0 + Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)] + Ds = mod.run(As, Bs, Cs, alpha, beta) + check_all(Ds, Ds_ref) + + def test_conv2d_fprop(self): + torch.manual_seed(2023) + + dtype = torch.float16 + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32) + plan.activation = "relu" + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( + 1, 4, 4, 16, + 8, 3, 3, 16, + 0, 0, + 3, 3, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + + D_ref = alpha * torch.ops.aten.conv2d( + A, B, stride=stride, padding=padding + ) + beta * C + D_ref = torch.nn.functional.relu(D_ref) + D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta) + + assert torch.allclose(D, D_ref) + + # Test serial split-K + D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) + assert torch.allclose(D, D_serial_split_k) + + # Test parallel split-K + D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) + assert torch.allclose(D, D_parallel_split_k) + + + def test_conv2d_dgrad(self): + torch.manual_seed(2023) + dtype = torch.float16 + plan = cutlass_cppgen.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32) + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( + 1, 4, 4, 16, + 8, 3, 3, 16, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W) + D_ref = alpha * torch.nn.grad.conv2d_input( + input_size, B, A, + stride=stride, padding=padding + ) + beta * C + D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, ) + + assert torch.allclose(D, D_ref) + + def test_conv2d_wgrad(self): + torch.manual_seed(2023) + dtype = torch.float16 + plan = cutlass_cppgen.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32) + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_cppgen.shape.Conv2DProblemSize( + 1, 4, 4, 16, + 8, 3, 3, 16, + 0, 0, + 3, 3, + 1, 1, + ConvMode.CrossCorrelation, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S) + D_ref = alpha * torch.nn.grad.conv2d_weight( + B, weight_size, A, + stride=stride, padding=padding + ) + beta * C + D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta) + + assert torch.allclose(D, D_ref) + + # Test serial split-K + D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) + assert torch.allclose(D, D_serial_split_k) + + # Test parallel split-K + D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) + assert torch.allclose(D, D_parallel_split_k) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..5467469e74e05573fb297b009914e0980e5ab222 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py @@ -0,0 +1,198 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Unit test for compute node in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * +from cutlass_cppgen import swizzle + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTCompute(EVTTestCaseBase): + + def test_arith(self): + """ + Test Arithmatic op + """ + def evt_arith_compute(accum, C, alpha, beta, gamma): + D = ((accum + C) * alpha - gamma) / beta + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "gamma": 2.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_arith_compute, example_inputs) + input_keys = ["C", "alpha", "beta", "gamma"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_func_call(self): + """ + Test Function call + """ + def evt_func_call(accum, C, alpha, beta, gamma): + D = multiply_add(relu(accum + alpha) + C, beta, gamma) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "gamma": 2.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_func_call, example_inputs) + input_keys = ["C", "alpha", "beta", "gamma"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_func_call2(self): + """ + Test Function call + """ + + def evt_func_call2(accum, C, alpha, beta): + D = maximum(alpha * accum + beta * C, 0.0) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_func_call2, example_inputs) + input_keys = ["C", "alpha", "beta"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_tanh(self): + """ + Test Tanh op + """ + def evt_tanh(accum): + D = tanh(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_tanh, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_sigmoid(self): + """ + Test Sigmoid op + """ + def evt_sigmoid(accum): + D = sigmoid(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_sigmoid, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_gelu(self): + """ + Test GELU op + """ + def evt_gelu(accum): + D = gelu(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_gelu, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_exp(self): + """ + Test Exp op + """ + def evt_exp(accum): + D = exp(accum) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_exp, example_inputs) + input_keys = [] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a7b7f7a336dce0651f299d26b17df04952be99 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py @@ -0,0 +1,173 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for store nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTLayout(EVTTestCaseBase): + + def test_permute_1(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(0, 2, 1)) + D_permute = F_permute + permute(C, indices=(0, 2, 1)) + D = permute(D_permute, indices=(0, 2, 1)) + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") + def test_permute_2(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(0, 2, 1)) + D = F_permute + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, n, m)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, n, m)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") + def test_permute_3(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(1, 0, 2)) + D = F_permute + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (m, l, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (m, l, n)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_reshape(self): + """ + Test reshape + """ + def evt_reshape(accum, alpha, TensorE): + F = alpha * accum + E_reshape = reshape(TensorE, new_shape=(512, 1)) + D = F + E_reshape + return D + + example_inputs = { + "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)), + "alpha": 0.5, + "TensorE": self.fake_tensor(self.element, (16, 32)), + "D": self.fake_tensor(self.element, (self.l, self.m, self.n)), + } + + launcher = EVTTestBed(self.element, evt_reshape, example_inputs) + input_keys = ["alpha", "TensorE"] + result_keys = ["D"] + launcher.verify(self.problem_size, input_keys, result_keys, self.l) + + def test_reshape2(self): + """ + Test reshape + """ + def evt_reshape(accum, alpha, TensorE): + F = alpha * accum + F_reshape = reshape(F, new_shape=(2, 3, 512, 256)) + D = F_reshape + TensorE + return D + + example_inputs = { + "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)), + "alpha": 0.5, + "TensorE": self.fake_tensor(self.element, (2, 3, 1, self.n)), + "D": self.fake_tensor(self.element, (2, 3, self.m, self.n)), + } + + launcher = EVTTestBed(self.element, evt_reshape, example_inputs) + input_keys = ["alpha", "TensorE"] + result_keys = ["D"] + launcher.verify(self.problem_size, input_keys, result_keys, self.l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..57a5c6bb17bb44bf294cc7a6a749c706601034f6 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py @@ -0,0 +1,142 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for load nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTLoad(EVTTestCaseBase): + + def test_tensor_load(self): + """ + Load extra tensor with shape [m, n] + """ + def evt_tensor_load(accum, C, aux, aux_batch): + D = accum + C + aux + aux_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "aux": self.fake_tensor(self.element, (m, n)), + "aux_batch": self.fake_tensor(np.float32, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_tensor_load, example_inputs) + input_keys = ["C", "aux", "aux_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_row_broadcast(self): + """ + Load extra tensor with shape [1, n] + """ + def evt_row_broadcast(accum, C, bias, bias_batch): + D = accum + C + bias + bias_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (n,)), + "bias_batch": self.fake_tensor(np.float32, (l, 1, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_row_broadcast, example_inputs) + input_keys = ["C", "bias", "bias_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_column_broadcast(self): + """ + Load extra tensor with shape [m, 1] + """ + def evt_column_broadcast(accum, C, bias, bias_batch): + D = accum + C + bias + bias_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (m, 1)), + "bias_batch": self.fake_tensor(np.float32, (l, m, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_column_broadcast, example_inputs) + input_keys = ["C", "bias", "bias_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_scalar_broadcast(self): + """ + Load extra tensor with shape [1, 1] + """ + def evt_scalar_broadcast(accum, C, alpha, alpha_batch): + D = accum + C + alpha + alpha_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "alpha_batch": self.fake_tensor(np.float32, (l, 1, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_scalar_broadcast, example_inputs) + input_keys = ["C", "alpha", "alpha_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..30dc8fe0d5ec413f1da57a8fa0875ed5e7baa887 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py @@ -0,0 +1,319 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unittest for mixed types of nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * +from cutlass_cppgen.swizzle import ThreadblockSwizzleStreamK + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTMixed(EVTTestCaseBase): + + def test_same_variable_used_multiple_times(self): + """ + The same variable z0 is used multiple times + """ + def evt_aux_store(accum): + z0 = relu(accum) + D = z0 + z0 + return z0, D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "z0": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) + input_keys = ["accum"] + result_keys = ["z0", "D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_no_lca(self): + """ + The same variable z0 is used multiple times + """ + def evt_no_lca(accum, bias): + E = relu(accum) + F = E + bias + tmp_2 = E + 2 + D = tmp_2 + E + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (m,1), stride=(1,0)), + } + + launcher = EVTTestBed(self.element, evt_no_lca, example_inputs) + input_keys = ["accum", "bias"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_mixed_dag(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + if device_cc() == 80: + alignments = [2, 4, 8] + else: + # Sm90 EVT currently only supports 128-bit alignment + alignments = [8,] + for align in alignments: + for m, n, k, l in self.get_problem_sizes(align): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_float(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for align in [3, 2, 4]: + for m, n, k, l in self.get_problem_sizes(align): + example_inputs = { + "accum": self.fake_tensor(np.float32, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(np.float32, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(np.float32, (l, m, n)), + "cbias": self.fake_tensor(np.float32, (m, 1)), + "rbias": self.fake_tensor(np.float32, (n,)), + "D": self.fake_tensor(np.float32, (l, m, n)), + "F": self.fake_tensor(np.float32, (l, m, n)), + "F_row_max": self.fake_tensor(np.float32, (n,)), + "E_col_max": self.fake_tensor(np.float32, (m, 1)) + } + launcher = EVTTestBed(DataType.f32, evt_mixed_dag, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_stage2(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, epilogue_stages=2) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_partition_k(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + tile_description = { + "threadblock_shape": [128, 128, 64], + "warp_count": [2, 2, 2] + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, tile_description=tile_description, epilogue_stages=2) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_stream_k(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + # High per-sm occupancy tile_description + tile_description = { + "threadblock_shape": [128, 128, 32], + "warp_count": [2, 2, 1], + "stages": 3 + } + tds = [None, tile_description] + for td in tds: + for m, n, k, l in self.get_problem_sizes(8, k=960, batch_count=[1, 3]): + if l == 1: + example_inputs = { + "accum": self.fake_tensor(self.element, (m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (m, n)), + "F": self.fake_tensor(self.element, (m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + else: + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + if td is not None: + launcher = EVTTestBed( + self.element, evt_mixed_dag, example_inputs, + tile_description=td, + swizzling_functor=ThreadblockSwizzleStreamK, backend="torch") + else: + launcher = EVTTestBed( + self.element, evt_mixed_dag, example_inputs, + swizzling_functor=ThreadblockSwizzleStreamK, backend="torch") + + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_mixed_dag_no_batch(self): + def evt_mixed_dag_no_batch(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, _ in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (m, n)), + "F": self.fake_tensor(self.element, (m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag_no_batch, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, 1) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..b47f11e4f3bde3499948ae68b1b5bb79347f0fd1 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py @@ -0,0 +1,180 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for store nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTStore(EVTTestCaseBase): + + @unittest.skipIf(device_cc() != 90, "This test is only for CC 90") + def test_invalid_store(self): + """ + Test invalid store + """ + def evt_invalid_store(accum): + D = accum + F = D + 1 # D has users, which is not allowed on SM90 or higher + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)) + } + with self.assertRaisesRegex( + RuntimeError, + r"On SM90 or higher, D is expected to be a output node with 0 users " + r"to enable smem reuse between C and D, but got 1" + ): + launcher = EVTTestBed(self.element, evt_invalid_store, example_inputs) + + break # Only need to test once + + def test_aux_store(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_aux_store(accum, alpha, C): + F = alpha * accum + D = F + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_col_reduce(self): + """ + Reduction [m, n] -> [m, 1] + """ + def evt_row_reduce(accum, alpha, C): + acc_row_max = max(accum, dim=[2,]) + F = alpha * accum + F_row_max = max(F, dim=[0, 2]) + D = F + C + return D, F_row_max, acc_row_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(np.float32, (m, 1)), + "acc_row_max": self.fake_tensor(np.float32, (l, m, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_row_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_row_max", "acc_row_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_row_reduce(self): + """ + Reduction [m, n] -> [n] + """ + def evt_col_reduce(accum, alpha, C): + acc_col_max = max(accum, dim=[1,]) + F = alpha * accum + F_col_max = max(F, dim=[0, 1]) + D = F + C + return D, F_col_max, acc_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "F_col_max": self.fake_tensor(np.float32, (n,)), + "acc_col_max": self.fake_tensor(np.float32, (l, 1, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_col_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_col_max", "acc_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_scalar_reduce(self): + """ + Reduction [m, n] -> [1,] + """ + def evt_scalar_reduce(accum, alpha, C): + acc_max = max(accum, dim=[1, 2]) + F = alpha * accum + F_max = max(F, dim=[0, 1, 2]) + D = F + C + return D, F_max, acc_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "acc_max": self.fake_tensor(np.float32, (l, 1, 1)), + "F_max": self.fake_tensor(np.float32, (1,)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_scalar_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_max", "acc_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb84e2e8c85e602b45b9ee18ce324accd3a32cd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'evt_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py new file mode 100644 index 0000000000000000000000000000000000000000..62d375d856ffaef6be50b39b76121e0eb78a7465 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py @@ -0,0 +1,235 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Testbed classes of EVT +""" + +import torch +import unittest + +import cutlass_cppgen +from cutlass_cppgen import Tensor +import cutlass_cppgen.backend.evt +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils.datatypes import torch_type +from cutlass_cppgen.utils.profiler import CUDAEventProfiler + + +class EVTReferenceModule: + def __init__(self, layout_A, layout_B, layout_C, epilogue_visitor): + self.layout_A = layout_A + self.layout_B = layout_B + self.layout_C = layout_C + self.epilogue_visitor = epilogue_visitor + + def run(self, A, B, C, problem_size, alpha, beta, batch=1): + if self.layout_A == cutlass_cppgen.LayoutType.RowMajor: + A_row = A.view((batch, problem_size.m, problem_size.k)) + else: + A_col = A.view((batch, problem_size.k, problem_size.m)) + A_row = torch.permute(A_col, (0, 2, 1)) + + if self.layout_B == cutlass_cppgen.LayoutType.RowMajor: + B_row = B.view((batch, problem_size.k, problem_size.n)) + else: + B_col = B.view((batch, problem_size.n, problem_size.k)) + B_row = torch.permute(B_col, (0, 2, 1)) + + if self.layout_C == cutlass_cppgen.LayoutType.RowMajor: + C_row = C.view((batch, problem_size.m, problem_size.n)) + else: + C_col = C.view((batch, problem_size.n, problem_size.m)) + C_row = torch.permute(C_col, (0, 2, 1)) + + out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta + + if self.layout_C == cutlass_cppgen.LayoutType.ColumnMajor: + out = torch.permute(out_row, (0, 2, 1)) + else: + out = out_row + + return torch.flatten(out) + + def __call__(self, A, B, C, problem_size, batch=1, epilogue_args=None): + # Running the mainloop + accum = self.run( + A, B, C, problem_size, 1.0, 0.0, batch=batch + ).reshape(batch, problem_size.m, problem_size.n) + + # Running the epilogue + epilogue_args["accum"] = accum + references = self.epilogue_visitor(**epilogue_args) + + # Return the results + if not isinstance(references, tuple): + references = (references,) + return references + + +class EVTTestBed: + """ + Epilogue Visitor Testbed + """ + def __init__(self, element, evt_fn, example_inputs, profile=False, **kwargs) -> None: + self.element = element + layout = cutlass_cppgen.LayoutType.RowMajor + self.example_inputs = example_inputs + + # Create the Gemm plan + self.plan = cutlass_cppgen.op.Gemm(element=element, layout=layout, element_accumulator=torch.float32) + + if "tile_description" in kwargs: + self.plan.tile_description = kwargs["tile_description"] + + if "swizzling_functor" in kwargs: + self.plan.swizzling_functor = kwargs["swizzling_functor"] + + # Compile the epilogue visitor + epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_fn, example_inputs) + if "epilogue_stages" in kwargs: + epilogue_visitor.epilogue_stages = kwargs["epilogue_stages"] + self.plan.epilogue_visitor = epilogue_visitor + + # Reference model + self.reference_fn = EVTReferenceModule(layout, layout, layout, epilogue_visitor) + + self.profile = profile + + def get_torch_tensor(self, shape, dtype=None, fill=None): + if dtype is None: + dtype = self.element + + dtype = torch_type(dtype) + if fill is None: + return torch.ceil( + torch.empty(size=shape, dtype=dtype, device="cuda").uniform_(-4.5, 3.5) + ) + else: + return torch.full(shape, fill, dtype=dtype, device="cuda") + + def verify(self, problem_size, input_keys, result_keys, batch_count=1): + """ + Verify the results + """ + problem_size = GemmCoord(*problem_size) + + # Initiate the GEMM arguments + tensor_A = self.get_torch_tensor((batch_count, problem_size.m, problem_size.k)) + tensor_B = self.get_torch_tensor((batch_count, problem_size.k, problem_size.n)) + + # Initialize the epilogue args + epilogue_args = {} + for key in self.example_inputs.keys(): + if key in input_keys: + tensor = self.example_inputs[key] + if isinstance(tensor, Tensor): + epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element) + else: + epilogue_args[key] = tensor + elif key in result_keys: + tensor = self.example_inputs[key] + if isinstance(tensor, Tensor): + if "max" in key: + fill = -1000 + else: + fill = 0 + epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element, fill=fill) + else: + epilogue_args[key] = tensor + + tensor_D = epilogue_args["D"] + if "C" in epilogue_args: + tensor_C = epilogue_args["C"] + else: + tensor_C = tensor_D + # Run the device kernel + self.plan.run(tensor_A, tensor_B, tensor_C, tensor_D, visitor_args=epilogue_args) + + # Run the host reference + evt_args_inputs = {} + for key in input_keys: + evt_args_inputs[key] = epilogue_args[key] + + reference_results = self.reference_fn( + tensor_A, tensor_B, tensor_C, problem_size, batch_count, evt_args_inputs) + + # Compare the results + for result, ref in zip(result_keys, reference_results): + assert torch.equal( + epilogue_args[result].flatten(), + ref.masked_fill(torch.isnan(ref), float('inf')).flatten()) + + # Run profile + if self.profile: + profiler = CUDAEventProfiler( + self.plan, 100, 100, tensor_A, tensor_B, tensor_C, tensor_D, + visitor_args = epilogue_args + ) + print(f"Cutlass Python Duration: {profiler()}") + + +class EVTTestCaseBase(unittest.TestCase): + """ + Base class for EVT Unittest + """ + def __init__(self, methodName: str = "runTest", lmnk=(6, 512, 256, 128)) -> None: + super().__init__(methodName) + + self.element = cutlass_cppgen.DataType.f16 + self.l, self.m, self.n, self.k = lmnk + + self.problem_size = (self.m, self.n, self.k) + + torch.random.manual_seed(42) + + def fake_tensor(self, element, shape, stride=None): + if stride is None: + return Tensor(element=element, shape=shape, layout_tag=cutlass_cppgen.LayoutType.RowMajor) + else: + return Tensor(element=element, shape=shape, stride=stride) + + def get_problem_sizes(self, alignment, k=None, batch_count=[3,]): + k = k if k else self.k + problem_size_m = [alignment, 512 - 3 * alignment] + problem_size_n = [alignment, 512 - alignment] + if alignment % 8 == 0: + problem_size_m.append(768) + problem_size_n.append(768) + problem_size_l = batch_count + problem_sizes = [] + for m in problem_size_m: + for n in problem_size_n: + for l in problem_size_l: + problem_sizes.append((m, n, k, l)) + + return problem_sizes diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py new file mode 100644 index 0000000000000000000000000000000000000000..155426ab902d1f99eafc7b03c388fc79b4520317 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py @@ -0,0 +1,134 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +High-level tests for running batched GEMMs +""" + +from functools import partial +import logging +from math import prod +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc +import torch + +from utils import LayoutCombination + +cutlass_cppgen.set_log_level(logging.WARNING) + +torch.manual_seed(2023) + + +def pytorch_reference(A, B, C, alpha, beta): + # Get the batch count. Assume that any of A, B, and C + # with a batch dimension ahve matching batch count. Thus, + # we break out of the loop once we have found the first + # tensor containing a batch dimension. + batch_count = (1,) + for tensor in [A, B, C]: + if len(tensor.shape) > 2: + batch_count = tensor.shape[:-2] + break + + int_batch_count = prod(batch_count) + + def add_batch(tensor): + if len(tensor.shape) == 2: + return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1) + else: + return tensor.reshape(-1, tensor.size(-2), tensor.size(-1)) + + # Reshape tensors to have batch dimension + A = add_batch(A) + B = add_batch(B) + C = add_batch(C) + + ret = (torch.bmm(A, B) * alpha) + (C * beta) + reshape_vals = batch_count + C.shape[-2:] + return ret.reshape(*reshape_vals) + + +def initialize(rows, cols, batch): + tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half() + if len(batch) > 0 and prod(batch) > 1: + reshape_vals = batch + (rows, cols) + return tensor.reshape(*reshape_vals) + else: + return tensor.reshape(rows, cols) + + +class GemmF16Batched(unittest.TestCase): + def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool): + M = 512 + N = 256 + K = 128 + alpha = 1. + beta = 2. + + A = initialize(M, K, batch_count if batch_A else (1,)) + B = initialize(K, N, batch_count if batch_B else (1,)) + C = initialize(M, N, batch_count if batch_C else (1,)) + D = initialize(M, N, batch_count) + + plan = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass_cppgen.DataType.f32) + plan.run(A, B, C, D, alpha, beta) + reference = pytorch_reference(A, B, C, alpha, beta) + assert reference.equal(D) + + def test_batched_ABC(self): + self.run_batched((3,), True, True, True) + self.run_batched((2, 3), True, True, True) + + def test_batched_AB(self): + self.run_batched((3,), True, True, False) + self.run_batched((2, 3), True, True, False) + + def test_batched_AC(self): + self.run_batched((3,), True, False, True) + self.run_batched((2, 3), True, False, True) + + def test_batched_BC(self): + self.run_batched((3,), False, True, True) + self.run_batched((2, 3), False, True, True) + + def test_batched_A(self): + self.run_batched((3,), True, False, False) + self.run_batched((2, 3), True, False, False) + + def test_batched_B(self): + self.run_batched((3,), False, True, False) + self.run_batched((2, 3), False, True, False) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd26951ec5d8a1eb6cbe38491c64fde2873b9c3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py @@ -0,0 +1,128 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F16 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f16 + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..61aa295b966daf5943e7092572c98ee20143e2b5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py @@ -0,0 +1,146 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F16 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.f16 + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype, + warp_count=None, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_unit_cluster = partial(add_test_tensorop, cluster_shape=[1, 1, 1]) +add_test_unit_cluster(layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=3) +add_test_unit_cluster(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], stages=5) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) + +# Tests with different cluster shapes +add_test_cluster_shape = partial(add_test_tensorop, threadblock_shape=[64, 128, 64], stages=None) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 1, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 2, 1]) + +# Tests for different schedule modes +add_test_schedule = partial(add_test_specialized, layouts=LayoutCombination.TTN, alignments=[8, 8, 4], + element_output=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, + opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], stages=None) +add_test_schedule( + cluster_shape=[1, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized +) +add_test_schedule( + cluster_shape=[1, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative +) +add_test_schedule( + cluster_shape=[2, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized +) +add_test_schedule( + cluster_shape=[2, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative +) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], stages=2) +add_test_simt(layouts=LayoutCombination.NNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8]) +add_test_simt(layouts=LayoutCombination.TNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8]) +add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8]) +add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8]) +add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8]) + +# Tests with void-C kernels +add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None, + cluster_shape=[2, 1, 1], element_C=cutlass_cppgen.DataType.void) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..bf662b9208ab2a5343d0fd11106835b7d9a5b2e9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py @@ -0,0 +1,104 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F32 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f32 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF32Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF32Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4) +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..3075ddf74bf2a119759ca1a3e47c0815f4b0923c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py @@ -0,0 +1,103 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F64 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f64 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf36fc77436fef22882e98c752b7a599cf7fb95 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py @@ -0,0 +1,71 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F64 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.f64 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], + element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc']) + +add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.NNN, threadblock_shape=[128, 128, 8], stages=2) +add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.TTT, threadblock_shape=[ 64, 128, 8], stages=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..fef6d457a6528a61613d1295877a2b6b8f80fef5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py @@ -0,0 +1,112 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.e4m3 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF8E4M3Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Test with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with different cluster shapes +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with warp-specialized ping-pong schedule +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) + +# Tests for SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) + + +# +# Add a test for E5M2 +# +dtype = cutlass_cppgen.DataType.e5m2 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF8E5M2Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..0a002a5fbad80de5f7b29e42db0806469244914c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py @@ -0,0 +1,75 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with mixed operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype =cutlass_cppgen.DataType.f16 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmMixedSm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1], + opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass_cppgen.DataType.f32) + +# Test with upcast on A +add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN) + +# Test with upcast on B +add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..e226e23684147cb0a9cd5c1270468eb96c67ba15 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py @@ -0,0 +1,103 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.s8 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0101f78da3b62b599a5deeb89f5596a7e515ce --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py @@ -0,0 +1,98 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.s8 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 8], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 64, 32], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[ 4, 4, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with different cluster shapes +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with warp-specialized ping-pong schedule +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) + +# Tests for SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py new file mode 100644 index 0000000000000000000000000000000000000000..6ffda5b47e37f184c2352f0ee4e737635dbd4147 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py @@ -0,0 +1,423 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from math import prod +import os +import re +import subprocess + +import torch + +from cutlass_library import ( + DataType, + DataTypeSize, + GemmUniversalMode, + LayoutType, + OpcodeClass, + ShortDataTypeNames, + SwizzlingFunctor +) + +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass_cppgen.backend.reduction_operation import ReductionArguments, ReductionOperation +from cutlass_cppgen.shape import GemmCoord, MatrixCoord +from cutlass_cppgen.utils.datatypes import torch_type + + +class GemmUniversalLauncher: + def __init__( + self, + operation, + seed=2080, + verification=True, + iterations=500, + compiler_mode= "nvcc", + **kwargs, + ) -> None: + self.math_operation = operation.tile_description.math_instruction.math_operation + self.verification = verification + + if compiler_mode == "nvcc": + compiler.nvcc() + elif compiler_mode == "nvrtc": + compiler.nvrtc() + else: + raise Exception(f"Unexpected compiler string {compiler_mode}") + + op_list = [operation] + if operation.arch < 90: + # Split K via Python is currently only supported for pre-SM90 kernels + self.reduction_operation: ReductionOperation = ReductionOperation( + shape=MatrixCoord(4, 32 * operation.C.alignment), + C=operation.C, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_compute=operation.epilogue_functor.element_epilogue, + epilogue_functor=operation.epilogue_functor, + count=operation.C.alignment, + ) + op_list.append(self.reduction_operation) + + compiler.add_module(op_list, bypass_cache=False) + + self.operation = operation + + self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element) + self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element) + self.dtype_C = torch_type(operation.C.element) + self.dtype_D = torch_type(operation.epilogue_functor.element_output) + + element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]) + + if element_size == 1: + self.rand_max = 1 + self.rand_min = 0 + elif element_size <= 8: + self.rand_max = 1 + self.rand_min = -1 + elif element_size == 16: + self.rand_max = 4 + self.rand_min = -4 + else: + self.rand_max = 8 + self.rand_min = -8 + + self.seed = seed + + self.compute_type = operation.epilogue_functor.element_epilogue + self.accumulator_type = operation.tile_description.math_instruction.element_accumulator + + def print_problem_size(self, p, mode, batch_count): + if mode == GemmUniversalMode.Gemm: + mode = "Gemm" + elif mode == GemmUniversalMode.Batched: + mode = "GemmBatched" + elif mode == GemmUniversalMode.GemmSplitKParallel: + mode = "GemmSplitKParallel" + print(f"problem: {p.m}, {p.n}, {p.k}\n batch_count: {batch_count}\n mode: {mode}") + + def uniform_init(self, shape, dtype, layout): + size = prod(shape) + if dtype.is_floating_point: + # Initialize data in FP32 and call convert to the data type we desire. + # This is a workaround for the following error that occurs when attempting to + # call uniform_ on a tensor with torch.float8_e4m3fn data: + # RuntimeError: "check_uniform_bounds" not implemented for 'Float8_e4m3fn' + data = torch.ceil( + torch.empty(size=(size,), dtype=torch.float32, device="cuda").uniform_( + self.rand_min - 0.5, self.rand_max - 0.5) + ).to(dtype) + else: + # PyTorch does not currently support integer-typed matrix multiplications on GPU. + # Fall back to CPU for integer type references. + data = torch.empty(size=(size,), dtype=dtype, device="cpu").random_(self.rand_min, self.rand_max + 1) + + is_fp8 = dtype == getattr(torch, "float8_e4m3fn", -1) or dtype == dtype == getattr(torch, "float8_e5m2", -1) + + if dtype == torch.float64 or dtype == torch.float32 or is_fp8: + data = data.to("cpu") + + data_ref = data.reshape(shape) + + if layout == LayoutType.RowMajor: + data_cutlass = data_ref + else: + data_cutlass = data_ref.transpose(-1, -2).contiguous() + + data_cutlass = data_cutlass.to("cuda") + + # As of this writing, few operations in PyTorch are supported with FP8 data. + # Thus, we perform computation in FP32 for FP8 reference checks. + if is_fp8: + data_ref = data_ref.to(torch.float32) + + return data_cutlass, data_ref + + def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): + # If any tensor is on CPU, place all tensors on CPU unless only + # tensor C is on CPU + # Handle mixed-input cases by casting to the larger data type and overriding + # to whatever the data type of the larger type is + if self.dtype_A != self.dtype_B: + if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]: + tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device) + else: + tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device) + + devices = [x.device.type for x in [tensor_A, tensor_B]] + if tensor_C is not None: + devices.append(tensor_C.device.type) + + if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]: + device = torch.device("cpu") + else: + device = tensor_A.device + + tensor_A = tensor_A.to(device) + tensor_B = tensor_B.to(device) + if tensor_C is not None: + tensor_C = tensor_C.to(device) + + dtype = torch_type(self.compute_type) + alpha_torch = torch.tensor([alpha], device=device).to(dtype) + beta_torch = torch.tensor([beta], device=device).to(dtype) + + tmp = tensor_A @ tensor_B + tensor_D_ref = (alpha_torch * tmp) + if tensor_C is not None: + tensor_D_ref += (tensor_C * beta_torch) + return tensor_D_ref.to(self.dtype_D) + + def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0): + torch.random.manual_seed(self.seed) + + # Assign an actual batch count in cases where we are not running in batched mode. + # This is to differentiate between the number of split K slices and the batch count, + # which are overloaded within the single `batch_count` variable. + if mode == GemmUniversalMode.Batched: + true_batch_count = batch_count + else: + true_batch_count = 1 + + def transpose(layout): + if layout == LayoutType.RowMajor: + return LayoutType.ColumnMajor + else: + return LayoutType.RowMajor + + tensor_A, tensor_A_ref = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.k), + self.dtype_A, + self.operation.A.layout if not self.operation.switched else transpose(self.operation.B.layout), + ) + tensor_B, tensor_B_ref = self.uniform_init( + (true_batch_count, problem_size.k, problem_size.n), + self.dtype_B, + self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout), + ) + if self.dtype_C is not None: + tensor_C, tensor_C_ref = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.n), + self.dtype_C, + self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), + ) + else: + tensor_C = None + tensor_C_ref = None + + tensor_D, _ = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.n), + self.dtype_D, + self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), + ) + tensor_D = torch.zeros_like(tensor_D) + + if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]: + alpha = int(alpha) + beta = int(beta) + + # + # Launch kernel + # + + arguments = GemmArguments( + operation=self.operation, + problem_size=problem_size, + A=tensor_A, + B=tensor_B, + C=tensor_C, + D=tensor_D, + output_op=self.operation.epilogue_type(alpha, beta), + gemm_mode=mode, + split_k_slices=split_k_slices, + batch=batch_count, + ) + + if mode == GemmUniversalMode.GemmSplitKParallel: + reduction_arguments = ReductionArguments( + self.reduction_operation, + problem_size=[problem_size.m, problem_size.n], + partitions=split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op=self.reduction_operation.epilogue_type(alpha, beta), + ) + + self.operation.run(arguments) + + if mode == GemmUniversalMode.GemmSplitKParallel: + self.reduction_operation.run(reduction_arguments) + + passed = True + + if self.verification: + if mode == GemmUniversalMode.GemmSplitKParallel: + reduction_arguments.sync() + + # Free memory allocated by args because we are not + # calling `arguments.sync()` in this case (which will free memory) + arguments.free() + else: + arguments.sync() + tensor_D_ref = self.reference( + problem_size, + tensor_A_ref, + tensor_B_ref, + tensor_C_ref, + alpha, + beta, + ) + + tensor_D_ref = tensor_D_ref.to('cuda') + + if self.operation.switched or self.operation.C.layout == LayoutType.ColumnMajor: + tensor_D = tensor_D.transpose(-1, -2).contiguous() + + passed = tensor_D.equal(tensor_D_ref) + + try: + assert passed + except AssertionError: + self.print_problem_size(problem_size, mode, batch_count) + del arguments + if mode == GemmUniversalMode.GemmSplitKParallel: + del reduction_arguments + + return passed + + +def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"): + passed = True + + minimum_operand_element_size = min( + DataTypeSize[operation.A.element], DataTypeSize[operation.B.element] + ) + opcode_class = operation.tile_description.math_instruction.opcode_class + + if opcode_class == OpcodeClass.Simt: + alignment = 1 + else: + alignment = 128 // minimum_operand_element_size + + alignment_m = alignment + alignment_n = alignment + alignment_k = alignment + + # INT8 alignment constraints + if opcode_class == OpcodeClass.Simt: + A_is_s8 = operation.A.element == DataType.s8 + B_is_s8 = operation.B.element == DataType.s8 + + if A_is_s8 and operation.A.layout == LayoutType.ColumnMajor: + alignment_m = 4 + if B_is_s8 == DataType.s8 and operation.A.layout == LayoutType.RowMajor: + alignment_n = 4 + if A_is_s8 and B_is_s8 and (operation.A.layout == LayoutType.RowMajor or operation.B.layout == LayoutType.ColumnMajor): + alignment_k = 4 + + threadblock_k = operation.tile_description.threadblock_shape[2] + + assert testcase != "interleaved" + + supports_split_k = operation.arch < 90 and not operation.swizzling_functor == SwizzlingFunctor.StreamK + + if testcase == "multistage": + modes = [GemmUniversalMode.Gemm] + problem_size_m = [16, 528] + problem_size_n = [16, 528] + problem_size_k = [ + threadblock_k, + threadblock_k * operation.tile_description.stages + + operation.tile_description.math_instruction.instruction_shape[2], + ] + problem_alpha = [1.0] + problem_beta = [0.0] + batch_counts = [1] + else: + modes = [GemmUniversalMode.Gemm] + batch_counts = [1, 2, 3, 5, 7] + if supports_split_k: + modes.append(GemmUniversalMode.GemmSplitKParallel) + + problem_size_m = [alignment_m, 512 - 3 * alignment_m] + problem_size_n = [alignment_n, 512 - 2 * alignment_n] + if operation.tile_description.stages is None: + stages_for_k_calc = 7 + else: + stages_for_k_calc = operation.tile_description.stages + problem_size_k = [ + alignment_k, + threadblock_k * stages_for_k_calc - alignment_k, + threadblock_k * stages_for_k_calc * 3 - alignment_k, + ] + problem_alpha = [1.0] + problem_beta = [2.0] + + testbed = GemmUniversalLauncher(operation, compiler_mode=compilation_mode) + + for mode in modes: + for m in problem_size_m: + for n in problem_size_n: + for k in problem_size_k: + for batch_count in batch_counts: + for alpha in problem_alpha: + for beta in problem_beta: + # skip very small K problems + if testcase == "universal": + if k // batch_count < 2 * threadblock_k: + continue + + problem_size = GemmCoord(m, n, k) + + if supports_split_k: + split_k_slices = batch_count + else: + split_k_slices = 1 + + overridden_mode = mode + if mode == GemmUniversalMode.Gemm and batch_count > 1: + overridden_mode = GemmUniversalMode.Batched + + passed = testbed.run( + overridden_mode, + problem_size, + batch_count, + split_k_slices, + alpha, + beta, + ) + + if not passed: + return False + + return passed diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5e7467b1e0040ce3012ff8541dfbac381bb861 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'gemm_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..28bba3e922961c96df75f8685e3064ab55cbbc87 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py @@ -0,0 +1,260 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_library import SubstituteTemplate + +import cutlass_cppgen +from cutlass_library import ( + DataTypeNames, + EpilogueScheduleSuffixes, + KernelScheduleSuffixes, + LayoutType, + OpcodeClassNames, + ShortDataTypeNames, + ShortLayoutTypeNames +) +from cutlass_cppgen.backend import library + +from gemm_testbed import test_all_gemm + + +class Layout: + """ + Utility class to map transpose and non-transpose terminology to row- and column-major terminology + """ + + T = LayoutType.RowMajor + N = LayoutType.ColumnMajor + + +class LayoutCombination: + """ + Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs + """ + + NNN = (Layout.N, Layout.N, Layout.N) + NNT = (Layout.N, Layout.N, Layout.T) + NTN = (Layout.N, Layout.T, Layout.N) + NTT = (Layout.N, Layout.T, Layout.T) + TNN = (Layout.T, Layout.N, Layout.N) + TNT = (Layout.T, Layout.N, Layout.T) + TTN = (Layout.T, Layout.T, Layout.N) + TTT = (Layout.T, Layout.T, Layout.T) + + +def get_name( + layouts, + alignments, + element_output, + element_accumulator, + element_epilogue, + cluster_shape, + threadblock_shape, + stages, + element_a, + element_b, + element_c, + arch, + opclass, + kernel_schedule=None, + epilogue_schedule=None, + suffix="", +): + """ + Generates a procedural name for a test case. + + :param layouts: indexable container of layouts of A, B, and C operands + :param alignments: indexable container of alignments of A, B, and C operands + :param element_output: data type of the output element + :param element_accumulator: data type used in accumulation + :param element_epilogue: data type used in computing the epilogue + :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_c: data type of operand C + :param arch: compute capability of kernel being generated + :type arch: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param kernel_schedule: kernel_schedule type + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue_schedule type + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param suffix: additional string to add to the suffix of the name + :type suffix: str + + :return: str + """ + name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}" + return SubstituteTemplate( + name_format, + { + "arch": str(arch), + "eA": DataTypeNames[element_a], + "eB": DataTypeNames[element_b], + "eC": DataTypeNames[element_c], + "lA": ShortLayoutTypeNames[layouts[0]], + "lB": ShortLayoutTypeNames[layouts[1]], + "lC": ShortLayoutTypeNames[layouts[2]], + "opclass": OpcodeClassNames[opclass], + "acc": DataTypeNames[element_accumulator], + "cM": str(cluster_shape[0]), + "cN": str(cluster_shape[1]), + "cK": str(cluster_shape[2]), + "tbM": str(threadblock_shape[0]), + "tbN": str(threadblock_shape[1]), + "tbK": str(threadblock_shape[2]), + "stages": str(stages) if stages is not None else "auto", + "aA": str(alignments[0]), + "aB": str(alignments[1]), + "aC": str(alignments[2]), + "k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule], + "e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule], + "suffix": "" if suffix is None else suffix, + }, + ) + + +def add_test_gemm( + cls=None, + cc=None, + element=None, + layouts=None, + alignments=None, + element_output=None, + element_accumulator=None, + cluster_shape=None, + threadblock_shape=None, + warp_count=None, + stages=None, + opclass=None, + swizzle=None, + kernel_schedule=None, + epilogue_schedule=None, + compilation_modes=['nvcc', 'nvrtc'], + element_A=None, + element_B=None, + element_C=None): + """ + Create test-running functions with the given specification and set it as a method of ``cls``. + + :param cls: class to which the generated method will be added + :type cls: type + :param cc: compute capability to compile for + :type cc: int + :param element: data type of A and B operands + :type element: cutlass_cppgen.DataType.f16 + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass_cppgen.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass_cppgen.DataType + :param cluster_shape: dimensions of clusters + :type cluster_shape: list or tuple + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param swizzle: threadblock swizzling functor + :param kernel_schedule: kernel schedule to use + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue schedule to use + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc') + :type compilation_modes: list, + :param element_A: data type of operand A. If set, overrides ``element`` + :type element_A: cutlass_cppgen.DataType + :param element_B: data type of operand B. If set, overrides ``element`` + :type element_B: cutlass_cppgen.DataType + :param element_C: data type of operand C. If set, overrides ``element`` + :type element_C: cutlass_cppgen.DataType + """ + + if element_A is None: + element_A = element + if element_B is None: + element_B = element + if element_C is None: + element_C = element + if element_output is None: + element_output = element + if element_accumulator is None: + element_accumulator = element + + for compilation_mode in compilation_modes: + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator, + kernel_cc=cc) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + + td = plan.tile_descriptions()[0] + + if warp_count is not None: + td.warp_count = warp_count + td.threadblock_shape = threadblock_shape + td.stages = stages + td.cluster_shape = cluster_shape + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode)) + + element_epilogue = element_accumulator + name = get_name( + layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator, + element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape, + stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass, + kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}') + + setattr(cls, name, run) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/installation.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/installation.py new file mode 100644 index 0000000000000000000000000000000000000000..f550c394812c7fede55070e4c99c4471a69c2f88 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/installation.py @@ -0,0 +1,57 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests for a successful installation of the CUTLASS Python interface +""" + +import os +import unittest + +import cutlass_cppgen +import cutlass_library + + +class InstallationTest(unittest.TestCase): + def test_cutlass_source_paths(self): + """ + Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages + """ + src_file = 'include/cutlass/cutlass.h' + library_file = os.path.join(cutlass_library.source_path, src_file) + cutlass_file = os.path.join(cutlass_cppgen.CUTLASS_PATH, src_file) + assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded." + assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded." + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5d46d45d617198a46bec85cd7218cb5431a7b1 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py @@ -0,0 +1,284 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests the high-level Conv2d interface +""" + +from math import ceil +import unittest + +import cutlass_cppgen +import cutlass_cppgen.utils.datatypes as datatypes +from cutlass_cppgen.backend.utils.device import device_cc +from utils import ExpectException +import os + + +class Conv2dEquivalence: + """ + Helper class for testing the equivalence of different constructions of the Conv2d interface + """ + def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator, + alignment_A, alignment_B, alignment_C): + + self.element_A = element_A + self.element_B = element_B + self.element_C = element_C + self.element_D = element_D + self.element_accumulator = element_accumulator + self.alignment_A = alignment_A + self.alignment_B = alignment_B + self.alignment_C = alignment_C + + self.conv_kind = conv_kind + + self.plan = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C, + element_D=element_D, element_accumulator=element_accumulator) + + self.op = self.plan.construct( + alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_C=self.alignment_C) + + def _plans_equal(self, other_plan) -> bool: + """ + Compares whether two plans are equal + + :param other_plan: plan to compare against the default Conv2d + :type other_plan: cutlass_cppgen.op.Conv2d + + :return: whether `other_plan` is equivalent to `self.plan` + :rtype: bool + """ + other_op = other_plan.construct( + alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_C=self.alignment_C) + + return self.op.rt_module.emit() == other_op.rt_module.emit() + + def generic_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types + and layouts for constructing the Conv2d interface + """ + if not datatypes.is_numpy_available(): + return + + # Test when specifying all parameters + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A and B as tensors using generic element and output + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test without explicit accumulator. Only run if the type of C and the accumulator are equal + if self.element_C == self.element_accumulator: + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_C=self.element_C, + element_D=self.element_D, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test with only the generic types. Only rune if the types of A, B, C, and D are the same + if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D + and self.element_A == self.element_accumulator): + plan_other = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=self.element_A) + assert self._plans_equal(plan_other) + + def numpy_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend + """ + if not datatypes.is_numpy_available(): + return + + import numpy as np + type_A = datatypes.numpy_type(self.element_A) + type_B = datatypes.numpy_type(self.element_B) + type_C = datatypes.numpy_type(self.element_C) + type_D = datatypes.numpy_type(self.element_D) + type_accum = datatypes.numpy_type(self.element_accumulator) + + size = (2, 2) + A = np.zeros(size, dtype=type_A) + B = np.zeros(size, dtype=type_B) + C = np.zeros(size, dtype=type_C) + D = np.zeros(size, dtype=type_D) + + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) + + def torch_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend + """ + if not datatypes.is_torch_available(): + return + + import torch + type_A = datatypes.torch_type(self.element_A) + type_B = datatypes.torch_type(self.element_B) + type_C = datatypes.torch_type(self.element_C) + type_D = datatypes.torch_type(self.element_D) + type_accum = datatypes.torch_type(self.element_accumulator) + + size = (2, 2) + + A = torch.empty(size, dtype=type_A) + B = torch.empty(size, dtype=type_B) + C = torch.empty(size, dtype=type_C) + D = torch.empty(size, dtype=type_D) + + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) + + def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D): + # Test when specifying all parameters via tensors + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A as tensors + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + if type_A == type_B: + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A) + assert self._plans_equal(plan_np) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if type_C == type_accum: + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D) + assert self._plans_equal(plan_np) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum): + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=type_A) + assert self._plans_equal(plan_np) + + def test_all(self): + """ + Runs all tests on the Gemm interface + """ + self.generic_test() + self.numpy_test() + self.torch_test() + + +@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') +class ConvEquivalenceTest(unittest.TestCase): + """ + Tests the equivalence of different constructions of the Conv2d interface + """ + pass + +type2alignment = { + cutlass_cppgen.DataType.f16: 8, + cutlass_cppgen.DataType.f32: 4 +} + +def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator): + + test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}" + + def run(self): + conv2d_eq = Conv2dEquivalence( + conv_kind=conv_kind, + element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_D, + element_accumulator=element_accumulator, + alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B], + alignment_C=type2alignment[element_C] + ) + conv2d_eq.test_all() + + setattr(ConvEquivalenceTest, test_name, run) + +for conv_kind in ["fprop", "wgrad", "dgrad"]: + for types in [ + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32], + [cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32] + ]: + add_test(conv_kind, types[0], types[1], types[2], types[3], types[4]) + + +@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') +class Conv2dErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the high-level Gemm interface + """ + + def test_alignment(self): + """ + Tests case in which the alignment specified is unsupported + """ + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) + + with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'): + op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3) + + def test_invalid_tile_description(self): + """ + Tests scenarios in which an invalid tile description is provided for a given CC + """ + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) + + td = plan.tile_descriptions()[0] + td.threadblock_shape=[17, 32, 5] + + plan.tile_description = td + with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'): + plan.compile() + # Clean up the error message + os.remove("./cutlass_python_compilation_device_error.txt") + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..e7d67f4d07f01b0936ff5796bfb6fe4c98b5c031 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py @@ -0,0 +1,254 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Test the EVT interface +""" + +import numpy as np +import unittest + +import cutlass_cppgen +from cutlass_cppgen import LayoutType, Tensor +from cutlass_cppgen.backend.utils.device import device_cc +from cutlass_cppgen.epilogue import reshape, permute + +from utils import ExpectException + + +@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only") +class EVTErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the EVT interface + """ + @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT requires root node be 'D'") + def test_root_not_d(self): + """ + Test when "D" does not exist in Sm90 EVT + """ + def evt_root_not_d(accum, alpha): + F = accum * alpha + return F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.2, + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(device_cc() == 90, + "SyntaxError: Sm90 EVT requires the epilogue to have a returned tensor D, " + "but the variable 'D' is not found in the return values.", True): + + cutlass_cppgen.epilogue.trace(evt_root_not_d, example_tensors) + + def test_no_accum(self): + """ + Test when "accum" is not in input arguments + """ + def evt_no_accum(alpha, C): + D = alpha * C + return D + + example_tensors = { + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.2, + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Cannot find 'accum' in the argument list.", True): + cutlass_cppgen.epilogue.trace(evt_no_accum, example_tensors) + + @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT has concern on smem size") + def test_too_much_shared_memory(self): + """ + Test when the epilogue consumes too much shared memory + """ + def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8): + D1 = accum + C1 + D2 = D1 + C2 + D3 = D2 + C3 + D4 = D3 + C4 + D5 = D4 + C5 + D6 = D5 + C6 + D7 = D6 + C7 + D = D7 + C8 + return D, D1, D2, D3, D4, D5, D6, D7 + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C1": self.fake_tensor(np.float16, (6, 512, 512)), + "C2": self.fake_tensor(np.float16, (6, 512, 512)), + "C3": self.fake_tensor(np.float16, (6, 512, 512)), + "C4": self.fake_tensor(np.float16, (6, 512, 512)), + "C5": self.fake_tensor(np.float16, (6, 512, 512)), + "C6": self.fake_tensor(np.float16, (6, 512, 512)), + "C7": self.fake_tensor(np.float16, (6, 512, 512)), + "C8": self.fake_tensor(np.float16, (6, 512, 512)), + "D1": self.fake_tensor(np.float16, (6, 512, 512)), + "D2": self.fake_tensor(np.float16, (6, 512, 512)), + "D3": self.fake_tensor(np.float16, (6, 512, 512)), + "D4": self.fake_tensor(np.float16, (6, 512, 512)), + "D5": self.fake_tensor(np.float16, (6, 512, 512)), + "D6": self.fake_tensor(np.float16, (6, 512, 512)), + "D7": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_too_much_shared_memory, example_tensors) + + plan = cutlass_cppgen.op.Gemm( + element=np.float16, layout=cutlass_cppgen.LayoutType.RowMajor, + element_accumulator=np.float32 + ) + + with ExpectException(True, + "RuntimeError: The epilogue consumes too much shared memory. " + "No valid tile description is found in the generator.", True): + plan.epilogue_visitor = epilogue_visitor + + def test_not_ssa(self): + """ + Test when the epilogue is not in SSA + """ + def evt_redefine(accum, C, alpha): + F = accum + C + F = F * alpha + D = F + return D, F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.5, + "D": self.fake_tensor(np.float16, (6, 512, 512)), + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Variable 'F' cannot be defined twice.", True): + cutlass_cppgen.epilogue.trace(evt_redefine, example_tensors) + + def evt_undefine(accum, alpha): + F = accum + C + D = F * alpha + return D, F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.5, + "D": self.fake_tensor(np.float16, (6, 512, 512)), + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Variable 'C' is undefined.", True): + cutlass_cppgen.epilogue.trace(evt_undefine, example_tensors) + + def test_missing_example_tensor(self): + """ + Test when the example tensor of an input/output variable is not provided + """ + def evt_missing_example_tensor(accum, C): + D = accum + C + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "RuntimeError: Example input for D is not provided.", True): + cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "RuntimeError: Example input for C is not provided.", True): + cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) + + def test_return_expression(self): + """ + Test when the return value is an expression + """ + def evt_return_expr(accum, C): + return accum + C + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "SyntaxError: Return value cannot be an expression", True): + cutlass_cppgen.epilogue.trace(evt_return_expr, example_tensors) + + def test_incompatible_shape(self): + """ + Test when the shape of example tensors are incompatible + """ + def evt_incompatible_shape(accum, C): + D = accum + C + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 256, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, + "RuntimeError: Dimension mismatch between accum(6, 256, 512), C(6, 512, 512).", True): + cutlass_cppgen.epilogue.trace(evt_incompatible_shape, example_tensors) + + def test_no_matching_impl(self): + def evt_no_matching_impl(accum, bias): + D = accum + reshape(permute(bias, indices=(1, 0)), new_shape=(512, 1)) + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 256)), + "bias": self.fake_tensor(np.float16, (16, 32)), + "D": self.fake_tensor(np.float16, (6, 512, 256)) + } + + with ExpectException(True, "NotImplementedError: No matching op for node bias with stride (0, (1, 32), 0).", True): + cutlass_cppgen.epilogue.trace(evt_no_matching_impl, example_tensors) + # + # Helper functions + # + + def fake_tensor(self, element, shape): + return Tensor(element=element, shape=shape, layout_tag=LayoutType.RowMajor) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..2913d5933f5342cc58b4f252657a724d2c7692da --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py @@ -0,0 +1,354 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests the high-level GEMM interface +""" + +from math import ceil +import unittest + +import cutlass_cppgen +import cutlass_cppgen.utils.datatypes as datatypes +from cutlass_cppgen.backend.utils.device import device_cc +from utils import ExpectException + + +class GemmEquivalence: + """ + Helper class for testing the equivalence of different constructions of the Gemm interface + """ + def __init__(self, element_A, element_B, element_C, element_D, element_accumulator, + layout_A, layout_B, layout_C, alignment_A, alignment_B, alignment_C): + self.element_A = element_A + self.element_B = element_B + self.element_C = element_C + self.element_D = element_D + self.element_accumulator = element_accumulator + self.layout_A = layout_A + self.layout_B = layout_B + self.layout_C = layout_C + self.alignment_A = alignment_A + self.alignment_B = alignment_B + self.alignment_C = alignment_C + self.plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C, + element_D=element_D, element_accumulator=element_accumulator, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C) + self.op = self.plan.construct(alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + + def _plans_equal(self, other_plan) -> bool: + """ + Compares whether two plans are equal + + :param other_plan: plan to compare against the default GEMM + :type other_plan: cutlass_cppgen.op.Gemm + + :return: whether `other_plan` is equivalent to `self.plan` + :rtype: bool + """ + other_op = other_plan.construct(alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C) + + # Compare whether the operations are equal by comparing the C++ code that would be emitted for them + return self.op.rt_module.emit() == other_op.rt_module.emit() + + def generic_test(self): + """ + Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types + and layouts for constructing the Gemm interface + """ + if not datatypes.is_numpy_available(): + return + + # Test when specifying all parameters + plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A + plan_other = cutlass_cppgen.op.Gemm(element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_B=self.layout_B, layout_C=self.layout_C, + element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + # Only run this test if the layouts and types for A and B are equal. + if self.element_A == self.element_B and self.layout_A == self.layout_B: + plan_other = cutlass_cppgen.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_C=self.layout_C, element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if self.element_C == self.element_accumulator: + plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, layout_A=self.layout_A, layout_B=self.layout_B, + layout_C=self.layout_C) + assert self._plans_equal(plan_other) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D + and self.element_A == self.element_accumulator and + self.layout_A == self.layout_B and self.layout_A == self.layout_C): + plan_other = cutlass_cppgen.op.Gemm(element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + def numpy_test(self): + """ + Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend + """ + if not datatypes.is_numpy_available(): + return + + import numpy as np + type_A = datatypes.numpy_type(self.element_A) + type_B = datatypes.numpy_type(self.element_B) + type_C = datatypes.numpy_type(self.element_C) + type_D = datatypes.numpy_type(self.element_D) + type_accum = datatypes.numpy_type(self.element_accumulator) + + layout_to_order = { + cutlass_cppgen.LayoutType.RowMajor: 'C', + cutlass_cppgen.LayoutType.ColumnMajor: 'F' + } + size = (2, 2) + A = np.zeros(size, order=layout_to_order[self.layout_A], dtype=type_A) + B = np.zeros(size, order=layout_to_order[self.layout_B], dtype=type_B) + C = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_C) + D = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_D) + + # Test when specifying all parameters via tensors + plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A as tensors + plan_np = cutlass_cppgen.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + # Only run this test if the layouts and types for A and B are equal. + if type_A == type_B and self.layout_A == self.layout_B: + plan_np = cutlass_cppgen.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A) + assert self._plans_equal(plan_np) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if type_C == type_accum: + plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D) + assert self._plans_equal(plan_np) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum and + self.layout_A == self.layout_B and self.layout_A == self.layout_C): + plan_np = cutlass_cppgen.op.Gemm(element=type_A, layout=self.layout_A) + assert self._plans_equal(plan_np) + + def test_all(self): + """ + Runs all tests on the Gemm interface + """ + self.generic_test() + self.numpy_test() + + +class GemmEquivalenceTest(unittest.TestCase): + """ + Tests the equivalence of different constructions of the Gemm interface + """ + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_8_8_8(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f32_ntn_8_8_8(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, + layout_A=cutlass_cppgen.LayoutType.ColumnMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.ColumnMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_4_4_4(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for F64 Tensor Core tests.") + def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f64, element_B=cutlass_cppgen.DataType.f64, element_C=cutlass_cppgen.DataType.f64, + element_D=cutlass_cppgen.DataType.f64, element_accumulator=cutlass_cppgen.DataType.f64, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.ColumnMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=1, alignment_B=1, alignment_C=1) + gemm_eq.test_all() + + +class GemmErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the high-level Gemm interface + """ + + def test_alignment(self): + """ + Tests case in which the alignment specified is unsupported + """ + plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + + with ExpectException(True, 'Alignment 16 is not supported for F16. The construction should fail.'): + op = plan.construct(alignment_A=16, alignment_B=16, alignment_C=16) + + def test_tensorop_availability(self): + """ + Tests case in which only SIMT operations are available but TensorOp is requested + """ + cc = device_cc() + + # F64 Tensor Core operations are only avaiable on certain devices + supports_tensorop_f64 = cc in [80, 89, 90] + plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f64, layout=cutlass_cppgen.LayoutType.RowMajor) + + error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}' + with ExpectException(not supports_tensorop_f64, error_msg): + plan.opclass = cutlass_cppgen.OpcodeClass.TensorOp + + expected_opclass = cutlass_cppgen.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass_cppgen.OpcodeClass.Simt + assert plan.opclass == expected_opclass, f'Expected opclass to be {expected_opclass}, but received {plan.opclass} for SM{cc}' + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for F16 Tensor Core tests.") + def test_opclass_switch(self): + """ + Tests cases in which the opcode class in question is switched (e.g., from TensorOp to SIMT) + """ + plan = cutlass_cppgen.op.Gemm( element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + assert plan.opclass == cutlass_cppgen.OpcodeClass.TensorOp + + # Ensure that all tile descriptions have opclass of TensorOp + for td in plan.tile_descriptions(): + assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.TensorOp + + plan.opclass = cutlass_cppgen.OpcodeClass.Simt + + # Ensure that all tile descriptions have opclass of Simt + for td in plan.tile_descriptions(): + assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.Simt + + def test_invalid_tile_description(self): + """ + Tests scenarios in which an invalid tile description is provided for a given CC + """ + cc = device_cc() + plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + td = plan.tile_descriptions()[0] + stages = td.stages + + # Zero stage count is valid for SM90+, as this is used to indicate that the builder's auto stage + # count should be used + with ExpectException(cc < 90, f'Requested zero stages'): + td.stages = 0 + plan.construct(td) + + if cc < 90: + with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): + td.stages = 3 + plan.construct(td) + elif cc == 90: + original_kschedule = td.kernel_schedule + original_eschedule = td.epilogue_schedule + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.NoSmemWarpSpecialized + td.stages = 3 + plan.construct(td) + # Reset schedules + td.kernel_schedule = original_kschedule + td.epilogue_schedule = original_eschedule + elif cc in [100, 101, 103]: + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.stages = 3 + plan.construct(td) + + with ExpectException(True, f'Requested too many stages'): + td.stages = 100 + plan.construct(td) + + # Reset stage count + td.stages = stages + + cluster_shape = td.cluster_shape + with ExpectException(cc < 90, f'Requested non-unit cluster shape on SM{cc}'): + td.cluster_shape = [2, 1, 1] + plan.construct(td) + + # Reset cluster shape + td.cluster_shape = cluster_shape + + with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized + plan.construct(td) + + with ExpectException(cc == 90, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto + plan.construct(td) + + with ExpectException(cc == 90, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized + plan.construct(td) + + with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative + td.tile_scheduler = cutlass_cppgen.TileSchedulerType.StreamK + plan.construct(td) + + # Ensure that all returned tile descriptions are unique + ops = {} + for i, td in enumerate(plan.tile_descriptions()): + op = plan.construct(td) + code_str = op.rt_module.emit() + if code_str in ops: + conflicting_td = ops[code_str] + assert False, f'Multiple tile descriptions emitted {code_str}\nTile descriptions are:\n{td}\n{conflicting_td}' + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9f93ca26e2d79a15dab4dd0045836ebd9fe62757 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py @@ -0,0 +1,69 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Helper functions & classes for interface test +""" +class ExpectException: + """ + Utility class to assert that an exception was raised when expected + + Example: + + .. highlight:: python + .. code-block:: python + + with ExceptionExpected(True, 'Division by zero'): + x = 1.0 / 0.0 + + :param exception_expected: whether an exception is expected to be raised + :type exception_expected: bool + :param message: message to print if an exception is raised when not expected or vice versa + :type message: str + """ + def __init__(self, exception_expected: bool, message: str = '', verify_msg=False): + self.exception_expected = exception_expected + self.message = message + self.verify_msg = verify_msg + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + exception_raised = exc_type is not None + assert self.exception_expected == exception_raised, self.message + if self.verify_msg: + exc_message = f"{exc_type.__name__}: {exc_val}" + assert exc_message == self.message, f"expect error message {self.message}, got {exc_message}" + + # Suppress the exception + return True diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cdc421ccffffeb7bd1696aaf9916330a6625ca --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py @@ -0,0 +1,75 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility script for discovering and running all PyCuTe tests +""" + +import argparse +import logging +import pathlib +import unittest + + +def numeric_log_level(log_level: str) -> int: + """ + Converts the string identifier of the log level into the numeric identifier used + in setting the log level + + :param x: string representation of log level (e.g., 'INFO', 'DEBUG') + :type x: str + + :return: numeric representation of log level + :rtype: int + """ + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f"Invalid log level: {log_level}") + return numeric_level + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, + help='Logging level to be used by the generator script') + args = parser.parse_args() + + # Set the logging level based on the user-provided `--log-level` command-line option + logging.basicConfig(level=args.log_level) + + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, "test_*.py") + test_runner = unittest.runner.TextTestRunner() + results = test_runner.run(tests) + if not results.wasSuccessful(): + raise Exception("Test cases failed") diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py new file mode 100644 index 0000000000000000000000000000000000000000..d4330377cab7079ea16422f194ddf4f2403ea507 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py @@ -0,0 +1,95 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.coalesce +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestCoalesce(unittest.TestCase): + def helper_test_coalesce(self, layout): + layoutR = coalesce(layout) + + _LOGGER.debug(f"{layout} => {layoutR}") + + self.assertEqual(size(layoutR), size(layout)) + + for i in range(size(layout)): + self.assertEqual(layoutR(i), layout(i)) + + def test_coalesce(self): + layout = Layout(1,0) + self.helper_test_coalesce(layout) + + layout = Layout(1,1) + self.helper_test_coalesce(layout) + + layout = Layout((2,4)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6), (1,6,2)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,6), (1,7,2)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,6), (4,7,8)) + self.helper_test_coalesce(layout) + + layout = Layout((2,(4,6))) + self.helper_test_coalesce(layout) + + layout = Layout((2,4), (4,1)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6), (24,6,1)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,3), (2,4,4)) + self.helper_test_coalesce(layout) + + layout = Layout(((2,2),(2,2)), ((1,4),(8,32))) + self.helper_test_coalesce(layout) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8684a55b19c90eae11ddd1cca011c2ff8270b5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py @@ -0,0 +1,92 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.complement +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestComplement(unittest.TestCase): + def helper_test_complement(self, layout): + layoutR = complement(layout) + + _LOGGER.debug(f"{layout} => {layoutR}") + + # Post-condition: test disjointness of the codomains + for a in range(size(layout)): + for b in range(size(layoutR)): + assert (layout(a) != layoutR(b)) or (layout(a) == 0 and layoutR(b) == 0) + + def test_complement(self): + test = Layout(1,0) + self.helper_test_complement(test) + + test = Layout(1,1) + self.helper_test_complement(test) + + test = Layout(4,0) + self.helper_test_complement(test) + + test = Layout((2,4),(1,2)) + self.helper_test_complement(test) + + test = Layout((2,3),(1,2)) + self.helper_test_complement(test) + + test = Layout((2,4),(1,4)) + self.helper_test_complement(test) + + test = Layout((2,4,8),(8,1,64)) + self.helper_test_complement(test) + + test = Layout(((2,2),(2,2)),((1,4),(8,32))) + self.helper_test_complement(test) + + test = Layout((2,(3,4)),(3,(1,6))) + self.helper_test_complement(test) + + test = Layout((4,6),(1,6)) + self.helper_test_complement(test) + + test = Layout((4,10),(1,10)) + self.helper_test_complement(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py new file mode 100644 index 0000000000000000000000000000000000000000..6c27eb7fe6cbb7bbbea7bd644ac8e64a2fc853c9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py @@ -0,0 +1,213 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.composition +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestComposition(unittest.TestCase): + def helper_test_composition(self, layoutA, layoutB): + layoutR = composition(layoutA, layoutB) + + _LOGGER.debug(f"{layoutA} o {layoutB} => {layoutR}") + + # True post-condition: Every coordinate c of layoutB with L1D(c) < size(layoutR) is a coordinate of layoutR. + + # Test that R(c) = A(B(c)) for all coordinates c in layoutR + for i in range(size(layoutR)): + self.assertEqual(layoutR(i), layoutA(layoutB(i))) + + def test_composition(self): + layoutA = Layout(1,0) + layoutB = Layout(1,0) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,0) + layoutB = Layout(1,1) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,1) + layoutB = Layout(1,0) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,1) + layoutB = Layout(1,1) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (0)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4), (0)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((1), (0)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((1), (0)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((2), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((2), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12), (2)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((4,3), (3,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12), (2)) + layoutB = Layout((4,3), (3,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((2,3), (2,4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((12)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((6), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((6,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((12)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((6), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((6,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((8,8)) + layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((8,8), (8,1)) + layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + layoutB = Layout(8, 4) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(((4,2)), ((1,16))) + layoutB = Layout((4,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((2,2), (2,1)) + layoutB = Layout((2,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2)) + layoutB = Layout((2,2,2), (2,8,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2), (2,8,1)) + layoutB = Layout((2,2,2), (1,8,2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2), (2,8,1)) + layoutB = Layout((4,2,2), (2,8,1)) + self.helper_test_composition(layoutA, layoutB) + + # Pre-coalesced LHS + layoutA = Layout((4,6,8),(1,4,7)) + layoutB = Layout((6),(1)) + self.helper_test_composition(layoutA, layoutB) + + # Mid-layout truncation + layoutA = Layout((4,6,8,10),(2,3,5,7)) + layoutB = Layout(6,12) + self.helper_test_composition(layoutA, layoutB) + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbf443c9725735b0051d0a225a55eece9c663a8 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py @@ -0,0 +1,80 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.int_tuple +""" + +import unittest + +from pycute import * + + +class TestIntTuple(unittest.TestCase): + def test_product(self): + self.assertEqual(product(2), 2) + + self.assertEqual(product((3,2)), 6) + + self.assertEqual(product(product(((2,3),4))), 24) + + def test_inner_product(self): + self.assertEqual(inner_product(2, 3), 6) + + self.assertEqual(inner_product((1,2), (3,2)), 7) + + self.assertEqual(inner_product(((2,3),4), ((2,1),2)), 15) + + def test_shape_div(self): + self.assertEqual(shape_div((3,4), 6), (1,2)) + + self.assertEqual(shape_div((3,4), 12), (1,1)) + + self.assertEqual(shape_div((3,4), 36), (1,1)) + + self.assertEqual(shape_div(((3,4),6), 36), ((1,1),2)) + + self.assertEqual(shape_div((6,(3,4)), 36), (1,(1,2))) + + def test_prefix_product(self): + self.assertEqual(prefix_product(2), 1) + + self.assertEqual(prefix_product((3,2)), (1,3)) + + self.assertEqual(prefix_product((3,2,4)), (1,3,6)) + + self.assertEqual(prefix_product(((2,3),4)), ((1,2),6)) + + self.assertEqual(prefix_product(((2,3),(2, 1, 2),( 5, 2, 1))), + ((1,2),(6,12,12),(24,120,240))) + + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..a6501fd6c7c6fc5a518e4d22bf93dc0e4746a8ba --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py @@ -0,0 +1,87 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.left_inverse +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestLeftInverse(unittest.TestCase): + def helper_test_left_inverse(self, layout): + inv_layout = left_inverse(layout) + + _LOGGER.debug(f"{layout} => {inv_layout}") + + for i in range(size(layout)): + self.assertEqual(inv_layout(layout(i)), i) + + def test_left_inverse(self): + test = Layout(1,0) + self.helper_test_left_inverse(test) + + test = Layout((1,1),(0,0)) + self.helper_test_left_inverse(test) + + test = Layout(1,1) + self.helper_test_left_inverse(test) + + test = Layout(4,1) + self.helper_test_left_inverse(test) + + test = Layout(4,2) + self.helper_test_left_inverse(test) + + test = Layout((8,4),(1,8)) + self.helper_test_left_inverse(test) + + test = Layout((8,4),(4,1)) + self.helper_test_left_inverse(test) + + test = Layout((2,4,6),(1,2,8)) + self.helper_test_left_inverse(test) + + test = Layout((2,4,6),(4,1,8)) + self.helper_test_left_inverse(test) + + test = Layout((4,2),(1,16)) + self.helper_test_left_inverse(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed9759d7808da8087fe9c76761d2dd9eaeab08b --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py @@ -0,0 +1,96 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.left_inverse +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestRightInverse(unittest.TestCase): + def helper_test_right_inverse(self, layout): + inv_layout = right_inverse(layout) + + _LOGGER.debug(f"{layout} => {inv_layout}") + + for i in range(size(inv_layout)): + self.assertEqual(layout(inv_layout(i)), i) + + def test_right_inverse(self): + test = Layout(1,0) + self.helper_test_right_inverse(test) + + test = Layout((1,1),(0,0)) + self.helper_test_right_inverse(test) + + test = Layout((3,7),(0,0)) + self.helper_test_right_inverse(test) + + test = Layout(1,1) + self.helper_test_right_inverse(test) + + test = Layout(4,0) + self.helper_test_right_inverse(test) + + test = Layout(4,1) + self.helper_test_right_inverse(test) + + test = Layout(4,2) + self.helper_test_right_inverse(test) + + test = Layout((2,4),(0,2)) + self.helper_test_right_inverse(test) + + test = Layout((8,4),(1,8)) + self.helper_test_right_inverse(test) + + test = Layout((8,4),(4,1)) + self.helper_test_right_inverse(test) + + test = Layout((2,4,6),(1,2,8)) + self.helper_test_right_inverse(test) + + test = Layout((2,4,6),(4,1,8)) + self.helper_test_right_inverse(test) + + test = Layout((4,2),(1,16)) + self.helper_test_right_inverse(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb99a4833529e18fa22d65a235ce80dad372365 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py @@ -0,0 +1,59 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.typing +""" + +import logging +import unittest +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestTyping(unittest.TestCase): + def helper_test_typing(self, _cls, _obj, cls, expected: bool): + _LOGGER.debug(f"issubclass({_cls}, {cls})") + _LOGGER.debug(f"isinstance({_obj}, {cls})") + + self.assertEqual(expected, issubclass(_cls, cls)) + self.assertEqual(expected, isinstance(_obj, cls)) + + def test_typing(self): + self.helper_test_typing(int, 1, Integer, True) + self.helper_test_typing(float, 1., Integer, False) + self.helper_test_typing(str, 'hi', Integer, False) + self.helper_test_typing(bool, False, Integer, False) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h new file mode 100644 index 0000000000000000000000000000000000000000..86b7823785a9f2a957cf505740d6cfde45ccfef1 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h @@ -0,0 +1,102 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#pragma warning (disable : 4068 ) /* disable unknown pragma warnings for visual studio */ + +#pragma nv_diag_suppress boolean_controlling_expr_is_constant +#include +#pragma nv_diag_warning boolean_controlling_expr_is_constant +#pragma warning( disable : 4503) + +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Gets a CUDA device +cudaDeviceProp GetCudaDevice(); + +/// Prints device properties +std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &device); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Sets flags for Unit test +void FilterArchitecture(); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order +// of problem sizes run by CUTLASS unit tests +int CutlassUnitTestProblemCount(); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// active test macro +#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ + TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__ + +// disabled test macro +#define CUTLASS_TEST_LEVEL_DISABLED(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ + TEST(NAME_STATIC,DISABLED_L##LEVEL##_##NAME_DYNAMIC) {} + +#if CUTLASS_TEST_LEVEL == 0 +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#elif CUTLASS_TEST_LEVEL == 1 +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#else +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#endif + +#if !defined(CUTLASS_TEST_UNIT_ENABLE_WARNINGS) +#define CUTLASS_TEST_UNIT_ENABLE_WARNINGS false +#endif + +#if (__CUDACC_VER_MAJOR__ >= 12) + #define CUDA_12_0_SM90_FEATURES_SUPPORTED true +#else + #define CUDA_12_0_SM90_FEATURES_SUPPORTED false +#endif + +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h new file mode 100644 index 0000000000000000000000000000000000000000..3035e9862bcb79b749b4cbc4a74341bceac9c598 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h @@ -0,0 +1,907 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Helper to construct cached name for +*/ +#pragma once + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "thrust/universal_vector.h" + +#ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS +#define CUTLASS_TEST_ENABLE_CACHED_RESULTS false +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result of a test +struct CachedTestKey { + + std::string op; ///< Concatenated string representation of operation performed + std::string problem; ///< Concatenated string representation of problem description + std::string types; ///< Concatenated string representation of operand types + uint32_t A; ///< Hashed result of tensor A + uint32_t B; ///< Hashed result of tensor B + uint32_t C; ///< Hashed result of tensor C + + // + // Methods + // + inline CachedTestKey(): A(), B(), C() { } + + inline CachedTestKey( + std::string op, ///< Concatenated string representation of operation performed + std::string problem, ///< Concatenated string representation of problem description + std::string types, ///< Concatenated string representation of operand types + uint32_t A, ///< Hashed result of tensor A + uint32_t B, ///< Hashed result of tensor B + uint32_t C ///< Hashed result of tensor C + ): + op(op), problem(problem), types(types), A(A), B(B), C(C) + { } + + /// Checks for equality of the problem + bool operator==(CachedTestKey const &rhs) const { + return op == rhs.op && problem == rhs.problem && types == rhs.types && A == rhs.A && B == rhs.B && C == rhs.C; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::istream &operator>>(std::istream &in, CachedTestKey &result) { + + in >> result.op; + in >> result.problem; + in >> result.types; + in >> result.A; + in >> result.B; + in >> result.C; + + return in; +} + +inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) { + + out << result.op << " "; + out << result.problem << " "; + out << result.types << " "; + out << result.A << " "; + out << result.B << " "; + out << result.C << " "; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct CachedTestResult { + uint32_t D; + // + // Methods + // + + CachedTestResult(): D() + { } + + CachedTestResult(uint32_t D): D(D) + { } + + operator bool() const { + return bool(D); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::istream &operator>>(std::istream &in, CachedTestResult &result) { + in >> result.D; + return in; +} + +inline std::ostream &operator<<(std::ostream &out, CachedTestResult const &result) { + out << result.D; + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct CachedTestResultListing { + + std::list> results; + + // + // Methods + // + + inline CachedTestResultListing(std::string const &path) { + std::ifstream file(path); + + while (file.good()) { + CachedTestKey key; + file >> key; + + CachedTestResult result; + file >> result; + + if (result) { + results.push_back(std::make_pair(key, result)); + } + } + } + + /// Returns the cached result + std::pair find(CachedTestKey const &rhs) const { + for (auto const & result : results) { + if (result.first == rhs) { + return std::make_pair(true, result.second); + } + } + return std::make_pair(false, CachedTestResult()); + } + + /// Appends an entry + void append(CachedTestKey const &key, CachedTestResult const &result) { + if (result) { + results.push_back(std::make_pair(key, result)); + } + } + + /// Writes the entire listing to a file + bool write(std::string const &path) { + std::ofstream file(path); + if (!file.good()) { + return false; + } + + for (auto const &result : results) { + file << result.first << result.second << std::endl; + } + + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ScalarEncoder { + Element scalar; + + ScalarEncoder(Element s): scalar(s) { } + + std::string str() const { + std::stringstream ss; + Element s = scalar; + if (s < Element()) { + s = -s; + ss << "n"; + } + ss << s; + return ss.str(); + } +}; + +template +ScalarEncoder EncodeScalar(Element a) { + return ScalarEncoder(a); +} + +template +struct ScalarEncoder> { + cutlass::complex scalar; + + ScalarEncoder(cutlass::complex s): scalar(s) { } + + std::string str() const { + std::stringstream ss; + ss << EncodeScalar(scalar.real()) << "_" << EncodeScalar(scalar.imag()) << "i"; + return ss.str(); + } +}; + +template +std::ostream &operator<<(std::ostream &out, ScalarEncoder const &scalar) { + out << scalar.str(); + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline char const *EncodeOperator(cutlass::conv::Operator conv_op) { + switch (conv_op) { + case cutlass::conv::Operator::kFprop: return "fprop"; + case cutlass::conv::Operator::kDgrad: return "dgrad"; + case cutlass::conv::Operator::kWgrad: return "wgrad"; + case cutlass::conv::Operator::kDeconv: return "deconv"; + } + return "conv_unknown"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Encode GemmCoord (Gemm problem size) +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::gemm::GemmCoord const &problem) { + + out << problem.m() << "x" << problem.n() << "x" << problem.k() << "_"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Encode Conv2dProblemSize +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::conv::Conv2dProblemSize const &problem) { + + out << problem.N << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" + << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; + + out << "pad_h" << problem.pad_h << "w" << problem.pad_w << "_"; + out << "stride_h" << problem.stride_h << "w" << problem.stride_w << "_"; + out << "dil_h" << problem.dilation_h << "w" << problem.dilation_w << "_"; + + switch (problem.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Encode Conv3dProblemSize +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::conv::Conv3dProblemSize const &problem) { + + out << problem.N << "x" << problem.D << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" + << problem.Z << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; + + out << "pad_d" << problem.pad_h << "h" << problem.pad_h << "w" << problem.pad_w << "_"; + out << "stride_d" << problem.stride_d << "h" << problem.stride_h << "w" << problem.stride_w << "_"; + out << "dil_d" << problem.dilation_d << "h" << problem.dilation_h << "w" << problem.dilation_w << "_"; + + switch (problem.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Encode 3.x ConvNd ProblemShape +template +inline std::ostream &EncodeProblemSize( + std::ostream &out, + ProblemShape const& problem_shape) { + + out << problem_shape.shape_A << "_"; + out << problem_shape.shape_B << "_"; + + out << "padl" << problem_shape.lower_padding << "_"; + out << "padu" << problem_shape.upper_padding << "_"; + out << "str" << problem_shape.traversal_stride << "_"; + out << "dil" << problem_shape.dilation << "_"; + + switch (problem_shape.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string ElementTypeName() { + return std::string(typeid(Element).name()); +} + +template <> +inline std::string ElementTypeName() { + return "h"; +} + +template <> +inline std::string ElementTypeName>() { + return "ch"; +} + +template <> +inline std::string ElementTypeName() { + return "bf16"; +} + +template <> +inline std::string ElementTypeName>() { + return "cbf16"; +} + +template <> +inline std::string ElementTypeName() { + return "tf32"; +} + +template <> +inline std::string ElementTypeName>() { + return "ctf32"; +} + +template <> +inline std::string ElementTypeName>() { + return "c"; +} + +template <> +inline std::string ElementTypeName>() { + return "z"; +} + +template <> +inline std::string ElementTypeName>() { + return "q"; +} + +template <> +inline std::string ElementTypeName() { + return "s8"; +} + +template <> +inline std::string ElementTypeName() { + return "u8"; +} + +template <> +inline std::string ElementTypeName() { + return "s4"; +} + +template <> +inline std::string ElementTypeName() { + return "u4"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string LayoutTypeName() { + return std::string(typeid(Layout).name()); +} + +template <> +inline std::string LayoutTypeName() { + return "n"; +} + +template <> +inline std::string LayoutTypeName() { + return "t"; +} + +template <> +inline std::string LayoutTypeName() { + return "nhwc"; +} + +template <> +inline std::string LayoutTypeName>() { + return "nc32hw32"; +} + +template <> +inline std::string LayoutTypeName>() { + return "nc64hw64"; +} + +template <> +inline std::string LayoutTypeName>() { + return "c32rsk32"; +} + +template <> +inline std::string LayoutTypeName>() { + return "c64rsk64"; +} + +template <> +inline std::string LayoutTypeName() { + return "ndhwc"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string TensorTypeName() { + std::stringstream ss; + ss << ElementTypeName() << LayoutTypeName(); + return ss.str(); +} + +template +inline std::string TensorTypeName() { + std::stringstream ss; + ss << ElementTypeName(); + return ss.str(); +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function on a byte array +struct CRC32 { + + uint32_t table[256]; + + // + // Methods + // + + CRC32() { + + uint32_t rem; + int i, j; + + for (i = 0; i < 256; i++) { + rem = i; + for (j = 0; j < 8; j++) { + if (rem & 1) { + rem >>= 1; + rem ^= 0xedb88320; + } else + rem >>= 1; + } + table[i] = rem; + } + } + + /// Computes the CRC of an array of bytes + uint32_t operator()(void const *start, size_t length, uint32_t crc = uint32_t()) const { + uint8_t const *p = static_cast(start); + uint8_t const *q = static_cast(start) + length; + + crc = ~crc; + + for (; p != q; ++p) { + uint8_t octet = *p; + crc = (crc >> 8) ^ table[(crc & 0xff) ^ octet]; + } + + return ~crc; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Element, typename Layout +> +uint32_t TensorHash( + cutlass::TensorView view, + CRC32 const &hash = CRC32(), + uint32_t crc = uint32_t() +) { + + return hash(view.data(), view.capacity() * cutlass::sizeof_bits::value / 8, crc); +} + +template +uint32_t TensorHash( + thrust::universal_vector& tensor, + CRC32 const &hash = CRC32(), + uint32_t crc = uint32_t() +) { + + return hash(tensor.data().get(), tensor.size() * cutlass::sizeof_bits::value / 8, crc); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline std::ostream &EncodeTypes( + std::ostream &out +) { + + out << TensorTypeName() << "_" + << TensorTypeName() << "_" + << TensorTypeName() << "_" + << ElementTypeName() << "_" + << ElementTypeName(); + + return out; +} + +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementD +> +inline std::ostream &EncodeTypes( + std::ostream &out +) { + + out << TensorTypeName() << "_" + << TensorTypeName() << "_" + << TensorTypeName() << "_" + << ElementTypeName(); + + return out; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedGemmTestKey( + cutlass::gemm::GemmCoord const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode gemm operator and problem sizes + key.op = "gemm"; + + std::stringstream ss_problem; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d_with_broadcast"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dWithReductionTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d_with_reduction"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv3dTestKey( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv3dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv3d operator and problem sizes + key.op = "conv3d"; + + std::stringstream ss_problem; + + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape, + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementD +> +inline CachedTestKey CreateCachedConvNd3xTestKey( + cutlass::conv::Operator conv_operator, + ProblemShape const& problem_shape, + double alpha, + double beta, + thrust::universal_vector A, + thrust::universal_vector B, + thrust::universal_vector C +) { + + CachedTestKey key; + + // Encode convNd operator and problem sizes + std::stringstream ss_op; + ss_op << "conv" << ProblemShape::RankS << "d"; + key.op = ss_op.str(); + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem_shape); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, + ElementB, + ElementC, + ElementD>(ss_types); + key.types = ss_types.str(); + + // Encode problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace test::conv::device + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h new file mode 100644 index 0000000000000000000000000000000000000000..a14134b2854732e669977831207a456d28beed9f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h @@ -0,0 +1,927 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed sizes for Conv2d problem +*/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +namespace test { +namespace conv { +namespace device { + +using Conv2dProblemVector = std::vector; + +// +// Structures to prune items from Conv2dProblemVector +// +// Specification template for pruning items for convolution problem lists +template struct Specification +{ + virtual ~Specification() = default; + virtual bool is_satisfied(T item) const = 0; +}; + +// input size (NHWC) specification +struct InputSizeSpecification : Specification +{ + cutlass::Tensor4DCoord input_size; + + InputSizeSpecification(cutlass::Tensor4DCoord input_size_) : input_size(input_size_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((input_size.n() == item.N) && (input_size.h() == item.H) && (input_size.w() == item.W) && (input_size.c() == item.C)); + } +}; + +// stride (stride_h, stride_w) specification +struct StrideSpecification : Specification +{ + cutlass::MatrixCoord stride; + + StrideSpecification(cutlass::MatrixCoord stride_) : stride(stride_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((stride.row() == item.stride_h) && (stride.column() == item.stride_h)); + } +}; + +// channel (C,K) specification, must be multiple of minimum channel +struct ChannelDivisibilitySpecification : Specification +{ + int channel_multiple; + + ChannelDivisibilitySpecification(int channel_multiple_) : channel_multiple(channel_multiple_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((item.K % channel_multiple == 0) && (item.C % channel_multiple == 0)); + } +}; + +// +// Pruning function for items from Conv2dProblemVector based on a Specification +// +inline Conv2dProblemVector prune(Conv2dProblemVector const &items, + Specification const &spec) +{ + Conv2dProblemVector pruned_list; + + for (auto& p : items) + if (spec.is_satisfied(p)) + pruned_list.push_back(p); + return pruned_list; +} + + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedConv2dProblemSizes initializes and holds conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedConv2dProblemSizes { + + // + // Data members + // + int minimum_channel_size; + + Conv2dProblemVector conv2d_default_sizes; + Conv2dProblemVector conv2d_rigorous_sizes; + Conv2dProblemVector conv2d_resnet50_sizes; + Conv2dProblemVector conv2d_resnet50_sizes_perf; + + // + // Methods + // + /// Default ctor + TestbedConv2dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { + initialize_conv2d_default_sizes(); + initialize_conv2d_rigorous_sizes(); + initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes, 1 /*batch-size*/); + + initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes_perf, 34 /*batch-size*/); + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv2dProblemVector *problems_vectors[] = { + &conv2d_default_sizes, + &conv2d_rigorous_sizes, + &conv2d_resnet50_sizes, + &conv2d_resnet50_sizes_perf + }; + + for (Conv2dProblemVector *problems : problems_vectors) { + Conv2dProblemVector filtered; + + for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { + if (!(problem.C % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_conv2d_default_sizes() { + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (1,1) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 1, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 8, minimum_channel_size}, // input size (NHWC) + {8, 1, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 8, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 4, 4, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {2, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 5, 5, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 5, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 6, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 7, 7, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (1,1) asymmetric paddings (1, 0, 1, 0) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 1, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 8, minimum_channel_size}, // input size (NHWC) + {8, 1, 3, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 8, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 4, 4, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {2, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 5, 5, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 5, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 6, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 7, 7, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (2,2) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 11, 7, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 11, 7, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 11, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 17, 19, minimum_channel_size}, // input size (NHWC) + {16, 2, 2, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 5, minimum_channel_size}, // input size (NHWC) + {16, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 17, 8}, // input size (NHWC) + {24, 3, 3, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 21, 8}, // input size (NHWC) + {24, 3, 3, 8}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 20, 24, 8}, // input size (NHWC) + {40, 3, 3, 8}, // filter size (KRSC) + {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 15, 19, 160}, // input size (NHWC) + {224, 1, 1, 160}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 19, 37, 160}, // input size (NHWC) + {224, 3, 3, 160}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, 160}, // input size (NHWC) + {224, 2, 3, 160}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 21, 128}, // input size (NHWC) + {224, 3, 3, 128}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 29, 37, 160}, // input size (NHWC) + {224, 5, 5, 160}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 15, 19, 32 + minimum_channel_size}, // input size (NHWC) + {96, 3, 3, 32 + minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 24, 64 + minimum_channel_size}, // input size (NHWC) + {96, 3, 3, 64 + minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 16, 288}, // input size (NHWC) + {160, 5, 5, 288}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 55, 51, 256}, // input size (NHWC) + {512, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 71, 80, 32}, // input size (NHWC) + {64, 5, 5, 32}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 224, 224, 8}, // input size (NHWC) + {64, 7, 7, 8}, // filter size (KRSC) + {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size stride (3, 3), filter (3, 3), non-default padding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 23, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size padding > stride, asymmetric filter, padding and striding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 31, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {5, 5, 7, 7}, // padding (pad_h, _, pad_w, _) + {3, 4}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 35, 256}, // input size (NHWC) + {512, 7, 5, 256}, // filter size (KRSC) + {11, 11, 7, 7}, // padding (pad_h, _, pad_w, _) + {3, 5}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size *mixed* stride (1, 2) and (2, 1), + // filter (3, 3), default padding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + ///////////////////////////////////////////////////////////////////////////// + // Additional input size + ///////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 28, 28, 256}, // input size (NHWC) + {256, 2, 2, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 32, 32, 16}, // input size (NHWC) + {32, 3, 3, 16}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {6, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {32, 24, 32, 32}, // input size (NHWC) + {32, 1, 2, 32}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {4, 4, 5, 128}, // input size (NHWC) + {256, 3, 6, 128}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + {4, 3, 3, 256} // output size (NPQK) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {4, 2, 3, 256}, // input size (NHWC) + {328, 3, 5, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + {4, 1, 1, 328} // output size (NPQK) + )); + } + + + // Add a few large and rigorous convolution problem sizes + void initialize_conv2d_rigorous_sizes() { + +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 124, 224, 96}, // input size (NHWC) + {24, 7, 7, 96}, // filter size (KRSC) + {1, 229, 129, 32} // output size (NPQK) + )); + + conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 233, 35, 48}, // input size (NHWC) + {24, 7, 5, 48}, // filter size (KRSC) + {1, 233, 35, 24} // output size (NPQK) + )); + +#endif + + } + + + // Add resent50 layers to unit testing sizes + void initialize_conv2d_resnet50_sizes(Conv2dProblemVector &conv2d_problem_vector, int batch_size = 1){ + +#if 0 // Resnet50 first layer (layer_id = 0) with channel = 3 is not supported in cutlass + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + [1, 224, 224, 3], // input size (NHWC) + [64, 7, 7, 3], // filter size (KRSC) + [3, 3, 3, 3], // padding (pad_h, _, pad_w, _) + [2, 2], // stride (stride_h, stride_w) + [1, 1], // dilation (dilation_h, dilation_w) + )); +#endif + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {256, 1, 1, 64}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {64, 1, 1, 64}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {64, 3, 3, 64}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {64, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {512, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {128, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 128}, // input size (NHWC) + {128, 3, 3, 128}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 128}, // input size (NHWC) + {512, 1, 1, 128}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {128, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {1024, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {256, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 256}, // input size (NHWC) + {256, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 256}, // input size (NHWC) + {1024, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {256, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {2048, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {512, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 512}, // input size (NHWC) + {512, 3, 3, 512}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 512}, // input size (NHWC) + {2048, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 2048}, // input size (NHWC) + {512, 1, 1, 2048}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + } + +}; + + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedGroupConv2dProblemSizes initializes and holds group conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedGroupConv2dProblemSizes { + + // + // Data members + // + int threadblock_n; + int threadblock_k; + int minimum_channel_size; + + Conv2dProblemVector default_single_group_sizes; + Conv2dProblemVector default_multiple_group_sizes; + + // + // Methods + // + /// Default ctor + TestbedGroupConv2dProblemSizes( + int threadblock_n_, + int threadblock_k_, + int minimum_channel_size_ = 64) + : threadblock_n (threadblock_n_), + threadblock_k (threadblock_k_), + minimum_channel_size (minimum_channel_size_) { + initialize_group_conv2d_default_sizes(); + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv2dProblemVector *problems_vectors[] = { + &default_single_group_sizes, + &default_multiple_group_sizes + }; + + for (Conv2dProblemVector *problems : problems_vectors) { + Conv2dProblemVector filtered; + + for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { + if (!((problem.C / problem.groups) % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_group_conv2d_default_sizes() { + + //////////////////////////////////////////////////////////////////////////////////// + // One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 + // One CTA calculates a single group + //////////////////////////////////////////////////////////////////////////////////// + + for (int cta_per_group_k = 1; cta_per_group_k < 4; ++cta_per_group_k) { + // groups = 2, 3, 4 + for (int groups = 2; groups < 5; ++groups) { + + int conv_k = cta_per_group_k * threadblock_n * groups; + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 2 * groups}, // input size (NHWC) + {conv_k, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + groups // groups + )); + + } // loop groups + } // loop cta_per_group_k + + // Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k}, // input size (NHWC) + {threadblock_n * 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // Larger problem sizes + + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 696}, // input size (NHWC) + {768, 3, 3, 232}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 3 // groups + )); + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 14, 14, 1392}, // input size (NHWC) + {1536, 3, 3, 232}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 3 // groups + )); + + //////////////////////////////////////////////////////////////////////////////////// + // One CTA calculate multiple groups: CTA::N % k_per_group = 0 + //////////////////////////////////////////////////////////////////////////////////// + + // 2 groups per CTA + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 4}, // input size (NHWC) + {threadblock_n, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // 2 groups per CTA and partial gemm_k + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k}, // input size (NHWC) + {threadblock_n, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // 4 groups per CTA + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 8}, // input size (NHWC) + {threadblock_n / 2, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 4 // groups + )); + + // 4 groups per CTA and partial gemm_k + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 2}, // input size (NHWC) + {threadblock_n / 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 4 // groups + )); + } + +}; + + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..34588ecb467b824cc0fcbbff0bc0d99e4385d80e --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h @@ -0,0 +1,818 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv2d { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + int tested_problem_count; + +public: + + TestbedConv2d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // increment tested problem count run by the testbed + tested_problem_count++; + +#if 0 // display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run." << std::endl; + return false; + } + + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv2dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + std::stringstream ss_problem_size_text; + ss_problem_size_text << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << ss_problem_size_text.str() + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSpecificConv2d( + const Conv2dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2d testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllConv2d( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + // Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes) + std::vector problem_vectors = { + conv_test_sizes, // run user specified sizes + conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + //conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Flatten 2D problem_vectors into a 1D problem_sizes + std::vector problem_sizes; + for (auto problem_vector : problem_vectors) { + for(auto conv_problem : problem_vector) { + problem_sizes.push_back(conv_problem); + } + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient) + // run the most rigorous problem size first + if (CutlassUnitTestProblemCount()) { + std::reverse(problem_sizes.begin(), problem_sizes.end()); + } + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // Fixed channels algorithm requires channel count to match access size + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFixedChannels) { + if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { + continue; + } + } + + // Few channels algorithm requires channel count to match access size + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFewChannels) { + if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { + continue; + } + } + + // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} + // Although strided dgrad works for all stride combinations, we are only going + // to run strided dgrad for non-unity strides + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts + if (CutlassUnitTestProblemCount() && + testbed.tested_problem_count > CutlassUnitTestProblemCount()) { + return true; + } + } + + // Small-channels convolution can't run here. + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFixedChannels) { + + return true; + } + + // Small-channels convolution can't run here. + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFewChannels) { + + return true; + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}) // dilation (dilation_h, dilation_w) + .reset_split_k_slices(2), + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts + if (CutlassUnitTestProblemCount() && + testbed.tested_problem_count > CutlassUnitTestProblemCount()) { + return true; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..cf075674da673cf8e056172732f912b8acba3c5b --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h @@ -0,0 +1,666 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/host_reorder.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class InterleavedTestbedConv2d { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_B_reordered; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + InterleavedTestbedConv2d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 3; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_B_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + cutlass::reorder_convK( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv2dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementC, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << "ncxhwx_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_cxrskx_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllInterleavedConv2d( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + InterleavedTestbedConv2d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(InterleavedK); // minimum channel size must be multiple of InterleavedK for interleaved layout + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + ChannelDivisibilitySpecification channel_spec(InterleavedK); //input and output channels must be multiple of InterleavedK + auto pruned_problem_vector = prune(*problem_vector, channel_spec); + + // Run conv testbed on default convolution sizes + for(auto conv_problem : pruned_problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's unity stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + +#if 0 + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } +#endif + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..ad7b2ce61a66a79f852c0aac0895d10ba18e5466 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h @@ -0,0 +1,622 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed for running device-level Conv2Ds with absolute maximum calculation and scaling +*/ + +#pragma once + +#include +#include +#include + +#include "conv2d_problems.h" +#include "../../common/cutlass_unit_test.h" +#include "../../gemm/device/testbed_utils.h" + +#include "cutlass/matrix_coord.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_reduce.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Conv, + template class ActivationFunctor +> +struct TestbedConv2dWithAbsMax { + + using ElementAccumulator = typename Conv::ElementAccumulator; + using ElementCompute = typename Conv::UnderlyingKernel::Epilogue::OutputOp::ElementCompute; + using ElementScalingFactor = typename Conv::EpilogueOutputOp::ElementScalingFactor; + using ElementAbsmax = typename Conv::EpilogueOutputOp::ElementAbsmax; + static cutlass::conv::Operator const kConvolutionalOperator = Conv::kConvolutionalOperator; + + static bool const kScaleAux = Conv::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; + static bool const kScaleOutput = Conv::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; + bool doScaleA; + bool doScaleB; + bool doScaleC; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_Aux; + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_Vector; + cutlass::HostTensor tmp_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // + // Methods + // + + TestbedConv2dWithAbsMax( + bool scaleA = true, + bool scaleB = true, + bool scaleC = true, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize scaling factors + template + bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { + cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); + return true; + } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::conv::Conv2dProblemSize const &problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Vector.resize({1, 1, 1, implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()}); + reference_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + tmp_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<4> origin(0); + tensor_A.host_view().at(origin) = typename Conv::ElementA(1); + tensor_B.host_view().at(origin) = typename Conv::ElementB(1); + tensor_C.host_view().at(origin) = typename Conv::ElementC(1); + tensor_Vector.host_view().at(origin) = typename Conv::ElementC(1); + + cutlass::reference::host::TensorFill(tensor_D.host_view()); + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_Vector.sync_device(); + + int scale_bits = 2; + if (doScaleA) { + scale_A.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits)); + scale_A.sync_device(); + } + + if (doScaleB) { + scale_B.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits)); + scale_B.sync_device(); + } + + if (doScaleC) { + scale_C.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits)); + scale_C.sync_device(); + } + + if (kScaleOutput) { + scale_D.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits)); + scale_D.sync_device(); + + abs_max_D.resize({1, 1, 1, 1}); + cutlass::reference::host::TensorFill(abs_max_D.host_view()); + abs_max_D.sync_device(); + + reference_abs_max_D.resize({1, 1, 1, 1}); + } + + if (kScaleAux) { + tensor_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + cutlass::reference::host::TensorFill(tensor_Aux.host_view()); + tensor_Aux.sync_device(); + + scale_Aux.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits)); + scale_Aux.sync_device(); + + abs_max_Aux.resize({1, 1, 1, 1}); + cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); + abs_max_Aux.sync_device(); + + reference_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + reference_abs_max_Aux.resize({1, 1, 1, 1}); + } + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + if (kScaleAux) { + tensor_Aux.sync_host(); + abs_max_Aux.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); + passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view()); + passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view()); + } + + if (kScaleOutput) { + abs_max_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); + passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view()); + } + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + std::ofstream file0("conv_testbed_with_amax_errors_reference.txt"); + std::ofstream file1("conv_testbed_with_amax_errors_computed.txt"); + + std::ofstream file("conv_testbed_with_amax_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nVector =\n" << tensor_Vector.host_view() + << "\nScaleA = " << scale_A.host_view() + << "\nScaleB = " << scale_B.host_view() + << "\nScaleC = " << scale_C.host_view() + << "\nScaleD = " << scale_D.host_view() + << "\nScaleAux = " << scale_Aux.host_view() + << std::endl; + + file0 << "\n\nReference D =\n" << reference_D.host_view() << std::endl; + file1 << "\n\nComputed D =\n" << tensor_D.host_view() << std::endl; + if (kScaleAux) { + file0 << "\n\nReference Aux =\n" << reference_Aux.host_view() << std::endl; + file1 << "\n\nComputed Aux =\n" << tensor_Aux.host_view() << std::endl; + file0 << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() << std::endl; + file1 << "\n\nComputed Absmax Aux = " << abs_max_Aux.host_view() << std::endl; + } + if (kScaleOutput) { + file0 << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() << std::endl; + file1 << "\n\nComputed Absmax D = " << abs_max_D.host_view() << std::endl; + } + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha, + ElementCompute beta) { + + cutlass::Coord<4> origin(0); + ElementCompute scaled_alpha = alpha; + if (doScaleA) { + scaled_alpha *= scale_A.host_view().at(origin); + } + if (doScaleB) { + scaled_alpha *= scale_B.host_view().at(origin); + } + + ElementCompute scaled_beta = beta; + if (doScaleC) { + scaled_beta *= scale_C.host_view().at(origin); + } + + // + // Verify + // + + cutlass::reference::host::Conv2d< + typename Conv::ElementA, typename Conv::LayoutA, + typename Conv::ElementB, typename Conv::LayoutB, + typename Conv::ElementC, typename Conv::LayoutC, + ElementCompute, ElementAccumulator, ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tmp_D.host_ref(), + scaled_alpha, + scaled_beta + ); + + ElementCompute tmp_abs_max_Aux(0.); + ElementCompute tmp_abs_max_D(0.); + + cutlass::NumericConverter cvt_c_to_compute; + cutlass::NumericConverter cvt_accum_to_compute; + cutlass::NumericConverter cvt_compute_to_absmax; + cutlass::NumericConverter cvt_compute_to_d; + cutlass::NumericConverter cvt_compute_to_aux; + + cutlass::absolute_value_op abs; + cutlass::maximum_with_nan_propogation max; + ActivationFunctor act; + + ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); + + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({n, p, q, k})); + ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, 0, 0, k})); + ElementCompute aux = intermediate + bias; + ElementCompute d = act(aux); + tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); + tmp_abs_max_D = max(abs(d), tmp_abs_max_D); + reference_D.host_view().at({n, p, q, k}) = cvt_compute_to_d(d * d_scale); + + if (kScaleAux) { + reference_Aux.host_view().at({n, p, q, k}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); + } + } + } + } + } + if (kScaleAux) { + reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux); + } + + if (kScaleOutput) { + reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D); + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Conv::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta}; + typename Conv::EpilogueOutputOp::Params epilogue_params{ + activation_params, + scale_A.device_data(), + scale_B.device_data(), + scale_C.device_data(), + scale_D.device_data(), + scale_Aux.device_data(), + abs_max_Aux.device_data(), + abs_max_D.device_data() + }; + + typename Conv::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + tensor_Aux.device_ref(), + epilogue_params, + cutlass::conv::SplitKMode::kSerial, + tensor_Vector.device_data(), + 0 + }; + + Conv conv2d_op; + + cutlass::Status status = conv2d_op.can_implement(arguments); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + size_t workspace_size = Conv::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = conv2d_op.initialize(arguments, workspace.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + cudaError_t cuda_error = cudaDeviceSynchronize(); + EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed" << std::endl; + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ImplicitGemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAllConv2dWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) { + const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(); + const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector(); + + // + // Testbed object + // + + TestbedConv2dWithAbsMax testbed(scaleA, scaleB, scaleC); + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + bool passed = true; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Prune all problems with channels that aren't divisible by the number of elements accessed per + // load for operands A and B. This is meant to align with the requirements of iterators used for + // fprop kernels. + ChannelDivisibilitySpecification channel_spec(128 / cutlass::sizeof_bits::value); + auto pruned_problem_vector = prune(*problem_vector, channel_spec); + + // Run conv testbed on default convolution sizes + for(auto conv_problem : pruned_problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed &= testbed.run(conv_problem); + + if (!passed) { + return false; + } + + // test mode = convolution + passed &= testbed.run(conv_problem.reset_mode(cutlass::conv::Mode::kConvolution)); + + if (!passed) { + return false; + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..f768f5b25f425910a49058599d3854352136caef --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -0,0 +1,734 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM for fused epilogue broadcast testbed + + Parallel split-k is not tested because we can just use regular conv kernel + when we need to use parallel-splitk. Broadcast can happen in the reduction + kernel. +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Conv2dWithBroadcastReferenceOp { + + using OutputOp = typename Conv2d::EpilogueOutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + Conv2dWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) { + ElementCompute t_full = binary_op(conv2d, bias); + T = ElementT(t_full); + + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = CONV(AB, C) +// +// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k]) +// +// Z[n, p, q, k] = Elementwise(T[n, p, q, k]) +// + +template < + typename Conv2d, + typename ReferenceOp, + bool AddBroadcastFirst = false +> +class TestbedConv2dWithBroadcast { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + using ElementZ = typename EpilogueOutputOp::ElementZ; + using ElementT = typename EpilogueOutputOp::ElementT; + using ElementVector = typename EpilogueOutputOp::ElementVector; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + static const bool kAddBroadcastFirst = AddBroadcastFirst; + static const bool kStoreT = EpilogueOutputOp::kStoreT; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_C_reference; + cutlass::HostTensor tensor_Z_computed; + cutlass::HostTensor tensor_Z_reference; + cutlass::HostTensor tensor_T_computed; + cutlass::HostTensor tensor_T_reference; + cutlass::HostTensor tensor_Y_reference; + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + +public: + + TestbedConv2dWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Broadcast.resize({ + 1, + 1, + 1, + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), + }); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); + + for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { + for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { + for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { + for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { + tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k})); + } + } + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + tensor_C_reference.sync_device(); + tensor_Z_computed.sync_device(); + tensor_Z_reference.sync_device(); + tensor_T_computed.sync_device(); + tensor_T_reference.sync_device(); + tensor_Y_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_Z_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Broadcast.device_data(), + kStoreT ? tensor_T_computed.device_data() : nullptr, + 0, // This must be zero + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() + ); + + // initialize the kernel + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_T_computed.sync_host(); + tensor_Z_computed.sync_host(); + + // + // Reference check + // + + // When kAddBroadcastFirst is true, add bias on the host + ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C_reference.device_ref(), + tensor_Y_reference.device_ref(), + alpha, + beta_ref); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_Y_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C_reference.host_ref(), + tensor_Y_reference.host_ref(), + alpha, + beta_ref); + +#endif + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { + + ElementZ z{}; + ElementT t{}; + + ElementCompute accum = tensor_Y_reference.at({n, p, q, k}); + ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k})); + + + if (kAddBroadcastFirst) { + reference_op(z, t, accum + bias, + beta * ElementCompute(tensor_C_reference.at({n, p, q, k}))); + } else { + reference_op(z, t, accum, bias); + } + + tensor_Z_reference.at({n, p, q, k}) = z; + tensor_T_reference.at({n, p, q, k}) = t; + } + } + } + } + + if (kStoreT) { + passed = cutlass::reference::host::TensorEquals( + tensor_T_computed.host_view(), + tensor_T_reference.host_view()); + + EXPECT_TRUE(passed); + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Z_computed.host_view(), + tensor_Z_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" + << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" + << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" + << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" + << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" + << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; + } + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template , + bool AddBroadcastFirst = false> +bool TestSpecificConv2dWithBroadcast( + const Conv2dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithBroadcast testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template , + bool AddBroadcastFirst = false, + bool TestSplitK = true +> +bool TestAllConv2dWithBroadcast( + const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithBroadcast testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + + if (!TestSplitK) + return passed; + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..a8ec16ca5de369470f5dc50bb6f8b5e2da3da10d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h @@ -0,0 +1,643 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/tensor_reduce.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv2dWithReduction { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + using ElementT = typename EpilogueOutputOp::ElementTensor; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + + cutlass::HostTensor tensor_Reduction; + cutlass::HostTensor tensor_Tensor; + cutlass::HostTensor tensor_Final_Reduction; + + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + TestbedConv2dWithReduction( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope = 2; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + tensor_Reduction.resize({ + 1, + 1, + (problem_size.N * problem_size.P * problem_size.Q - 1 + Conv2d::ThreadblockShape::kM) / Conv2d::ThreadblockShape::kM, + (problem_size.K) + }); + + tensor_Final_Reduction.resize({ + 1, + 1, + 1, + (problem_size.K) + }); + + tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K}); + + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Reduction.device_data(), + tensor_Tensor.device_data(), + static_cast(tensor_Reduction.stride()[0]), + static_cast(tensor_Tensor.stride()[0]) + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + // Final reduction over the partial reduction tensor + using Functor = cutlass::plus; + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementAccumulator, + ElementAccumulator, + LayoutC, + Functor, + 8, + ElementAccumulator + >; + + TensorReduction reduction(tensor_Reduction.extent(), 2); + + cutlass::DeviceAllocation reduction_device_workspace(reduction.workspace_size()); + + status = reduction.reduce( + tensor_Final_Reduction.device_ref(), + tensor_Reduction.device_ref(), + reduction_device_workspace.get(), + ElementAccumulator()); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); + + // + // Reference check + // + + tensor_D_computed.sync_host(); + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + + EXPECT_TRUE(passed); + + // + // Reference check on reduction results + // + + tensor_Reduction.sync_host(); + tensor_Final_Reduction.sync_host(); + + // compute backwards for reduction results + cutlass::HostTensor reference_Reduction; + reference_Reduction.resize({ + 1, + 1, + 1, + (problem_size.K) + }); + + for (int k = 0; k < problem_size.K; ++k) { + ElementAccumulator reduced_value = ElementAccumulator(); + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + reduced_value += tensor_D_reference.at({n, p, q, k}); + } + } + } + reference_Reduction.at({0, 0, 0, k}) = reduced_value; + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Final_Reduction.host_view(), + reference_Reduction.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << tensor_D_reference.host_view() << "\n" + << "\nD computed:\n" << tensor_D_computed.host_view() << "\n" + << "\nreduction reference:\n" << reference_Reduction.host_view() << "\n" + << "\nreduction computed:\n" << tensor_Reduction.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllConv2dWithReduction( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithReduction testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + // Parallel SplitK is not tested. + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h new file mode 100644 index 0000000000000000000000000000000000000000..fae7d6194fb671594221a90faea7cac1e5fbeb9f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h @@ -0,0 +1,293 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed sizes for Conv2d problem +*/ +#pragma once + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_types.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +namespace test { +namespace conv { +namespace device { + +using Conv3dProblemVector = std::vector; + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedConv3dProblemSizes initializes and holds conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedConv3dProblemSizes { + + // + // Data members + // + int minimum_channel_size; + Conv3dProblemVector conv3d_default_sizes; + Conv3dProblemVector conv3d_vnet_medical_sizes; + + // + // Methods + // + /// Default ctor + TestbedConv3dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { + + initialize_conv3d_default_sizes(); + initialize_conv3d_vnet_medical_sizes(conv3d_vnet_medical_sizes, 1 /*batch-size*/); + + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv3dProblemVector *problems_vectors[] = { + &conv3d_default_sizes, + &conv3d_vnet_medical_sizes + }; + + for (Conv3dProblemVector *problems : problems_vectors) { + Conv3dProblemVector filtered; + + for (cutlass::conv::Conv3dProblemSize const & problem : *problems) { + if (!(problem.C % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_conv3d_default_sizes() { + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 3, 3, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 1, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) + CUTLASS_STL_NAMESPACE::make_tuple( + cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w) + ), + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + CUTLASS_STL_NAMESPACE::make_tuple( + cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w) + ), + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 16, 16, 16, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 15, 19, 160}, // input size (NDHWC) + {224, 1, 3, 6, 160}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 2, 1, 1, minimum_channel_size}, // input size (NDHWC) + {8, 2, 1, 1, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 7, 7, minimum_channel_size}, // input size (NDHWC) + {16, 1, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 11, 15, 19, 64}, // input size (NDHWC) + {32, 4, 3, 6, 64}, // filter size (KTRSC) + cutlass::Coord<3>({2, 1, 3}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + } + + // Add vnet layers to unit testing sizes + void initialize_conv3d_vnet_medical_sizes(Conv3dProblemVector &conv3d_problem_vector, int batch_size = 1) { + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 32, 32, 32, 16}, // input size (NDHWC) + {32, 2, 2, 2, 16}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {64, 2, 2, 2, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 64}, // input size (NDHWC) + {64, 3, 3, 3, 64}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 64}, // input size (NDHWC) + {128, 2, 2, 2, 64}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 4, 4, 4, 128}, // input size (NDHWC) + {128, 3, 3, 3, 128}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 128}, // input size (NDHWC) + {128, 3, 3, 3, 128}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 64}, // input size (NDHWC) + {64, 3, 3, 3, 64}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 32, 32, 32, 16}, // input size (NDHWC) + {64, 2, 2, 2, 16}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {128, 2, 2, 2, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + } + +}; + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..029f5effb9103bebd4ee61767795d3883541d986 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h @@ -0,0 +1,716 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/util/reference/host/convolution.h" + +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "conv3d_problems.h" +#include "cutlass/core_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv3d { +public: + + using ElementA = typename Conv3d::ElementA; + using LayoutA = typename Conv3d::LayoutA; + using ElementB = typename Conv3d::ElementB; + using LayoutB = typename Conv3d::LayoutB; + using ElementC = typename Conv3d::ElementC; + using LayoutC = typename Conv3d::LayoutC; + using ElementAccumulator = typename Conv3d::ElementAccumulator; + using ElementCompute = typename Conv3d::ElementCompute; + using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + TestbedConv3d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 4; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv3dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + + /// Executes one test + bool run( + cutlass::conv::Conv3dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute()) { + + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv3d conv3d_op; + + typename Conv3d::Arguments conv3d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + cutlass::Status status = conv3d_op.can_implement(conv3d_args); + if (status != cutlass::Status::kSuccess) { + std::cerr << "can_implement failed for the given problem_size: \n"; + return false; + } + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + status = conv3d_op.initialize(conv3d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv3d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv3d output is written to workspace in global memory + conv3d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv3d_args.output_op = {1.0, 0.0}; + // update conv3d operator arguments + status = conv3d_op.update(conv3d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv3d operator + status = conv3d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv3dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv3d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv3d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta + ); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta + ); +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv3d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv3d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv3d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "ndhwc_" + << problem_size.N << "x" + << problem_size.D << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_ktrsc_" + << problem_size.K << "x" + << problem_size.T << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_d << "x" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_d << "x" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_d << "x" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv3d::ThreadblockShape::kM << "x" + << Conv3d::ThreadblockShape::kN << "x" + << Conv3d::ThreadblockShape::kK << "_" + << Conv3d::WarpShape::kM << "x" + << Conv3d::WarpShape::kN << "x" + << Conv3d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllConv3d( + const Conv3dProblemVector & conv_test_sizes = Conv3dProblemVector(), + const Conv3dProblemVector & conv_blacklist_sizes = Conv3dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + //TestbedConv3d testbed(cutlass::Distribution::Sequential, cutlass::Distribution::Sequential, cutlass::Distribution::Sequential); + TestbedConv3d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv3d problem sizes to avoid duplicate runs + Conv3dProblemVector conv_tested_sizes; + + Conv3dProblemVector const *problem_vectors[] = { + &conv3d_problems.conv3d_default_sizes, + &conv3d_problems.conv3d_vnet_medical_sizes, + &conv_test_sizes + }; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv3dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity) || + (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport == + cutlass::conv::StrideSupport::kUnity))) { + if (!((conv_problem.stride_d == 1) && + (conv_problem.stride_h == 1) && + (conv_problem.stride_w == 1)) + ) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( + {1, 8, 8, 8, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +template +bool TestSpecificConv3d( + const Conv3dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3d testbed; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..f8ba785c9d0ecbdd518711714558c9e166c0209a --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h @@ -0,0 +1,732 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM for fused epilogue broadcast testbed + + Parallel split-k is not tested because we can just use regular conv kernel + when we need to use parallel-splitk. Broadcast can happen in the reduction + kernel. +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv3d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Conv3dWithBroadcastReferenceOp { + + using OutputOp = typename Conv3d::EpilogueOutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + Conv3dWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute conv3d, ElementCompute bias) { + ElementCompute t_full = binary_op(conv3d, bias); + T = ElementT(t_full); + + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = CONV(AB, C) +// +// T[n, o, p, q, k] = ReductionOp(Y[n, o, p, q, k], Broadcast[k]) +// +// Z[n, o, p, q, k] = Elementwise(T[n, o, p, q, k]) +// + +template < + typename Conv3d, + typename ReferenceOp, + bool AddBroadcastFirst = false +> +class TestbedConv3dWithBroadcast { +public: + + using ElementA = typename Conv3d::ElementA; + using LayoutA = typename Conv3d::LayoutA; + using ElementB = typename Conv3d::ElementB; + using LayoutB = typename Conv3d::LayoutB; + using ElementC = typename Conv3d::ElementC; + using LayoutC = typename Conv3d::LayoutC; + using ElementAccumulator = typename Conv3d::ElementAccumulator; + using ElementCompute = typename Conv3d::ElementCompute; + using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; + using ElementZ = typename EpilogueOutputOp::ElementZ; + using ElementT = typename EpilogueOutputOp::ElementT; + using ElementVector = typename EpilogueOutputOp::ElementVector; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; + static const bool kAddBroadcastFirst = AddBroadcastFirst; + static const bool kStoreT = EpilogueOutputOp::kStoreT; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_C_reference; + cutlass::HostTensor tensor_Z_computed; + cutlass::HostTensor tensor_Z_reference; + cutlass::HostTensor tensor_T_computed; + cutlass::HostTensor tensor_T_reference; + cutlass::HostTensor tensor_Y_reference; + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + +public: + + TestbedConv3dWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv3dProblemSize const &problem_size, bool non_packed_test = false, uint64_t seed = 2019) { + + // to make the layout of tensors a little bit bigger than the problem size + cutlass::Tensor5DCoord stride_increment = cutlass::Tensor5DCoord(8, 16, 32, 32, 64); + + cutlass::Tensor5DCoord tensor_A_extent = implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size); + cutlass::Tensor5DCoord tensor_B_extent = implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size); + cutlass::Tensor5DCoord tensor_C_extent = implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size); + + if (non_packed_test) { + tensor_A_extent += stride_increment; + tensor_C_extent += stride_increment; + } + + tensor_A.resize(tensor_A_extent); + tensor_B.resize(tensor_B_extent); + tensor_C.resize(tensor_C_extent); + tensor_C_reference.resize(tensor_C_extent); + tensor_Z_computed.resize(tensor_C_extent); + tensor_Z_reference.resize(tensor_C_extent); + tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Y_reference.resize(tensor_C_extent); + tensor_Broadcast.resize({ + 1, + 1, + 1, + 1, + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), + }); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); + for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { + for (int o = 0; o < tensor_C_reference.extent().d(); ++o) { + for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { + for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { + for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { + tensor_C_reference.at({n, o, p, q, k}) = ElementAccumulator(tensor_C.at({n, o, p, q, k})); + } + } + } + } + } + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + tensor_C_reference.sync_device(); + tensor_Z_computed.sync_device(); + tensor_Z_reference.sync_device(); + tensor_T_computed.sync_device(); + tensor_T_reference.sync_device(); + tensor_Y_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv3dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + bool non_packed_test = false, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv3d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size, non_packed_test); + + // configure the operator + Conv3d conv3d_op; + typename Conv3d::Arguments conv3d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_Z_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Broadcast.device_data(), + kStoreT ? tensor_T_computed.device_data() : nullptr, + 0, // This must be zero + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() + ); + + // initialize the kernel + size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv3d_op.initialize(conv3d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv3d operator + status = conv3d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_T_computed.sync_host(); + tensor_Z_computed.sync_host(); + + // + // Reference check + // + + // When kAddBroadcastFirst is true, add bias on the host + ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C_reference.device_ref(), + tensor_Y_reference.device_ref(), + alpha, + beta_ref); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_Y_reference.sync_host(); + +#else + + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C_reference.host_ref(), + tensor_Y_reference.host_ref(), + alpha, + beta_ref); + +#endif + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int n = 0; n < problem_size.N; ++n) { + for (int o = 0; o < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Z : problem_size.D); ++o) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { + + ElementZ z{}; + ElementT t{}; + + ElementCompute accum = tensor_Y_reference.at({n, o, p, q, k}); + ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, 0, k})); + + + if (kAddBroadcastFirst) { + reference_op(z, t, accum + bias, + beta * ElementCompute(tensor_C_reference.at({n, o, p, q, k}))); + } else { + reference_op(z, t, accum, bias); + } + + tensor_Z_reference.at({n, o, p, q, k}) = z; + tensor_T_reference.at({n, o, p, q, k}) = t; + } + } + } + } + } + + if (kStoreT) { + passed = cutlass::reference::host::TensorEquals( + tensor_T_computed.host_view(), + tensor_T_reference.host_view()); + + EXPECT_TRUE(passed); + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Z_computed.host_view(), + tensor_Z_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv3d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "nnhwc_" + << problem_size.N << "x" + << problem_size.D << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.T << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_d << "x" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_d << "x" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_d << "x" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << (non_packed_test ? "non_packed_tensor_test_" : "packed_tensor_test_") + << Conv3d::ThreadblockShape::kM << "x" + << Conv3d::ThreadblockShape::kN << "x" + << Conv3d::ThreadblockShape::kK << "_" + << Conv3d::WarpShape::kM << "x" + << Conv3d::WarpShape::kN << "x" + << Conv3d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" + << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" + << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" + << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" + << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" + << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv3dProblemSizes +// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template , + bool AddBroadcastFirst = false, + bool TestSplitK = true +> +bool TestAllConv3dWithBroadcast( + const Conv3dProblemVector &conv_test_sizes = Conv3dProblemVector(), + const Conv3dProblemVector &conv_blacklist_sizes = Conv3dProblemVector(), + bool non_packed_test = false) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3dWithBroadcast testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv3d problem sizes to avoid duplicate runs + Conv3dProblemVector conv_tested_sizes; + + Conv3dProblemVector const *problem_vectors[] = { + &conv3d_problems.conv3d_default_sizes, + &conv3d_problems.conv3d_vnet_medical_sizes, + &conv_test_sizes + }; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv3dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_d == 1) && + (conv_problem.stride_h == 1) && + (conv_problem.stride_w == 1)) + ) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_d == 1) && (conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + } + } + + if (!TestSplitK) + return passed; + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv3d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( + {1, 8, 8, 8, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + false,/*non_packed_test*/ + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +template , + bool AddBroadcastFirst = false> +bool TestSpecificConv3dWithBroadcast( + const Conv3dProblemVector & problem_sizes, + bool non_packed_test = false) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3dWithBroadcast testbed; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross, non_packed_test = false + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + + // test mode = convolution, non_packed_test = false + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..cef5f981c595dfbbb95658fb757865b219538192 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h @@ -0,0 +1,473 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Depthwise Direct Conv testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "../cache_testbed_output.h" +#include "conv2d_problems.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/cutlass.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedDepthwiseDirectConv2d { + public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + public: + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_reordered_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + int tested_problem_count; + + public: + TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {} + + /// Helper to initialize a tensor view + template + void initialize_tensor(cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + if (dist_kind == cutlass::Distribution::Uniform) { + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } else { + scope = 5; + } + } else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0); + } else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + + } else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } else { + } + } + + void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_reordered_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient(int smem_size) const { + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < static_cast(smem_size)) { + return false; + } + + return true; + } + + /// Executes one test + bool run(cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1.5), + ElementCompute beta = ElementCompute(1)) { + // increment tested problem count run by the testbed + tested_problem_count++; + +#if 0 // display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " + << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") + << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args(problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + tensor_reordered_B.device_ref(), + split_k_mode); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.can_implement(problem_size); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + if (!sufficient(conv2d_op.get_smem_size())) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run." << std::endl; + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = + CreateCachedConv2dTestKey(kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view()); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d(kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d(kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + std::stringstream ss_problem_size_text; + ss_problem_size_text << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_DirectConv_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << ss_problem_size_text.str() + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) { + bool passed = true; + + // + // Testbed object + // + TestbedDepthwiseDirectConv2d testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (auto conv_problem : problem_sizes) { + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..54c11281e14b813b249d7f9710542843b37bcc68 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp @@ -0,0 +1,1385 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS 3.x Implicit GEMM testbed sizes for ConvNd problem +*/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +std::vector> +inline +get_conv_problem_vector(); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Fprop +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {800, 80, 1}, // stride (nwc) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {512, 64, 1}, // stride (nwc) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nqk) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, + {16,1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {96, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 64}, + {256, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 4, 64}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and tstride of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {1}, + {2}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {8000, 800, 80, 1}, // stride (nhwc) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {4096, 512, 64, 1}, // stride (nhwc) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {8000, 800, 80, 1}, // stride (npqk) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, + {16, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {96, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 64}, + {256, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {256, 3, 3, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,2/1,2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {256, 2, 5, 64}, + {1, 1}, + {2, 2}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 7, 7, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {2, 3}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 15, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {2, 3}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // ndhwc + {64, 1, 1, 1, 64}, // ktrsc + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // non-packed input output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // ndhwc + {8000, 8000, 800, 80, 1}, // stride (ndhwc) + {64, 1, 1, 1, 64}, // ktrsc + {64, 64, 64, 64, 1}, // stride (ktrsc) + {8000, 8000, 800, 80, 1}, // stride (nzpqk) + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, + {16, 1, 1, 1, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // N = 7 and K = 256 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 64}, + {96, 1, 1, 1, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x3x3 + no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 3, 3, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x3x3 + symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 3, 3, 32}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + symmetric padding 111 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 4, 5, 64}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {2, 2, 3}, + 1 + }); + return problem_shapes; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Wgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, + {16,1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {96, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 64}, + {256, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 4, 32}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and tstride of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {1}, + {2}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1024, 128}, + {640, 1, 128}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1040, 128}, + {640, 1, 128}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, + {16, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {96, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 64}, + {256, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 15, 16, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {2, 3}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 15, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {2, 3}, + {2, 3}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 64, 16, 128}, + {640, 1, 1, 128}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 65, 16, 128}, + {640, 1, 1, 128}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 64}, // ndhwc + {64, 1, 1, 1, 64}, // ktrsc + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // Filter 3x3x3 + no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 3, 3, 32}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 64, 16, 128}, + {640, 1, 1, 1, 128}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 65, 16, 128}, + {640, 1, 1, 1, 128}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Grouped Wgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Get problem size vectors for group conv problems +template +std::vector> +inline +get_grouped_conv_problem_vector(int GroupsPerTile); + +// Specialization for 3D wgrad problems +template<> +std::vector> inline +get_grouped_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>(int GroupsPerTile) { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + + if (GroupsPerTile == 1) { + // channel_per_group == 64 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 2048}, // ndhwc + {2048, 1, 3, 3, 64}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 2) { + // channel_per_group == 32 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 1024}, // ndhwc + {1024, 1, 3, 3, 32}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 4) { + // channel_per_group == 16 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 512}, // ndhwc + {512, 1, 3, 3, 16}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 8) { + // channel_per_group == 8 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 256}, // ndhwc + {256, 1, 3, 3, 8}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Unit Stride Dgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {512, 64, 1}, // stride (nqk) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 16}, + {64, 1, 16}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 96}, + {64, 1, 96}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 256}, + {64, 1, 256}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {64, 3, 256}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding with k % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {32, 3, 256}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {64, 4, 256}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 64}, + {256, 3, 64}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {8000, 800, 80, 1}, // stride (npqk) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {4096, 512, 64, 1}, // stride (npqk) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {8000, 800, 80, 1}, // stride (nhwc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 16}, + {64, 1, 1, 16}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 96}, + {64, 1, 1, 96}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 256}, + {64, 1, 1, 256}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {64, 3, 3, 256}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding with k % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {32, 3, 3, 256}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {64, 2, 5, 256}, + {1, 1}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 16}, + {64, 1, 1, 1, 16}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // non-packed input output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // nzpqk + {8000, 8000, 800, 80, 1}, // stride (nzpqk) + {64, 1, 1, 1, 64}, // ktrsc + {64, 64, 64, 64, 1}, // stride (ktrsc) + {8000, 8000, 800, 80, 1}, // stride (ndhwc) + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // N = 7 and K = 256 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 96}, + {64, 1, 1, 1, 96}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + symmetric padding 111 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 96}, + {64, 3, 4, 5, 96}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 96}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Strided Dgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Test TMA truncation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 512, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {2}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1024, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 2048, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {8}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // stride divides dilation + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {0}, // padding lower (pad_w) + {1}, // padding upper (pad_w) + {2}, // stride (stride_w) + {4}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // dilation divides stride + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {1}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {2}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // stride dilation dont divide + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {1}, // padding lower (pad_w) + {2}, // padding upper (pad_w) + {2}, // stride (stride_w) + {3}, // dilation (dilation_w) + 1 // group + }); + return problem_shapes; +} + +// Specialization for 2D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // mode 0 stride divides dilation + // mode 1 dilation divides stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {2, 4}, + {4, 2}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // mode 0 dilation divides stride + // mode 1 stride divides dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {4, 2}, + {2, 4}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // stride dilation dont divide + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {3, 2}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {2, 1, 2}, + {4, 2, 3}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..99ba9c407cec38e919812fedeee38ba75d9129f7 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp @@ -0,0 +1,768 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed for 3.x API +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "../../common/cutlass_unit_test.h" + +#include "cute/tensor.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "../test/unit/gemm/device/gemm_testbed_3x.hpp" + +#include "thrust/universal_vector.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/conv.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "conv_problem_sizes.hpp" +#include "../cache_testbed_output.h" + +#include + +#include "cute/layout.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Initializes a flat device buffer +template +static void +initialize_values( + thrust::universal_vector& dst_ptr, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + if (cutlass::Distribution::Uniform == dist_kind) { + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 4; + } + else { + scope = 8; + } + cutlass::reference::host::BlockFillRandomUniform( + dst_ptr.data().get(), dst_ptr.size(), seed, scope, -scope, 0); + } + else if (cutlass::Distribution::Identity == dist_kind) { + cutlass::reference::host::BlockFillRandomUniform( + dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0, 0); + } + else if (cutlass::Distribution::Gaussian == dist_kind) { + cutlass::reference::host::BlockFillRandomGaussian(dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0.5); + } + else if (cutlass::Distribution::Sequential == dist_kind) { + cutlass::reference::host::BlockFillSequential(dst_ptr.data().get(), dst_ptr.size()); + } + else { + std::cerr << "Invalid distribution kind!\n."; + exit(1); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// utils for sparse or dense conv parameters + +template +struct DenseConvParams { + // Default Kernel data types + using ElementA = typename Conv::ConvKernel::ElementA; + using ElementB = typename Conv::ConvKernel::ElementB; + + static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; + using ProblemShape = cutlass::conv::ConvProblemShape; + + // get the default arguments without sparse data + auto get_mainloop_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + thrust::universal_vector& tensor_A, + thrust::universal_vector& tensor_B + ) { + auto args = typename Conv::ConvKernel::MainloopArguments { + tensor_A.data().get(), + tensor_B.data().get(), + }; + return args; + } +}; + +template +struct SparseConvParams { +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ConvTestbed { + // Kernel data types + using ElementA = typename Conv::ConvKernel::ElementA; + using ElementB = typename Conv::ConvKernel::ElementB; + using ElementC = cute::conditional_t, + typename Conv::ConvKernel::ElementD, typename Conv::ConvKernel::ElementC>; + using ElementD = typename Conv::ConvKernel::ElementD; + using ElementAccumulator = typename Conv::ConvKernel::ElementAccumulator; + + // ConvTest for sparse kernel + static constexpr bool isSparseEnabled = isSparseEnabled_; + using ConvParams = cute::conditional_t, DenseConvParams>; + ConvParams params; + + // + // FusionOperation derived types/queries + // + using FusionOp = typename Conv::EpilogueOutputOp; + + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementCompute = typename FusionOp::ElementCompute; + using BiasType = typename cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::type; + using ElementBias = non_void_t; + using ActivationType = non_void_t::type, + cutlass::epilogue::thread::Identity>; + static constexpr bool IsActivationEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithActivation::value; + using ActivationFunctor = cute::conditional_t>; + + static constexpr bool IsBiasEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::value && + !cute::is_same_v; + static constexpr bool IsPerChannelScaleEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithPerChannelScaled::value; + + static constexpr bool DisableSource = cute::is_void_v; + + static constexpr bool IsResidualEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithResidualAdd::value; + + using StrideC = typename Conv::ConvKernel::StrideC; + using StrideD = typename Conv::ConvKernel::StrideD; + using ThreadEpilogueOp = typename Conv::ConvKernel::CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; + using ProblemShape = cutlass::conv::ConvProblemShape; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; + using Splits = typename gemm::device::detail::Splits; + + using Schedule = typename Conv::DispatchPolicy::Schedule; + /// Initialization + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_disable = cutlass::Distribution::Identity; // all zeros + uint64_t seed = 6090; + float epsilon = 0.0f; + int split_p_slices = 1; + thrust::universal_vector tensor_A; + thrust::universal_vector tensor_B; + thrust::universal_vector tensor_C; + thrust::universal_vector tensor_D_computed; + thrust::universal_vector tensor_D_reference; + thrust::universal_vector tensor_bias; + thrust::universal_vector tensor_alpha; + thrust::universal_vector tensor_beta; + + // Return true on success, else false + bool initialize(ProblemShape const& problem_shape, uint64_t seed = 6090) { + tensor_A.resize(sizeof(ElementA) * problem_shape.size_A()); + tensor_B.resize(sizeof(ElementB) * problem_shape.size_B()); + tensor_C.resize(sizeof(ElementC) * problem_shape.size_C()); + tensor_D_computed.resize(sizeof(ElementD) * problem_shape.size_C()); + tensor_D_reference.resize(sizeof(ElementD) * problem_shape.size_C()); + tensor_bias.resize(sizeof(ElementBias) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + if constexpr (IsPerChannelScaleEnabled) { + tensor_alpha.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + tensor_beta.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + } + initialize_values(tensor_A, init_A, seed); + initialize_values(tensor_B, init_B, seed * 11); + initialize_values(tensor_C, init_C, seed * 17); + initialize_values(tensor_bias, init_bias, seed * 19); + if constexpr (IsPerChannelScaleEnabled) { + initialize_values(tensor_alpha, init_bias, seed * 23); + if constexpr (DisableSource) { + initialize_values(tensor_beta, init_disable, seed * 27); + } + else { + initialize_values(tensor_beta, init_bias, seed * 27); + } + } + + bool flag = true; + if constexpr (isSparseEnabled) { + flag &= params.initialize(problem_shape, tensor_B, static_cast(seed + 2023)); + } + + return flag; + } + + // Determine SMEM requirements and waive if not satisfied + bool sufficient() const { + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + int max_smem_size; + result = cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaDeviceGetAttribute() failed"); + } + + return max_smem_size >= Conv::ConvKernel::SharedStorageSize; + } + + auto transform_shape_and_stride_with_groups(ProblemShape const& problem_shape) { + using TensorExtent = cute::array; + using TensorStride = cute::array; + + TensorExtent shape_a_g{}; + TensorExtent shape_b_g{}; + TensorExtent shape_c_g{}; + TensorStride stride_a_g{}; + TensorStride stride_b_g{}; + TensorStride stride_c_g{}; + + auto shape_a = cute::reverse(problem_shape.shape_A); + auto shape_b = cute::reverse(problem_shape.shape_B); + auto shape_c = cute::reverse(problem_shape.shape_C); + auto stride_a = cute::reverse(problem_shape.stride_A); + auto stride_b = cute::reverse(problem_shape.stride_B); + auto stride_c = cute::reverse(problem_shape.stride_C); + + int32_t G = problem_shape.groups; + + if constexpr (ConvOp == cutlass::conv::Operator::kFprop || + ConvOp == cutlass::conv::Operator::kDgrad) { + // shape_a_g = (c,w,h,d,n,g) or (k,q,p,z,n,g) + // shape_b_g = (c,s,r,k,t,g) + // shape_c_g = (k,q,p,z,n,g) or (c,w,h,d,n,g) + shape_a_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_a) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_a), + cute::make_shape(G))); + shape_b_g = cute::to_array(tuple_cat( + cute::take<0,NumSpatialDimensions + 1>(shape_b), + cute::make_shape(cute::size(shape_b) / G, G))); + shape_c_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_c) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_c), + cute::make_shape(G))); + + stride_a_g = cute::to_array(append(stride_a, cute::size<0>(shape_a) / G)); + stride_b_g = cute::to_array(append(stride_b, + cute::size(stride_b) * cute::size(shape_b) / G)); + stride_c_g = cute::to_array(append(stride_c, cute::size<0>(shape_c) / G)); + } + else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + // shape_a_g = (k,q,p,z,n,g) + // shape_b_g = (c,w,h,d,n,g) + // shape_c_g = (c,s,r,k,t,g) + shape_a_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_a) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_a), + cute::make_shape(G))); + shape_b_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_b) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_b), + cute::make_shape(G))); + shape_c_g = cute::to_array(tuple_cat( + cute::take<0,NumSpatialDimensions + 1>(shape_c), + cute::make_shape(cute::size(shape_c) / G, G))); + + stride_a_g = cute::to_array(append(stride_a, cute::size<0>(shape_a) / G)); + stride_b_g = cute::to_array(append(stride_b, cute::size<0>(shape_b) / G)); + stride_c_g = cute::to_array(append(stride_c, + cute::size(stride_c) * cute::size(shape_c) / G)); + } + + return make_tuple(shape_a_g, shape_b_g, shape_c_g, + stride_a_g, stride_b_g, stride_c_g); + } + + // Executes one test + bool run( + ProblemShape const& problem_shape, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + dim3 cluster_shape = dim3(0, 0, 0), + dim3 cluster_shape_fallback = dim3(0, 0, 0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + MaxSwizzleSize max_swizzle = MaxSwizzleSize{}, + Splits splits = Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic + ) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device.\n"; + } + return true; + } + + bool ret = initialize(problem_shape); + + if (!ret) { + std::cerr << "initialize failed for the given problem_shape: \n"; + return false; + } + + cutlass::KernelHardwareInfo hw_info; + cudaGetDevice(&hw_info.device_id); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + hw_info.cluster_shape = cluster_shape; + hw_info.cluster_shape_fallback = cluster_shape_fallback; + + // configure the operator + Conv conv_op; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); + } + // Need to support non-packed output strides for fprop and dgrad kernel. + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { + cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { + cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + } + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{}; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + + auto mainloop_args = params.get_mainloop_arguments(problem_shape, tensor_A, tensor_B); + + auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments { + {}, + tensor_C.data().get(), + stride_C, + tensor_D_computed.data().get(), + stride_D, + }; + + auto args = typename Conv::Arguments { + problem_shape, + mainloop_args, // MainloopArguments + epilogue_args, // EpilogueArguments + hw_info, + scheduler_args + }; + + auto &fusion_args = args.epilogue.thread; + + fusion_args.alpha = alpha; + fusion_args.beta = beta; + + if constexpr (IsPerChannelScaleEnabled) { + fusion_args.alpha_ptr = tensor_alpha.data().get(); + fusion_args.beta_ptr = tensor_beta.data().get(); + } + + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = tensor_bias.data().get(); + } + + // Clamp bound + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::lowest(); + fusion_args.activation.upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); + } + + // Scale + if constexpr (cute::is_same_v> || + cute::is_same_v> || + cute::is_same_v> || + cute::is_same_v> ) { + fusion_args.activation.scale = ElementCompute{1}; + } + + // LeakyRelu + if constexpr (cute::is_same_v> ) { + fusion_args.activation.leaky_alpha = ElementCompute{0}; + } + + cutlass::Status status = cutlass::Status::kInvalid; + + status = conv_op.can_implement(args); + EXPECT_EQ(conv_op.can_implement(args), cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "can_implement failed for the given problem_shape: \n"; + print(problem_shape); + return false; + } + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv::get_workspace_size(args); + thrust::universal_vector workspace(workspace_size); + + status = conv_op.initialize(args, workspace.data().get()); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv3d operator + status = conv_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " Kernel execution error: " + << cudaGetErrorString(result); + + // Create cute::Tensors using the logical rank-3 MNK multi-mode shapes the mainloop gives us + auto [shape_mA, shape_mB, shape_mC, stride_mA, stride_mB, stride_mC] = + transform_shape_and_stride_with_groups(problem_shape); + auto shape_mBias = cute::make_shape(cute::size(cute::get<0>(problem_shape.get_shape_B()))); + + auto mA = make_tensor(tensor_A.data().get(), make_layout(shape_mA, stride_mA)); + auto mB = make_tensor(tensor_B.data().get(), make_layout(shape_mB, stride_mB)); + auto mC = make_tensor(tensor_C.data().get(), make_layout(shape_mC, stride_mC)); + auto mD_ref = make_tensor(tensor_D_reference.data().get(), make_layout(shape_mC, stride_mC)); + auto mD_computed = make_tensor(tensor_D_computed.data().get(), make_layout(shape_mC, stride_mC)); + auto mBias = make_tensor(tensor_bias.data().get(), make_layout(shape_mBias)); + auto mAlpha = make_tensor(tensor_alpha.data().get(), make_layout(shape_mBias)); + auto mBeta = make_tensor(tensor_beta.data().get(), make_layout(shape_mBias)); + + cutlass::reference::host::ConvEpilogueFusionParams< + ElementAccumulator, + ElementScalar, + ElementCompute, + ElementC, + ElementD, + IsResidualEnabled, + decltype(mAlpha), + decltype(mBeta), + decltype(mBias), + ActivationFunctor> + epilogue_fusion_params{}; + + epilogue_fusion_params.alpha = alpha; + epilogue_fusion_params.beta = beta; + + if constexpr (IsPerChannelScaleEnabled) { + epilogue_fusion_params.tensor_alpha = mAlpha; + epilogue_fusion_params.tensor_beta = mBeta; + } + + if constexpr (IsBiasEnabled) { + epilogue_fusion_params.tensor_bias = mBias; + } + + auto padding = cute::reverse(problem_shape.lower_padding); + auto tstride = cute::reverse(problem_shape.traversal_stride); + auto dilation = cute::reverse(problem_shape.dilation); + + cutlass::reference::host::ConvReferenceImpl< + ConvOp, + NumSpatialDimensions, + decltype(mA), + decltype(mB), + decltype(mC), + decltype(mD_ref), + decltype(padding), + decltype(tstride), + decltype(dilation), + decltype(epilogue_fusion_params)> + reference_impl(mA, mB, mC, mD_ref, padding, tstride, dilation, epilogue_fusion_params); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConvNd3xTestKey< + ProblemShape, + ElementA, + ElementB, + ElementC, + ElementD + >( + ConvOp, + problem_shape, + alpha, + beta, + tensor_A, + tensor_B, + tensor_C + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string convnd_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + CachedTestResultListing cached_results(convnd_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + #endif + + if (!cached_result_loaded) { + // Compute reference + reference_impl.compute_reference(); + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + cached_test_result.D = TensorHash(tensor_D_reference); + CachedTestResultListing cached_results(convnd_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(convnd_result_cache_name); + #endif + } // if (!cached_result_loaded) + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + uint32_t tensor_D_computed_hash = TensorHash(tensor_D_computed); + passed = (tensor_D_computed_hash == cached_test_result.D); + // If hash fails, double check against reference implementation. + if(!passed) { + std::cerr << "Hash-based comparison unsuccessful for key:" << "\n" << cached_test_key + << ", comparing with reference implementation now.\n"; + if (cached_result_loaded) { + // Compute reference + reference_impl.compute_reference(); + } + // Validate kernel against reference + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); + } + #else + // Validate kernel against reference + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); + #endif + + EXPECT_TRUE(passed); + return passed; + } + + template< + class Engine, class Layout, + class EngineA, class LayoutA, + class EngineB, class LayoutB, + class EngineAlpha, class LayoutAlpha, + class EngineBeta, class LayoutBeta, + class EngineBias, class LayoutBias> + static constexpr bool + compare_reference( + cute::Tensor const& reference, + cute::Tensor const& computed, + cute::Tensor const& A, + cute::Tensor const& B, + cute::Tensor const& tensor_alpha, + cute::Tensor const& tensor_beta, + cute::Tensor const& tensor_bias, + float epsilon = 0.0f) { + if (size(reference) != size(computed)) { + return false; + } + + bool passed = true; + if (epsilon == 0.0f) { + // fast refcheck w/o epsilon + for (size_t i = 0; i < size_t(size(reference)); ++i) { + if (reference(i) != computed(i)) { + passed = false; + printf("[%llu] %f, %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + break; + } + } + } else { + // refcheck with epsilon + for (size_t i = 0; i < size_t(size(reference)); ++i) { + auto ref = static_cast(reference(i)); + auto act = static_cast(computed(i)); + auto abs_error = std::abs(act - ref); + auto rel_error = abs_error / (std::max(std::abs(act), std::abs(ref)) + 0.00001f); + if (std::isnan(abs_error) || std::isnan(rel_error) || + std::min(abs_error, rel_error) > epsilon) { + passed = false; + printf("[%llu] %f, %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + break; + } + } + } + #if CUTLASS_DEBUG_TRACE_LEVEL > 1 + if (not passed) { + cute::print("Reference:"); + cute::print_tensor(reference); + cute::print("\nComputed:"); + cute::print_tensor(computed); + cute::print("\n"); + + for (size_t i = 0; i < size_t(size(A)); ++i) { + printf("[%llu]: A = %f\n", static_cast(i), float(A(i))); + } + for (size_t i = 0; i < size_t(size(B)); ++i) { + printf("[%llu]: B = %f\n", static_cast(i), float(B(i))); + } + if constexpr (IsPerChannelScaleEnabled) { + for (size_t i = 0; i < size_t(size(tensor_alpha)); ++i) { + printf("[%llu]: alpha = %f\n", static_cast(i), + float(tensor_alpha(i))); + } + for (size_t i = 0; i < size_t(size(tensor_beta)); ++i) { + printf("[%llu]: beta = %f\n", static_cast(i), + float(tensor_beta(i))); + } + } + if constexpr (IsBiasEnabled) { + for (size_t i = 0; i < size_t(size(tensor_bias)); ++i) { + printf("[%llu]: bias = %f\n", static_cast(i), + float(tensor_bias(i))); + } + } + for (size_t i = 0; i < size_t(size(reference)); ++i) { + printf("[%llu]: ref = %f, computed = %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + } + } + #endif + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f, + dim3 cluster_shape = dim3(0, 0, 0), + dim3 cluster_shape_fallback = dim3(0, 0, 0) + ) { + using ElementScalar = typename Conv::EpilogueOutputOp::ElementScalar; + + bool passed = true; + ConvTestbed testbed; + testbed.epsilon = epsilon; + auto problem_vector = get_conv_problem_vector< + Conv::NumSpatialDimensions, Conv::DispatchPolicy::ConvOp, SupportStrides>(); + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; + using Splits = typename gemm::device::detail::Splits; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + + for (auto conv_problem : problem_vector) { + #if CUTLASS_DEBUG_TRACE_LEVEL > 0 + print(conv_problem); + #endif + for (DecompositionMode decomp_mode : decomposition_modes) { + std::vector problem_splits = {Splits{1}}; + if constexpr (UsesStreamKScheduler) { + if (decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(Splits{2}); + problem_splits.push_back(Splits{4}); + } + } + for (auto splits : problem_splits) { + + passed = testbed.run( + conv_problem, + cutlass::from_real(alpha), + cutlass::from_real(beta), + cluster_shape, + cluster_shape_fallback, + RasterOrderOptions::Heuristic, // raster_order + MaxSwizzleSize(1), + splits, + decomp_mode + ); + if (!passed) { + printf("Failed test for "); print(conv_problem); + return false; + } + } // splits + } // decomposition_mode + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace test::conv::device + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff170be142ff9d0d02cc684c2873c3ec014bd236 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +using namespace cute; + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; +}; + +template +__global__ void +test_tiled_cp_async_device_cute(T const* g_in, T* g_out, + TiledCopy const tiled_copy, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + auto thr_copy = tiled_copy.get_slice(threadIdx.x); + Tensor gA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor gB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); + + auto tAgA = thr_copy.partition_S(gA); + auto tAsA = thr_copy.partition_D(sA); + +#if 0 + if (thread0()) { + print("gA : "); print(gA.layout()); print("\n"); + print("sA : "); print(sA.layout()); print("\n"); + print("tAgA: "); print(tAgA.layout()); print("\n"); + print("tAsA: "); print(tAsA.layout()); print("\n"); + } +#endif + + copy(tiled_copy, tAgA, tAsA); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // Store trivially smem -> gmem + + if (thread0()) { + copy(sA, gB); + } + +} + +template +void +test_tiled_cp_async( + TiledCopy const tiled_copy, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast(i % 13); } + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + test_tiled_cp_async_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tiled_copy, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < size(hA_out) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } +} + +template +void test_cp_async_no_swizzle() { + using namespace cute; + auto smem_atom = SMEM_LAYOUT{}; + auto smem_layout = tile_to_shape(smem_atom, Shape{}); + auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); + test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); +} + +template +void test_cp_async_with_swizzle() { + using namespace cute; + auto swizzle_atom = SWIZZLE_ATOM{}; + auto smem_atom = composition(swizzle_atom, SMEM_LAYOUT{}); + auto smem_layout = tile_to_shape(smem_atom, Shape{}); + auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); + test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); +} diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3ff20d4087ee2fd6f4f74338e3e63eef27c221d3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp @@ -0,0 +1,775 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/relatively_equal.h" +#include "cutlass_unit_test.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include + +#include +#include + +#include + +using namespace cute; + +template +struct fp64_tester { + using value_type = double; +}; + +template +struct fp64_tester> { + using value_type = complex; +}; + +template // logical shape (M, N) +auto host_generate_gemm_inputs( + ALayout a_layout, + BLayout b_layout, + CLayout c_layout +) { + thrust::host_vector h_a(cosize(a_layout)); + thrust::host_vector h_b(cosize(b_layout)); + thrust::host_vector h_c(cosize(c_layout)); + thrust::host_vector h_c_out(cosize(c_layout)); + + auto h_a_tensor = make_tensor(h_a.data(), a_layout); + auto h_b_tensor = make_tensor(h_b.data(), b_layout); + auto h_c_tensor = make_tensor(h_c.data(), c_layout); + size_t max_size = std::max({static_cast(size(a_layout)), + static_cast(size(b_layout)), + static_cast(size(c_layout))}); + for (size_t i = 0; i < max_size; ++i) { + double di = static_cast(i); + if(i < size(a_layout)) { + h_a_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(b_layout)) { + h_b_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(c_layout)) { + h_c_tensor(i) = static_cast((di*di) / size(a_layout)); + } + } + + return std::make_tuple(h_a, h_b, h_c, h_c_out); +} + +template +thrust::host_vector +host_reference_gemm(Alpha alpha, + Tensor const& h_a_tensor, + Tensor const& h_b_tensor, + Beta beta, + Tensor const& h_c_tensor, + ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) + { + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TA = remove_cv_t; + using TB = remove_cv_t; + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + + thrust::host_vector h_c_ref(cosize(h_c_tensor.layout()), static_cast(0.0)); + auto h_c_ref_tensor = make_tensor(h_c_ref.data(), h_c_tensor.layout()); + // A * B + for (int k = 0; k < size<1>(h_a_tensor); k++) { + for (int m = 0; m < size<0>(h_a_tensor); m++) { + for (int n = 0; n < size<0>(h_b_tensor); n++) { + const auto a_value = a_load_transform(h_a_tensor(m, k)); + const auto b_value = b_load_transform(h_b_tensor(n, k)); + const auto a_value_fp64 = static_cast(a_value); + const auto b_value_fp64 = static_cast(b_value); + h_c_ref_tensor(m, n) += static_cast(a_value_fp64 * b_value_fp64); + } + } + } + // C = A*B + C + for (int i = 0; i < size(h_c_ref_tensor); i++) { + const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); + const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); + h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); + } + + return h_c_ref; +} + +template +void verify_gemm_correctness(cute::Tensor const& h_c_out_tensor, + cute::Tensor const& h_c_ref_tensor) +{ + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + + for (int i = 0; i < size(h_c_ref_tensor); i++) { + ABC_64 h_c_ref_i = h_c_ref_tensor(i); + ABC_64 h_c_out_i = h_c_out_tensor(i); + double epsilon(0.1f); + double nonzero_floor(std::numeric_limits::min()); + bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor); + ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i; + } +} + + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + Alpha const alpha, + Beta const beta, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op, + SMemCopyLdOpC c_copy_ld_op, + SMemCopyStOpC c_copy_st_op) +{ + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(smem_b_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), smem_c_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + cooperative_copy(threadIdx.x, g_c_tensor, s_c_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + cooperative_gemm( + threadIdx.x, tiled_mma, + alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_copy_op, b_copy_op, c_copy_ld_op, c_copy_st_op + ); + __syncthreads(); + + cooperative_copy(threadIdx.x, s_c_tensor, g_c_out_tensor); +} + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op) + { + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // Create C fragment for storing intermediate results + auto thr_mma = TiledMma().get_thread_slice(threadIdx.x); + Tensor g_c_partition = thr_mma.partition_C(g_c_tensor); + Tensor g_c_out_partition = thr_mma.partition_C(g_c_out_tensor); + Tensor r_c_partition = thr_mma.make_fragment_C(g_c_partition); + + // Create indexing help for predicated GEMMs + Tensor cC = make_identity_tensor(shape(gmem_c_layout)); + Tensor tCcC = thr_mma.partition_C(cC); + + // Load C from global + // (always loading in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + r_c_partition(i) = c_load_transform(g_c_partition(i)); + } + } + + cooperative_gemm( + threadIdx.x, tiled_mma, s_a_tensor, s_b_tensor, r_c_partition, + a_load_transform, b_load_transform, a_copy_op, b_copy_op + ); + + __syncthreads(); + + // Store C to global + // (always storing in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + g_c_out_partition(i) = c_store_transform(r_c_partition(i)); + } + } +} + +template, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyLdOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyStOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}, + CSMemCopyLdOp c_smem_copy_ld_op = {}, + CSMemCopyStOp c_smem_copy_st_op = {}) +{ + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK + + static_assert(size<0>(smem_a_layout) == size<0>(smem_c_layout)); // AM == CM + static_assert(size<0>(smem_b_layout) == size<1>(smem_c_layout)); // BN == CN + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + static_assert(cute::size(gmem_c_layout) == cute::size(smem_c_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.1); + const auto beta = static_cast(1.2); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) + + sizeof(TC) * h_c.size(); + + + auto kernel = cooperative_gemm_kernel< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, SMemCLayout, + TA, TB, TC, decltype(alpha), decltype(beta), + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp, CSMemCopyLdOp, CSMemCopyStOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + alpha, + beta, + tiled_mma, + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform, + a_smem_copy_op, + b_smem_copy_op, + c_smem_copy_ld_op, + c_smem_copy_st_op + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Copy result data + h_c_out = d_c_out; + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); +} + +template, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}) +{ + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK + + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.0); + const auto beta = static_cast(1.0); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = + host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), static_cast(-1)); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes); + + + auto kernel = cooperative_gemm_kernel_rmem_c< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, + TA, TB, TC, + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + tiled_mma, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_smem_copy_op, b_smem_copy_op + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + // Copy result data + h_c_out = d_c_out; + + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); +} + +template +void test_cooperative_gemm_col_major_layout(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + test_cooperative_gemm + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma, + ops...); +} + + +template +std::enable_if_t, + cute::is_layout, + cute::is_layout>> +test_cooperative_gemm_col_major_layout(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + SMemAtomLayoutC smem_atom_layout_c, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops&& ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + auto smem_c_layout = tile_to_shape( + smem_atom_layout_c, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); + + test_cooperative_gemm + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma, + ops...); +} + + +template +void test_cooperative_gemm_col_major_layout_rmem_c(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + + test_cooperative_gemm_rmem_c + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + tiled_mma, + ops...); +} + +template +std::enable_if_t, + cute::is_layout>> +test_cooperative_gemm_col_major_layout_rmem_c(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + test_cooperative_gemm_rmem_c + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + tiled_mma, + ops...); +} + +template +void test_cooperative_gemm_col_major_layout_rmem_c(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout_rmem_c, + T, T, T> + (static_cast(args)...); +} + +template +void test_cooperative_gemm_col_major_layout(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout, + T, T, T> + (static_cast(args)...); +} diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d2620e62ff247e36ae49809ab4ef3416560ae31 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp @@ -0,0 +1,217 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, + CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + constexpr int R = rank_v; + Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + + // + // Prepare the TMA_LOAD + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N) + Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) + +#if 0 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA_x: "); print(tAgA_x); print("\n"); + print("tAsA_x: "); print(tAsA_x); print("\n"); + } +#endif + + // + // Perform the TMA_LOAD + // + + // INPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles + Tensor tAgA = group_modes<1,rank(tAgA_x)>(tAgA_x); // (TMA,REST) + Tensor tAsA = group_modes<1,rank(tAsA_x)>(tAsA_x); // (TMA,REST) + static_assert(size<1>(tAsA) == 1); + + // OUTPUT: Group the CTA_TILE_X modes and REST_X modes for output + Tensor tBgB = group_modes<0,R>(group_modes(gB)); // (CTA_TILE, REST) + +#if 0 + if (thread0()) { + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + } +#endif + + // Test L2 prefetch + if (threadIdx.x == 0) { + prefetch(tma, tAgA); + } + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = sizeof(make_tensor_like(tensor<0>(tAsA))); + + if (threadIdx.x == 0) + { + /// Initialize shared memory barrier + tma_load_mbar[0] = 0; + cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0]), tAgA(_,stage), tAsA(_,0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value + constexpr int kPhaseBit = 0; + cute::wait_barrier(tma_load_mbar[0], kPhaseBit); + + // + // Write out trivially smem -> gmem + // + + // Subbyte elements could cause race conditions, so be even more conservative + if (thread0()) { + copy(sA, tBgB(_,stage)); + } + + __syncthreads(); + } +} + +template +auto +test_tma_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); + //print(tma); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tma, cta_tile, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3e0ec46df1b672c35c3c38f731c09b0134d4cd80 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include +#include +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, GmemLayout gmem_layout, SmemLayout smem_layout, + CUTE_GRID_CONSTANT CopyAtom const tma, CTA_Tiler cta_tiler, Cluster_Size cluster_size) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + Tensor gA = zipped_divide(mA, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + Tensor gB = zipped_divide(mB, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + +#if 1 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Prepare the TMA_LOAD + // + + Tensor sA_x = make_tensor(sA.data(), make_layout(sA.layout(), Layout<_1>{})); // ((CTA_TILE_M,CTA_TILE_N,...),_1) + Tensor tBgB = gB; // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + auto [tAgA, tAsA] = tma_partition(tma, cta_rank_in_cluster, make_layout(cluster_size), sA_x, gA); + +#if 1 + if (thread0()) { + print("sA_x : "); print(sA_x); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // TMA Multicast Masks -- Get a mask of the active ctas in each TMA + // + + + int elected_cta_rank = 0; + bool elect_one_cta = (elected_cta_rank == cta_rank_in_cluster); + bool elect_one_thr = cute::elect_one_sync(); + + uint16_t tma_mcast_mask = ((uint16_t(1) << cluster_size) - 1); + +#if 1 + if (thread0()) { + print("tma_mcast_mask : "); print(tma_mcast_mask); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Perform the TMA_LOAD + // + + if (elect_one_thr) { + // Initialize TMA barrier + cute::initialize_barrier(tma_load_mbar[0], /* num_threads */ 1); + } + int tma_phase_bit = 0; + // Ensures all CTAs in the Cluster have initialized + __syncthreads(); + cute::cluster_sync(); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); + + if (elect_one_thr) + { + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0], tma_mcast_mask), tAgA(_,stage), tAsA(_,0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from tma_phase_bit value + cute::wait_barrier(tma_load_mbar[0], tma_phase_bit); + tma_phase_bit ^= 1; + + // + // Write out trivially smem -> gmem + // + + // Subbyte elements could cause race conditions, so be even more conservative + if (elect_one_cta && elect_one_thr) { + copy(sA, tBgB(_,stage)); + } + + __syncthreads(); + cute::cluster_sync(); + } +} + +template +auto +test_tma_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = make_tma_atom(copy_op, gA, smem_layout, cta_tiler, cluster_size); + //print(tma); + + // Launch + + dim3 dimBlock(32); + dim3 dimCluster(size(cluster_size)); + dim3 dimGrid = dimCluster; + int smem_size = sizeof(SharedStorage); + + void* kernel_ptr = (void*) &tma_test_device_cute; + + cutlass::launch_kernel_on_cluster({dimGrid, dimBlock, dimCluster, smem_size}, + kernel_ptr, + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast(raw_pointer_cast(d_out.data())), + gmem_layout, + smem_layout, + tma, cta_tiler, cluster_size); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0429d2435fbf43c690f311c1f7c04f7025a2dd94 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, + CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor mB = tma.get_tma_tensor(shape(gmem_layout)); + + constexpr int R = rank_v; + Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + + // + // Prepare the TMA_STORE + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + Tensor tBsB_x = cta_tma.partition_S(sB); // (TMA,TMA_M,TMA_N) + Tensor tBgB_x = cta_tma.partition_D(gB); // (TMA,TMA_M,TMA_N,REST_M,REST_N) + +#if 0 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mB : "); print( mB.data()); print(" o "); print( mB.layout()); print("\n"); + print(" gB : "); print( gB.data()); print(" o "); print( gB.layout()); print("\n"); + print("tBgB_x: "); print(tBgB_x.data()); print(" o "); print(tBgB_x.layout()); print("\n"); + print(" sB : "); print( sB.data()); print(" o "); print( sB.layout()); print("\n"); + print("tBsB_x: "); print(tBsB_x.data()); print(" o "); print(tBsB_x.layout()); print("\n"); + } +#endif + + // + // Perform the TMA_STORE + // + + // INPUT: Group the CTA_TILE_X modes and REST_X modes for input + Tensor tAgA = group_modes<0,R>(group_modes(gA)); // (CTA_TILE, REST) + + // OUTPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles + Tensor tBgB = group_modes<1,rank(tBgB_x)>(tBgB_x); // (TMA,REST) + Tensor tBsB = group_modes<1,rank(tBsB_x)>(tBsB_x); // (TMA,REST) + static_assert(size<1>(tBsB) == 1); + +#if 0 + if (thread0()) { + print("tAgA : "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); + print("tBsB : "); print(tBsB.data()); print(" o "); print(tBsB.layout()); print("\n"); + print("tBgB : "); print(tBgB.data()); print(" o "); print(tBgB.layout()); print("\n"); + } +#endif + + // Test L2 prefetch + cooperative_prefetch<128>(threadIdx.x, gA); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tBgB); ++stage) + { + // + // Read in trivially gmem -> smem + // + // Subbyte elements could cause race conditions, so be even more conservative + if (thread0()) { + copy(tAgA(_,stage), sB); + } + + __syncthreads(); + cute::cp_async_wait<0>(); + + // + // Perform the TMA_STORE + // + + if (threadIdx.x == 0) { + copy(tma, tBsB(_,0), tBgB(_,stage)); + } + + tma_store_wait<0>(); + __syncthreads(); + } +} + +template +void +test_tma_store(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_out.data())), gmem_layout); + auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); + //print(tma); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tma, cta_tile, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..3163a0d0eaa24513ee210bd2b310d1bf233773a9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h @@ -0,0 +1,417 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_with_reduction_threadblock( + typename Epilogue::ElementVector *ptr_Reduction, + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + typename Epilogue::TensorTileIterator::Params params_Tensor, + typename Epilogue::TensorTileIterator::Element *ptr_Tensor, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::TensorTileIterator iterator_T( + params_Tensor, + ptr_Tensor, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + accumulator_iterator.load(accumulators); + +#if 0 + // For debugging, enable this block of code to fill each accumulator element with its + // source thread ID. + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < accumulators.size(); ++i) { + typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); + accumulators[i] = x; + } + + __syncthreads(); + +#endif + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue(output_op, ptr_Reduction, iterator_D, accumulators, iterator_C, iterator_T); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpilogueWithReductionTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementTensor = typename Epilogue::TensorTileIterator::Element; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensor accumulator_tensor; + cutlass::HostTensor source_tensor; + cutlass::HostTensor output_tensor; + cutlass::HostTensor additional_tensor; + cutlass::HostTensor reduction_tensor; + + +public: + + // + // Methods + // + + EpilogueWithReductionTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + additional_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + reduction_tensor({1, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + + cutlass::reference::host::TensorFill(additional_tensor.host_view(), ElementTensor(1)); + } + + bool run_all() { + + /* + double alpha_values[] = {1, 0, 2.25}; + double beta_values[] = {0, 1, -1.25}; + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + int m = quantized_size.row() - m_idx * 3; + int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; + + for (double const &alpha : alpha_values) { + for (double const &beta : beta_values) { + + bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); + + if (!passed) { + return false; + } + } + } + } + } + return true; + */ + + double alpha = 1; + double beta = 0; + + return run( + {quantized_size.row(), quantized_size.column()}, + {cutlass::from_real(alpha), cutlass::from_real(beta)}); + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ElementOutput default_output = ElementOutput(-127); + ElementAccumulator default_reduction = ElementAccumulator(); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + cutlass::reference::host::TensorFill(reduction_tensor.host_view(), default_reduction); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + additional_tensor.sync_device(); + reduction_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); + typename Epilogue::TensorTileIterator::Params params_T(additional_tensor.device_ref().layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_with_reduction_threadblock<<< grid, block >>>( + reduction_tensor.device_data(), + params_D, + output_tensor.device_data(), + params_C, + source_tensor.device_data(), + params_T, + additional_tensor.device_data(), + output_params, + problem_size, + accumulator_tensor.device_view()); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + reduction_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + // + // The output has two parts: + // - GEMM tensor epilogue in canonical layout + // - partial reduction in canonical row-major layout + // + + // Verify the GEMM tensor output + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ElementOutput got = output_tensor.at(coord); + + ElementOutput expected; + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + expected = ElementOutput(output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ElementCompute(source_tensor.at(coord))); + } + else { + expected = default_output; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // Verify the partial reduction + for (int c = 0; c < quantized_size.column(); ++c) { + + ElementAccumulator reduction_acc = ElementAccumulator(); + + for (int r = 0; r < quantized_size.row(); ++r) { + reduction_acc += accumulator_tensor.at({r, c}); + } + + ElementAccumulator expected = default_reduction; + ElementAccumulator got = reduction_tensor.at({0, c}); + + if (c < problem_size.column()) { + expected = reduction_acc; + } + else { + expected = default_reduction; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - reduction element (" << c << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + } + } + + // + // Report results on error + // + + if (errors) { + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e2457fdb4817e1dfb3af73149ae1e4c4458670a2 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h @@ -0,0 +1,356 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for epilogues +*/ +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/platform/platform.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_threadblock( + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + accumulator_iterator.load(accumulators); + +#if 0 + // For debugging, enable this block of code to fill each accumulator element with its + // source thread ID. + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < accumulators.size(); ++i) { + typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); + accumulators[i] = x; + } + + __syncthreads(); + +#endif + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue(output_op, iterator_D, accumulators, iterator_C); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpilogueTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensor accumulator_tensor; + cutlass::HostTensor source_tensor; + cutlass::HostTensor output_tensor; + +public: + + // + // Methods + // + + EpilogueTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 2, + -2, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 2, + -2, + 0); + } + + bool run_all() { + + double alpha_values[] = {1, 0, 2.25}; + double beta_values[] = {0, 1, -1.25}; + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + int m = quantized_size.row() - m_idx * 3; + int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; + + for (double const &alpha : alpha_values) { + for (double const &beta : beta_values) { + + bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); + + if (!passed) { + return false; + } + } + } + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ElementOutput default_output = ElementOutput(-127); + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_threadblock<<< grid, block >>>( + params_D, + output_tensor.device_data(), + params_C, + source_tensor.device_data(), + output_params, + problem_size, + accumulator_tensor.device_view()); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ElementOutput got = output_tensor.at(coord); + + ElementOutput expected; + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + ElementCompute intermediate = + output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ElementCompute(source_tensor.at(coord)); + + if ((cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || std::numeric_limits::is_integer) + && !std::numeric_limits::is_integer) { + std::fesetround(FE_TONEAREST); + expected = ElementOutput(std::nearbyint(float(cutlass::real(intermediate)))); + } else { + expected = ElementOutput(intermediate); + } + } else { + expected = default_output; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) + << ", accum: " << (accumulator_tensor.at(coord)) + << ", source: " << OutputIO(source_tensor.at(coord)) + << ", alpha: " << (output_params.alpha) + << ", beta: " << (output_params.beta) << "\n"; + + ++errors; + } + } + } + + // + // Report results on error + // + + if (errors) { + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..a76578f7638ac1d30161a9bcb55ecec70b5c43e0 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h @@ -0,0 +1,394 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_planar_complex_threadblock( + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + int64_t imaginary_stride_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + int64_t imaginary_stride_C, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int64_t imaginary_stride_accum, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D_real( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_D_imag( + params_D, + ptr_D + imaginary_stride_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C_real( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_C_imag( + params_C, + ptr_C + imaginary_stride_C, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + // + // Load accumulators + // + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + + accumulator_iterator.load(accumulators.real); + accumulator_iterator.load_with_pointer_offset(accumulators.imag, imaginary_stride_accum); + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop so assembly is clearly visible + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue( + output_op, + iterator_D_real, + iterator_D_imag, + accumulators, + iterator_C_real, + iterator_C_imag); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpiloguePlanarComplexTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + + using ComplexElementOutput = cutlass::complex; + using ComplexElementAccumulator = cutlass::complex; + using ComplexElementCompute = cutlass::complex; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensorPlanarComplex accumulator_tensor; + cutlass::HostTensorPlanarComplex source_tensor; + cutlass::HostTensorPlanarComplex output_tensor; + +public: + + // + // Methods + // + + EpiloguePlanarComplexTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + #if 1 + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + #else + + cutlass::reference::host::BlockFillSequential(accumulator_tensor.host_data(), accumulator_tensor.capacity()); + + #endif + } + + bool run_all() { + + cutlass::complex alpha_values[3]; + + alpha_values[0] = cutlass::complex(1, 0); + alpha_values[1] = cutlass::complex(0, 0); + alpha_values[2] = cutlass::complex(2.25f, -0.5f); + + cutlass::complex beta_values[3]; + + beta_values[0] = cutlass::complex(0, 0); + beta_values[1] = cutlass::complex(1, 0); + beta_values[2] = cutlass::complex(0.5f, -2.25f); + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + cutlass::MatrixCoord problem_size( + quantized_size.row() - m_idx * 3, + quantized_size.column() - n_idx * Epilogue::kElementsPerAccess + ); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = run(problem_size, {alpha, beta}); + + if (!passed) { + return false; + } + } + } + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ComplexElementOutput default_output = ComplexElementOutput(ElementOutput(-127), ElementOutput(-101)); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_planar_complex_threadblock<<< grid, block >>>( + params_D, + output_tensor.device_data(), + output_tensor.imaginary_stride(), + params_C, + source_tensor.device_data(), + source_tensor.imaginary_stride(), + output_params, + problem_size, + accumulator_tensor.device_view_real(), + accumulator_tensor.imaginary_stride() + ); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ComplexElementOutput got = output_tensor.at(coord); + + ComplexElementOutput expected = default_output; + + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + ComplexElementOutput src = source_tensor.at(coord); + + ComplexElementCompute tmp = + output_params.alpha * ComplexElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ComplexElementCompute(src.real(), src.imag()); + + expected = ComplexElementOutput(ElementOutput(tmp.real()), ElementOutput(tmp.imag())); + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // + // Report results on error + // + + if (errors) { + + + std::cout << "Incorrect result for problem(" + << problem_size.row() << ", " + << problem_size.column() << ") for alpha: " << output_params.alpha << ", beta: " << output_params.beta << std::endl; + + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + + std::cout << "Wrote workspace to '" << ss.str() << "'" << std::endl; + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0054a1b6757a232e9177407fdd2041b6a91cffb9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp @@ -0,0 +1,1384 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/layout.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +namespace cutlass { +namespace gemm { +namespace device { +using namespace cute; + +// This type is only intended to demonstrate porting 2.x kernels to 3.0 +template< + class OperatorClass, class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types { + static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); +}; + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct DefaultGemm_TensorOpSm80_OperandA; + +template +struct DefaultGemm_TensorOpSm80_OperandB; + +// +// F16: 128-by-128-by-64 +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// +// F16: 128-by-128-by-32 (small k-block) +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,3,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere MMA F32F16 +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + half_t, LayoutA, + half_t, LayoutB, + float, LayoutC, + float> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32,_32,_16>>; // 32x32x16 MMA for LDSM, 1x2x1 value group + + // A + static constexpr int kAlignmentA = 8; + using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< + half_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 8; + using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< + half_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + half_t, TagToStrideA_t, + half_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + float, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// +// TF32: 128-by-128-by-kblock (kBlock = 16, 32) +// + +/// Operand A - Row-major (K-major) (kBlock = 32) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,2,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Row-major (K-major) (kBlock = 16) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,2,3>{}, + Layout, + Stride<_16, _1>>{})); + using SmemCopyAtom = Copy_Atom; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,2,3>{}, + Layout, + Stride< _1,_32>>{})); + using SmemCopyAtom = Copy_Atom, tfloat32_t>; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere MMA F32TF32 +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + tfloat32_t, LayoutA, + tfloat32_t, LayoutB, + float, LayoutC, + float> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout, Stride<_2, _1, _1>>, // 2x2x1 thread group + Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group + + // A + static constexpr int kAlignmentA = 4; + using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< + tfloat32_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 4; + using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< + tfloat32_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + tfloat32_t, TagToStrideA_t, + tfloat32_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + float, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + int32_t, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _64>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32,_32,_32>>; // 16x16x32 MMA for LDSM, 1x2x1 value group + + // A (M,K) K-major + using SmemLayoutAtomA = decltype( + composition( + Swizzle<2,4,3>{}, + Layout, + Stride<_64, _1>>{})); + static constexpr int kAlignmentA = 16; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>>{})); + // LDS.32- or LDSM-based copy atom + // using SmemCopyAtomA = Copy_Atom; + using SmemCopyAtomA = Copy_Atom; // LDSM works + + // B (N,K) K-major + using SmemLayoutAtomB = decltype( + composition( + Swizzle<2,4,3>{}, + Layout, + Stride<_64, _1>>{})); + static constexpr int kAlignmentB = 16; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>>{})); + + // LDS.32- or LDSM-based copy atom + // using SmemCopyAtomB = Copy_Atom; + using SmemCopyAtomB = Copy_Atom; // LDSM works + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + int8_t, TagToStrideA_t, + int8_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + int32_t, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// SIMT TWO STAGE /////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct DefaultGemm_Simt_OperandA; + +/////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultGemm_Simt_OperandA +{ + using SmemLayoutAtom = Layout, + Stride< _1,_128>>; + + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); +}; + +template +struct DefaultGemm_Simt_OperandA +{ + using SmemLayoutAtom = Layout, + Stride< _1,Int<128 + 4>>>; // Padded + + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + +}; + +template +struct DefaultGemm_Simt_OperandB; + +template +struct DefaultGemm_Simt_OperandB + : DefaultGemm_Simt_OperandA {}; + +template +struct DefaultGemm_Simt_OperandB + : DefaultGemm_Simt_OperandA {}; + +} // end namespace detail + +// SIMT Two Stage +template < + class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _8>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>>; + + // A + static constexpr int kAlignmentA = 1; + using DefaultOperandA = detail::DefaultGemm_Simt_OperandA; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 1; + using DefaultOperandB = detail::DefaultGemm_Simt_OperandB; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + + +// +// DP4A - int8 Proof-of-concept +// + +// SIMT Two Stage TN - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + // NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; // Tile of atoms (threads) + + // A (M,K) K-major + using ElementA = int8_t; + // 40% from regular M and N major layout + // using SmemLayoutAtomA = Layout, + // Stride< _1,_128>>; + // 80% from interleaved layouts + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 4; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) K-major + using ElementB = int8_t; + // 40% from regular M and N major layout + // using SmemLayoutAtomB = Layout, + // Stride< _1,_128>>; + // 80% from interleaved layouts + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 4; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage NN - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + + using DispatchPolicy = MainloopSm70TwoStage; + + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) M-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // B (N,K) K-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 4; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage NT - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::RowMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) M-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // B (N,K) N-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage TT - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::RowMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) K-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 4; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) N-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// SIMT MULTI STAGE ////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage NT +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>, // 32x32x1 MMA with perm for load vectorization + Layout,Stride<_2,_1>>,Underscore>>; + + // A (M,K) M-major + using SmemLayoutAtomA = Layout>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout>{}, + Layout>{})); + + // B (N,K) N-major + using SmemLayoutAtomB = Layout>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage TN +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>>; + + // A (M,K) K-major + using SmemLayoutAtomA = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride<_16, _1>>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride<_16, _1>>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage NN +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>,Underscore,Underscore>>; // 32x16x1 MMA with perm for load vectorization + + // A (M,K) M-major + using SmemLayoutAtomA = Layout>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout>{}, + Layout>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride<_16, _1>>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage TT +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>,Underscore>>; // 16x32x1 MMA with perm for load vectorization + + // A (M,K) K-major + using SmemLayoutAtomA = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride<_16, _1>>{})); + + // B (N,K) N-major + using SmemLayoutAtomB = Layout>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA TN (K-Major A and K-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) K-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // B (N,K) K-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; + +/* + using EpilogueOutputOp = epilogue::collective::Epilogue< + epilogue::thread::LinearCombination, + Layout, + Stride< _1,_64>>, // SMEM layout + Copy_Atom,double>, // R2S with tiled_mma layout + decltype(make_tiled_copy(Copy_Atom,double>{},// S2R + Layout, + Stride< _1,_16>>{}, // Thread layout + Layout>{})), // Value layout + Copy_Atom,double> // R2G with S2R_dst layout + >; +*/ +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA NN (M-Major A and K-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) M-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // B (N,K) K-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{}));// N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA NT (M-Major A and N-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) M-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // B (N,K) N-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA TT (K-Major A and N-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) K-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // B (N,K) N-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Hopper fp64 MMA TN +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm90, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) K-major + using SmemLayoutAtomA = decltype( + make_ordered_layout(Shape<_128,_16>{}, + Step < _2, _1>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = decltype( + make_ordered_layout(Shape<_64,_16>{}, + Step < _2, _1>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + double, double, + double, cutlass::layout::ColumnMajor, 1, + double, cutlass::layout::ColumnMajor, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89755dd7d3162b114a537e58c6aa33cac80078f9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -0,0 +1,3993 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include // std::lcm + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/complex.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/detail/collective.hpp" + +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" +#include "cute/layout.hpp" +#include "cute/numeric/int.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class ScalarLoc { + ON_HOST = 0, + ON_DEVICE = 1 +}; + +enum class VectorScale { + DISABLED = 0, + ENABLED = 1 +}; + +enum class CheckEquality { + EXACT = 0, + RELATIVE = 1 +}; + +namespace detail { + +inline constexpr auto decomp_mode_to_string = + [] (cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode mode) -> std::string { + using Mode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + if (mode == Mode::Heuristic) { + return "Heuristic"; + } + else if (mode == Mode::DataParallel) { + return "DataParallel"; + } + else if (mode == Mode::SplitK) { + return "SplitK"; + } + else if (mode == Mode::StreamK) { + return "StreamK"; + } + else { + return "Unknown"; + } + }; + +inline constexpr auto raster_order_to_string = + [] (cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions mode) -> std::string { + using Mode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + if (mode == Mode::Heuristic) { + return "Heuristic"; + } + else if (mode == Mode::AlongM) { + return "AlongM"; + } + else if (mode == Mode::AlongN) { + return "AlongN"; + } + else { + return "Unknown"; + } + }; + +// Helper classes that take default data type when +// the Gemm::EpilogueOutputOp does not have ElementCompute +// and ElementScalar. +// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) +template +struct ElementComputeType { + using Type = Default; +}; + +template +struct ElementComputeType>> { + using Type = typename Gemm::EpilogueOutputOp::ElementCompute; +}; + +template +struct ElementScalarType { + using Type = Default; +}; + +template +struct ElementScalarType>> { + using Type = typename Gemm::EpilogueOutputOp::ElementScalar; +}; + + +template +struct IsF8F6F4Kernel { + static constexpr bool value = false; +}; + +template +struct IsF8F6F4Kernel> { + static constexpr bool value = true; +}; + + +template +struct IsSfdEpi : cute::false_type {}; + +template +struct IsSfdEpi> : cute::true_type {}; + +// The maximum swizzle size to use +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +class MaxSwizzleSize { +public: + MaxSwizzleSize() = default; + + template && + !cute::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +template +struct IsDefaultEpilogue { + static constexpr bool value = false; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsLegacyEpiloguePolicy { + static constexpr bool value = false; +}; + +template +struct IsLegacyEpiloguePolicy> { + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool value = cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize>>; +}; + +// The number of splits to test. +// +// This class makes it harder to confuse the order of arguments +// of the various run(...) functions in this file. The constructor +// is explicit, so one can't just type 42 (or false, which the +// compiler unhelpfully turns into 0); one has to type Splits(42). +// Splits() picks the default number of splits, 1. +// +// The conversion-to-int operator (operator int()) MUST be explicit! +// Conversion to int MUST require static_cast. +// Otherwise, that defeats a key purpose of this class, +// which is to catch common errors of confusing the order +// of function arguments. +class Splits { +public: + Splits() = default; + + template && + !cute::is_same_v)) > + explicit Splits(IntegralNotBool splits) : splits_(splits) {} + explicit operator int() const { return splits_; } +private: + int splits_ = 1; +}; + +// The number of iterations to test. +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +// Iterations() picks the default number of iterations, 20. +class Iterations { +public: + Iterations() = default; + + template && + !cute::is_same_v)) > + explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} + explicit operator int() const { return iterations_; } +private: + int iterations_ = 20; +}; + +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + + else if (bits_input <= 8) { + + if constexpr ( + cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + + scope_max = 1; + scope_min = -1; + + } + + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Looks at Cute Stride to check Row / Column Major +template +static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); +} + + +// +// Default MMA input Operands : A , B +// +template< + class ScheduleType_, + class Gemm, + class ElementA_ = typename Gemm::GemmKernel::ElementA, + class ElementB_ = typename Gemm::GemmKernel::ElementB, + class Enable = void> +struct HostCollectiveMainloop { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; + + StrideA stride_a; + StrideB stride_b; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_), + check_relative_equality(check_relative_equality_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (generic)::initialize(problem_shape)"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M * L, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N * L); + + try { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.resize"); +#endif + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.resize"); +#endif + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor A or B resize threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor A or B resize threw an unknown exception"); + throw; + } + + try { + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + } + catch (cutlass::cuda_exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked initialize_tensor threw cutlass::cuda_exception: " << e); + throw; + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked initialize_tensor threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked_initialize_tensor threw an unknown exception"); + throw; + } + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: Check last error before sync_device()"); + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: cudaGetLastError() is " << error_str); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.host_data()=" << tensor_A.host_data() << ", tensor_A.device_data()=" << tensor_A.device_data()); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.host_data()=" << tensor_B.host_data() << ", tensor_B.device_data()=" << tensor_B.device_data()); + } +#endif + try { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.sync_device"); +#endif + tensor_A.sync_device(); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.sync_device"); +#endif + tensor_B.sync_device(); + } + catch (cutlass::cuda_exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw cutlass::cuda_exception: " << e); + throw; + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw an unknown exception"); + throw; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: Reached end"); +#endif + return true; + } + + Arguments to_args() { + + + // Runtime datatype selection + if constexpr (not cute::is_same_v) { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A.device_data()), stride_a, + reinterpret_cast(tensor_B.device_data()), stride_b + }; + } + else { + + Arguments arguments = + { + tensor_A.device_data(), stride_a, tensor_B.device_data(), stride_b + }; + return arguments; + } + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + + + auto dummy_SFA = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto dummy_SFB = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + + cutlass::reference::host::GettMainloopParams mainloop_params{}; + + mainloop_params.A = A; + mainloop_params.B = B; + mainloop_params.transform_A = TransformA; + mainloop_params.transform_B = TransformB; + + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view(); + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + + bool passed = true; + return passed; + } +}; + +// +// Sparse MMA host implementation +// +template< + class Gemm, + class ElementA_, + class ElementB_> +struct HostCollectiveMainloopSparse +{ + + // Kernel data types + using ElementA = ElementA_; + // CuTe layout A for the kernel's sparse tensorA. + using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + // CuTe layout E for the kernel's metadata tensor. + using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + // The following typenames are for the reference host tensors. They are non-sparse tensors. + using LayoutTagA = decltype(SparseConfig::deduce_layoutA_tag(LayoutA{})); + using StrideA = cutlass::gemm::TagToStrideA_t; + // We don't care about the actual strideE for the host tensor, but just need one to allocate memory. + using StrideE = StrideA; + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + using LayoutTagE = cutlass::detail::StrideToLayoutTagA_t; + + using ArchTag = typename Gemm::ArchTag; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig>; + + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig, + ArchTag>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + StrideA stride_a; + StrideA stride_a_compressed; + StrideB stride_b; + StrideE stride_e; + + LayoutA layout_a; + LayoutE layout_e; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + typename LayoutTagE::Stride stride_factor_E; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_E; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr int MaxSmCount = 16; + + HostCollectiveMainloopSparse( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride(), + typename LayoutTagE::Stride stride_factor_E_ = typename LayoutTagE::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_E(stride_factor_E_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloopSparse::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_a_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); + stride_e = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); + + auto a_coord = cutlass::make_Coord(M * L, K); + auto b_coord = cutlass::make_Coord(K, N * L); + auto e_coord = cutlass::make_Coord(MAlignedE * L, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * L, KAlignedAC); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_Comp.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_E.sync_device(); + tensor_A_Comp.sync_device(); + + cutlass::Status status {cutlass::Status::kSuccess }; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {M, N, K, L}, + {tensor_A.device_data(), + stride_a, + tensor_A_Comp.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + layout_a = SparseConfig::fill_layoutA(problem_shape_MNKL); + layout_e = SparseConfig::fill_layoutE(problem_shape_MNKL); + + tensor_E.sync_host(); + tensor_A_Comp.sync_host(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A_Comp.device_data()), layout_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_E.device_data(), layout_e + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + return true; + } +}; + +template< + class ScheduleType_, + class Gemm, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + typename Gemm::CollectiveMainloop::DispatchPolicy>>> + : HostCollectiveMainloopSparse +{ + using HostCollectiveMainloopSparse::HostCollectiveMainloopSparse; +}; + +// +// Sparse MMA input Operands : A_compressed, B, metadata +// +// Structured Sparse Gemm Input Operands + +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + typename ElementA_, + typename ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> + : HostCollectiveMainloopSparse +{ + using HostCollectiveMainloopSparse::HostCollectiveMainloopSparse; +}; + +// +// Sparse Gemm Input Operands : A , B, E +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_ >; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +// +// Sparse Gemm Input Operands : A , B, E +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_ >; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + StrideA stride_a; + StrideB stride_b; + + LayoutSFA layout_sfa; + LayoutSFB layout_sfb; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_SFA; + cutlass::HostTensor tensor_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelTmaWarpSpecializedBlockScaledSm100)::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M * L, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N * L); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + tensor_A.sync_device(); + tensor_B.sync_device(); + + using namespace cute; + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + layout_sfb = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); + EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_SFA.host_view().at({0, 0}) = ElementSF(1); + tensor_SFB.host_view().at({0, 0}) = ElementSF(1); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A.device_data()), stride_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_SFA.device_data(), layout_sfa, + tensor_SFB.device_data(), layout_sfb + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); + + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); + + cutlass::reference::host::GettMainloopParams + mainloop_params{A, SfA, B, SfB}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nSFA =\n" << tensor_SFA.host_view() + << "\nSFB =\n" << tensor_SFB.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); + return true; + } +}; + + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Structured Sparse Gemm Input Operands : A_compressed, B, metadata, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + typename ElementA_, + typename ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + // CuTe layout A for the kernel's sparse tensorA. + using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + // CuTe layout E for the kernel's metadata tensor. + using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + // The following typenames are for the reference host tensors. They are non-sparse tensors. + using LayoutTagA = decltype(SparseConfig::deduce_layoutA_tag(LayoutA{})); + using StrideA = cutlass::gemm::TagToStrideA_t; + // We don't care about the actual strideE for the host tensor, but just need one to allocate memory. + using StrideE = StrideA; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using LayoutTagE = cutlass::detail::StrideToLayoutTagA_t; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig, + cutlass::arch::Sm100>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + StrideA stride_a; + StrideA stride_a_compressed; + StrideB stride_b; + StrideE stride_e; + + LayoutA layout_a; + LayoutE layout_e; + LayoutSFA layout_sfa; + LayoutSFB layout_sfb; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + typename LayoutTagE::Stride stride_factor_E; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_SFA; + cutlass::HostTensor tensor_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride(), + typename LayoutTagE::Stride stride_factor_E_ = typename LayoutTagE::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_E(stride_factor_E_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelSparseTmaWarpSpecializedBlockScaledSm100)::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_a_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); + stride_e = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); + + auto a_coord = cutlass::make_Coord(M * L, K); + auto b_coord = cutlass::make_Coord(K, N * L); + auto e_coord = cutlass::make_Coord(MAlignedE * L, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * L, KAlignedAC); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_Comp.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_E.sync_device(); + tensor_A_Comp.sync_device(); + + cutlass::Status status {cutlass::Status::kSuccess }; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {M, N, K, L}, + {tensor_A.device_data(), + stride_a, + tensor_A_Comp.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + layout_a = SparseConfig::fill_layoutA(problem_shape_MNKL); + layout_e = SparseConfig::fill_layoutE(problem_shape_MNKL); + + tensor_E.sync_host(); + tensor_A_Comp.sync_host(); + + using namespace cute; + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + layout_sfb = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); + EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_SFA.host_view().at({0, 0}) = ElementSF(1); + tensor_SFB.host_view().at({0, 0}) = ElementSF(1); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A_Comp.device_data()), layout_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_E.device_data(), layout_e, + tensor_SFA.device_data(), layout_sfa, + tensor_SFB.device_data(), layout_sfb + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); + + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); + + // return {A, SfA, B, SfB}; + cutlass::reference::host::GettMainloopParams + mainloop_params{A, SfA, B, SfB}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nSFA =\n" << tensor_SFA.host_view() + << "\nSFB =\n" << tensor_SFB.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); + return true; + } +}; + +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +template +struct HostCollectiveDefaultEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + + using FusionOp = typename Gemm::EpilogueOutputOp; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + StrideC stride_c; + StrideD stride_d; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + cutlass::HostTensor tensor_C; + // Inputs + ElementScalar alpha; + ElementScalar beta; + + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveDefaultEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize(problem_size, alpha, beta)"); +#endif + // Initialize Epilogue tensors + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M * L, N); + try { + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: resizing tensors threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: resizing tensors threw an unknown exception"); + throw; + } + { + const bool init_succeeded = initialize_tensor(tensor_C.host_view(), init_C, seed + 2020); + if (not init_succeeded) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: initialize_tensor returned false"); + } + EXPECT_TRUE(init_succeeded); + } + tensor_C.host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + try { + tensor_C.sync_device(); + tensor_D.sync_device(); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: sync_device() threw an unknown exception"); + throw; + } + + alpha = alpha_; + beta = beta_; + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { + auto [M, N, K, L] = problem_shape_MNKL; + + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } + + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + } + + bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D)> + epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha; + epilogue_params.beta = beta; + + return epilogue_params; + } +}; + +template +struct HostCollectiveEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + + // + // FusionOperation derived types/queries + // + static constexpr bool IsLegacy = detail::IsLegacyEpiloguePolicy::value; + + // FFMA2 SGEMM uses ThreadEpilogueOp for bias and relu support instead of FusionOp, so we compose LinCombPerRowBiasEltAct FusionOp by hand to test the functionality. + static constexpr bool IsFfma2Kernel = cute::is_same_v; + using FusionOp = cute::conditional_t, + typename Gemm::EpilogueOutputOp>; + static_assert(cute::is_base_of_v); + + + // Scale factor Generation related + using SfStrategy = cutlass::reference::host::SfStrategy; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; + static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; + static constexpr bool IsKMajorSFD = cute::is_same_v; + using ElementSFD = non_void_t; + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; + using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; + cutlass::HostTensor tensor_SFD; + cutlass::HostTensor reference_SFD; + + using ElementCompute = typename FusionOp::ElementCompute; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementBias = non_void_t; + using ElementAux = non_void_t; + using ElementAmax = non_void_t; + using LayoutTagAux = non_void_t; + using ActivationFunctor = non_void_t>; + + static constexpr bool IsRowBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsColBiasEnabled = FusionOp::IsPerColBiasSupported; + static_assert(not (IsColBiasEnabled && IsRowBiasEnabled)); + + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; + static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsPerColScaleEnabled = FusionOp::IsPerColScaleSupported; + static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; + static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; + static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + StrideC stride_c; + StrideD stride_d; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + cutlass::HostTensor alpha; + cutlass::HostTensor beta; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor bias; + cutlass::HostTensor tensor_C; + cutlass::HostTensor norm_constant; + + // Outputs + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor tensor_Aux; + cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // References + cutlass::HostTensor reference_dbias; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If vector scale is supported and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors + cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; + // Random distribution with which to initialize the bias vector + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize(problem_size, alpha, beta)"); +#endif + // Initialize Epilogue tensors + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M * L, N); + try { + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: resizing tensors threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: resizing tensors threw an unknown exception"); + throw; + } + + try { + const bool initialize_tensor_C_succeeded = + initialize_tensor(tensor_C.host_view(), init_C, seed + 2020); + if (not initialize_tensor_C_succeeded) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor returned false"); + } + EXPECT_TRUE(initialize_tensor_C_succeeded); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor threw an unknown exception"); + throw; + } + + tensor_C.host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + try { + tensor_C.sync_device(); + tensor_D.sync_device(); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: sync_device() threw an unknown exception"); + throw; + } + + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); + auto row_vector_coord = cutlass::make_Coord(N); + auto batch_vector_coord = cutlass::make_Coord(L); + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { + // scalars + if (vector_scale_mode == VectorScale::DISABLED) { + // batched scalars + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + alpha.resize(batch_vector_coord, true); + beta.resize(batch_vector_coord, true); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (beta_ != ElementScalar(0)) { + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + else { + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + // non-batched scalars + else { + alpha.resize(scalar_coord, false); + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + // batched vectors + else { + auto batched_vector_coord = cutlass::make_Coord((IsPerRowScaleEnabled ? M : N) * L); + alpha.resize(batched_vector_coord, true); + beta.resize(batched_vector_coord, true); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (beta_ != ElementScalar(0)) { + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + else { + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + } + else { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + // Set alpha beta for different batches. + alpha.resize(batch_vector_coord, true); + beta.resize(batch_vector_coord, true); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + for (int l = 0; l < L; ++l) { + beta.host_view().at(cutlass::make_Coord(l)) = beta_ + ElementScalar(l); + } + } + else { + alpha.resize(scalar_coord, false); + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + alpha.sync_device(); + beta.sync_device(); + + if constexpr (IsScaleFactorEnabled) { + scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); + EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); + EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); + EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { + bias.resize(IsRowBiasEnabled ? col_vector_coord : row_vector_coord); + EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); + bias.sync_device(); + } + + if constexpr (IsDeBiasEnabled) { + bias.resize(col_vector_coord); + reference_dbias.resize(col_vector_coord); + cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); + cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabledD) { + abs_max_D.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_D.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_D.sync_device(); + reference_abs_max_D.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); + } + + if constexpr (IsAuxInEnabled) { + auto aux_coord = cutlass::make_Coord(M * L, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensor_Aux.resize(aux_coord, aux_layout); + EXPECT_TRUE(initialize_tensor(tensor_Aux.host_view(), init_C, seed + 2023)); + tensor_Aux.sync_device(); + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); + } + + if constexpr (IsAuxOutEnabled) { + auto aux_coord = cutlass::make_Coord(M * L, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensor_Aux.resize(aux_coord, aux_layout); + reference_Aux.resize(aux_coord, aux_layout, false); + tensor_Aux.sync_device(); + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); + + if constexpr (IsScaleFactorEnabled) { + scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); + scale_Aux.sync_device(); + } + + if constexpr (IsAbsMaxEnabledAux) { + abs_max_Aux.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_Aux.sync_device(); + reference_abs_max_Aux.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); + } + } + + + if constexpr (IsBlockScaleSupported) { + auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); + auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); + auto sfd_coord = [&] () { + if constexpr (IsKMajorSFD) { + return cutlass::make_Coord(m_blks * Blk_MN{} * L, n_blks * Blk_SF{}); + } + else { + return cutlass::make_Coord(m_blks * Blk_SF{} * L, n_blks * Blk_MN{}); + } + }(); + tensor_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D)); + reference_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false); + tensor_SFD.sync_device(); + norm_constant.resize(scalar_coord, true); + EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); + norm_constant.sync_device(); + } + + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } + + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + } + + bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); + if(!passed) { + #if 0 + auto [M, N, K, L] = problem_shape_MNKL; + auto ref = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto comp = cute::make_tensor(detail::make_iterator(tensor_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + for(int i=0; i(ElementD(ref(i, j, l))) != static_cast((ElementD(comp(i, j, l))))) { + printf(" ref: %f comp: %f\n", i, j, l, static_cast(ElementD(ref(i, j, l))), static_cast((ElementD(comp(i, j, l))))); + } + } + } + } + #endif + std::cout<<"D is incorrect"<(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + Arguments arguments = + { + {}, + tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d + }; + + auto &fusion_args = arguments.thread; + if constexpr (IsLegacy) { + arguments.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.ptr_Bias = bias.device_data(); + arguments.ptr_T = tensor_Aux.device_data(); + } + else { + fusion_args.alpha = alpha.at(coord_0); + fusion_args.alpha_ptr = alpha.device_data(); + // Only initializing beta/beta_ptr for non-void source + if constexpr (not cute::is_void_v) { + fusion_args.beta = beta.at(coord_0); + fusion_args.beta_ptr = beta.device_data(); // if vector_scale_mode is true this is nullptr + } + + if constexpr (IsPerRowScaleEnabled) { + int32_t m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + fusion_args.dAlpha = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); + fusion_args.dBeta = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); + } + else if constexpr (IsPerColScaleEnabled) { + int32_t n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + fusion_args.dAlpha = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); + fusion_args.dBeta = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); + } + else { + if constexpr (not IsFfma2Kernel) { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + if (L > 1) { + fusion_args.dAlpha = cute::make_stride(cute::_0{},cute::_0{}, int64_t(1)); + fusion_args.dBeta = cute::make_stride(cute::_0{},cute::_0{}, int64_t(1)); + } + } + } + } + + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } + + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); + } + + // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty + auto init_activation_args = [] (auto activation, auto& args) { + using Activation = cute::remove_cvref_t; + if constexpr (cute::is_same_v>) { + args.lower_bound = 0; // Treat Clamp as ReLU + args.upper_bound = cutlass::platform::identity_for_minimum(); + } + if constexpr (cute::is_same_v>) { + args.scale = ElementCompute(1); + } + }; + + if constexpr (not cute::is_same_v>) { + init_activation_args(ActivationFunctor{}, fusion_args.activation); + } + if constexpr (IsAbsMaxEnabledD) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + } + + if constexpr (IsAuxOutEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabledAux) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } + } + + + if constexpr (IsBlockScaleSupported) { + arguments.thread.block_scale_factor_ptr = tensor_SFD.device_data(); + arguments.thread.norm_constant_ptr = norm_constant.device_data(); + } + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), + cute::make_layout(cute::make_shape(IsRowBiasEnabled ? M : N))); + auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensor_Aux.host_data() : reference_Aux.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_Aux)); + auto Valpha = [&](){ + if constexpr (IsPerRowScaleEnabled) { + int m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); + } + else if constexpr (IsPerColScaleEnabled) { + int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); + } + else { + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); + } + }(); + + auto Vbeta = [&]() { + if constexpr (IsPerRowScaleEnabled) { + int m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); + } + else if constexpr (IsPerColScaleEnabled) { + int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); + } + else { + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); + } + }(); + + auto SfD = [&](){ + if constexpr (IsBlockScaleSupported) { + auto tensor = make_tensor(detail::make_iterator(reference_SFD.host_data()), + Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + return tensor; + } + else { + // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. + return D; + } + }(); + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + decltype(Bias), + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), + ActivationFunctor, + decltype(SfD), + Int, + cutlass::plus, + IsColBiasEnabled + , SfGenStrategy + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha.at(coord_0); + epilogue_params.beta = beta.at(coord_0); + + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_a = scale_A.at(coord_0); + epilogue_params.scale_b = scale_B.at(coord_0); + epilogue_params.scale_c = scale_C.at(coord_0); + epilogue_params.scale_d = scale_D.at(coord_0); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled or IsDeBiasEnabled) + { + epilogue_params.Bias = Bias; + } + + if constexpr (IsAbsMaxEnabledD) { + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + } + + if constexpr (IsAuxInEnabled) { + epilogue_params.Aux = Aux; + } + + if constexpr (IsAuxOutEnabled) { + epilogue_params.Aux = Aux; + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_aux = scale_Aux.at(coord_0); + } + if constexpr (IsAbsMaxEnabledAux) { + epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); + } + } + + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { + epilogue_params.Valpha = Valpha; + if (vector_scale_mode == VectorScale::ENABLED) { + epilogue_params.Vbeta = Vbeta; + } + } + else { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + epilogue_params.Valpha = Valpha; + epilogue_params.Vbeta = Vbeta; + } + } + + if constexpr (IsBlockScaleSupported) { + epilogue_params.SfD = SfD; + epilogue_params.st = norm_constant.at(coord_0); + } + return epilogue_params; + } +}; + +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* +> +struct TestbedImpl { + // Kernel data types + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type + using HostCollectiveMainloopType = HostCollectiveMainloop; + + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, + HostCollectiveEpilogue>; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; + using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; + using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; + using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + + using InternalElementA = typename Gemm::GemmKernel::ElementA; + using InternalElementB = typename Gemm::GemmKernel::ElementB; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + uint32_t sm_count; + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr uint32_t mma_promotion_interval = 4; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + HostCollectiveMainloopType collective_mma_inputs; + CollectiveEpilogue collective_epilogue; + + // + // Methods + // + + TestbedImpl( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + /// Initializes data structures + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::initialize(problem_size, alpha, beta)"); +#endif + collective_mma_inputs.initialize(problem_size); + collective_epilogue.initialize(problem_size, alpha_, beta_); + + return true; + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) + { + auto [M, N, K, L] = problem_shape_MNKL; + + bool passed = collective_mma_inputs.compare_reference(problem_shape_MNKL); + passed &= collective_epilogue.compare_reference(problem_shape_MNKL, alpha, beta); + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + collective_mma_inputs.print_tensors(file); + collective_epilogue.print_tensors(file); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_size, + ElementScalar alpha, + ElementScalar beta) + { + using namespace cute; + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto mainloop_params = collective_mma_inputs.to_host_args(problem_size); + auto epilogue_params = collective_epilogue.to_host_args(problem_size); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + bool passed = compare_reference(problem_shape_MNKL, alpha, beta); + return passed; + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } + + return true; + } + + bool profile( + ProblemShapeType problem_size, + int iterations, + Gemm& gemm_op, + typename Gemm::Arguments& arguments, + cutlass::device_memory::allocation& workspace) { + int M = cute::size<0>(problem_size); + int N = cute::size<1>(problem_size); + int K = cute::size<2>(problem_size); + int L = 1; + if constexpr(cute::rank(ProblemShapeType{}) == 4) { + L = cute::size<3>(problem_size); + } + + + cutlass::Status status; + // + // Run the GEMM + // + cudaError_t result; + + for (int iter = 0; iter < iterations; ++iter) { + status = gemm_op(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + return false; + } + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + return true; + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + detail::Iterations iterations = detail::Iterations{}, + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} + ) + { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run"); +#endif + + // Fail test if insufficient CUDA device + if (!sufficient()) { + CUTLASS_TRACE_HOST("TestbedImpl::run: Test failed due to insufficient CUDA device"); + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("TestbedImpl::run: sufficient() returned true"); + } +#endif + + try { + const bool initialized = this->initialize(problem_size, alpha, beta); + if (not initialized) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize returned false"); + std::cerr << "Initialization failed \n"; + return false; + } + } + catch ([[maybe_unused]] std::exception const& e) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an unknown exception"); + throw; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize() returned true"); +#endif + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = this->sm_count; + } + else { + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + else { + scheduler_args = { static_cast(max_swizzle), raster_order }; + } + typename HostCollectiveMainloopType::Arguments mainloop_args; + + mainloop_args = collective_mma_inputs.to_args(); + + + if constexpr (IsRuntimeDataType) { + mainloop_args.runtime_data_type_a = runtime_input_datatype_a; + mainloop_args.runtime_data_type_b = runtime_input_datatype_b; + } + + + arguments = + { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + mainloop_args, + collective_epilogue.to_args(problem_size), + hw_info, + scheduler_args + }; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Creating gemm_op"); +#endif + Gemm gemm_op; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling Gemm::get_workspace_size"); +#endif + size_t workspace_size = Gemm::get_workspace_size(arguments); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Allocating workspace of size " << workspace_size); +#endif + cutlass::device_memory::allocation workspace(workspace_size); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.can_implement"); +#endif + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + std::cerr << "This test is not supported: " << error_str << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling profile"); +#endif + return profile(problem_size, static_cast(iterations), gemm_op, arguments, workspace); + } + else { + cudaError_t result; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.initialize"); +#endif + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.run"); +#endif + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling cudaDeviceSynchronize"); +#endif + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaDeviceSynchronize reports non-success"); + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling this->verify"); +#endif + bool passed = this->verify(problem_size, alpha, beta); + if (!passed) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->verify FAILED"); + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta + << "\n"; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->verify passed"); + } +#endif + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Reached end"); +#endif + return passed; + } + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* +> +struct Testbed3x { + + using TestBedImpl = typename detail::TestbedImpl< + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, + ElementB + , RuntimeDatatypeA + , RuntimeDatatypeB + >; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3x( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, + bool profiling = false, + detail::Iterations iterations = detail::Iterations{} + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} + ) + { + return impl_.run( + problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits, decomposition_mode + , runtime_input_datatype_a, runtime_input_datatype_b + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestGemmPerf3x(int iterations = 20) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalar = ElementAccumulator; + bool passed = true; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector problem_size_m = { 4608 }; + std::vector problem_size_n = { 4608 }; + std::vector problem_size_k = { 8192 }; + + Testbed3x testbed; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0), + RasterOrderOptions{}, detail::MaxSwizzleSize(1), detail::Splits{1}, DecompositionMode{}, + true, // profiling + detail::Iterations{iterations}); + + if (!passed) { + return false; + } + } + } + } + + return true; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +template < + typename Gemm, + typename RuntimeDataTypeA, + typename RuntimeDataTypeB, + bool force_legacy_epilogue = false> +bool TestRuntimeDataTypeSmall( + RuntimeDataTypeA runtime_input_datatype_a, + RuntimeDataTypeB runtime_input_datatype_b, + double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, VectorScale vector_scale_mode = VectorScale::ENABLED, std::vector override_problem_size_k = {}) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + + using InternalElementA = typename Gemm::GemmKernel::ElementA; + using InternalElementB = typename Gemm::GemmKernel::ElementB; + + CtaShape_MNK cta_shape; + static constexpr int SmCount = 16; + static constexpr int MultiplierOffsetM = 1; + static constexpr int MultiplierOffsetN = 2; + static constexpr int MultiplierOffsetK = 3; + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + + float waves[] = {0.5, 1.25, 2.5}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + problem_size_k = {256 + max_alignment * MultiplierOffsetK, 512 + max_alignment * MultiplierOffsetK}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + [[maybe_unused]] constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + bool passed = true; + + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + int num_grid = int(wave * SmCount); + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = num_grid / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = num_grid / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment; + int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment; + + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (DecompositionMode decomp_mode : decomposition_modes) { + std::vector problem_splits = {detail::Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(detail::Splits{2}); + } + for (auto splits : problem_splits) { + + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e2m1_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF4Format::E2M1 && + runtime_input_datatype_b == cute::UMMA::MXF4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + else { + std::cout << "Unsupported configuration for runtime datatype MXFP4." << std::endl; + return false; + } + } + + else + if constexpr (cute::is_same_v && + cute::is_same_v) { + static_assert((cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v), + "Runtime datatype must be selected with an appropriate static umbrella data type."); + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e4m3_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + // f6xf4 + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e3m2_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e2m1_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E2M1 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e4m3_e3m2 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E3M2) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e3m2_e2m3 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M3) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e5m2_e5m2 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e4m3_e5m2 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e5m2_e4m3 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e4m3_e4m3 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for runtime datatype."); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // splits + } // decomposition_mode + } // k + } // waves + + return passed; +} + +template +bool TestSmall(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + CtaShape_MNK cta_shape; + Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); + static constexpr int SmCount = 16; + static constexpr int MultiplierOffsetM = 1; + static constexpr int MultiplierOffsetN = 2; + static constexpr int MultiplierOffsetK = 3; + int max_alignment_k = 0; + int max_alignment_m = 0; + int max_alignment_n = 0; + + if constexpr (apply_alignment_offset) { + max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + max_alignment_n = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + max_alignment_m = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + } + // Alignment for SFD + if constexpr (detail::IsSfdEpi::value) { + using GmemLayoutTagScalefactor = typename Gemm::GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::GmemLayoutTagScalefactor; + constexpr int SFDVecSize = Gemm::GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::SFVecSize; + if constexpr (cute::is_same_v) { + max_alignment_n = std::lcm(max_alignment_n, SFDVecSize); + } + else { + max_alignment_m = std::lcm(max_alignment_m, SFDVecSize); + } + } + + float waves[] = {0.5, 1.25, 2.5}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + problem_size_k = {256 + max_alignment_k * MultiplierOffsetK, 512 + max_alignment_k * MultiplierOffsetK}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + bool passed = true; + + std::vector raster_order_options = {RasterOrderOptions::Heuristic}; + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + int num_grid = int(wave * SmCount); + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = num_grid / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = num_grid / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment_m; + int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment_n; + int l = test_batched_alpha_beta && wave == waves[0] && k == problem_size_k[0] ? 2 : 1; // only test the smallest problem size + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, l}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (DecompositionMode decomp_mode : decomposition_modes) { + for (RasterOrderOptions raster_order : raster_order_options) { + std::vector problem_splits = {detail::Splits{1}}; + if constexpr (UsesStreamKScheduler) { + if (decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(detail::Splits{2}); + problem_splits.push_back(detail::Splits{4}); + } + } + for (auto splits : problem_splits) { + try { + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, // raster_order + detail::MaxSwizzleSize(0), + splits, + decomp_mode + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception (unknown)"; + throw; + } + EXPECT_TRUE(passed) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} failed"; + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << l << " FAILED.\n"; + return false; + } + } // splits + } // raster_order + } // decomposition_mode + } // k + } // waves + + return passed; +} + +template +bool TestSmallFusion(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + return TestSmall(alpha, + beta, + check_relative_equality, + use_device_scalars, + vector_scale_mode, + override_problem_size_k); +} + + + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAll(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed(check_relative_equality, ScalarLoc::ON_HOST, VectorScale::DISABLED); + + int max_alignment_m = std::max({Gemm::kAlignmentA, Gemm::kAlignmentC, Gemm::kAlignmentD}); + int max_alignment_n = std::max({Gemm::kAlignmentB, Gemm::kAlignmentC, Gemm::kAlignmentD}); + if constexpr (std::is_base_of_v) { + max_alignment_m = std::max(max_alignment_m, Gemm::EpilogueOutputOp::AlignmentAux); + max_alignment_n = std::max(max_alignment_n, Gemm::EpilogueOutputOp::AlignmentAux); + } + std::vector problem_size_m = {max_alignment_m, 512 - 3 * max_alignment_m}; + std::vector problem_size_n = {max_alignment_n, 512 - 2 * max_alignment_n}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + int max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_k = {max_alignment_k, TileShapeK * (Stages + 1) - max_alignment_k}; + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + std::vector problem_splits = {detail::Splits{1}}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + problem_splits.push_back(detail::Splits{2}); + problem_splits.push_back(detail::Splits{3}); + + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + + // Use larger K sizes for stream-K tests + static constexpr int min_tiles_per_sk_unit = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::min_iters_per_sk_unit_; + problem_size_k = {TileShapeK * min_tiles_per_sk_unit, TileShapeK * 3 * min_tiles_per_sk_unit - max_alignment_k}; + } + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + std::vector raster_orders = {RasterOrderOptions::AlongM, RasterOrderOptions::AlongN}; + std::vector max_swizzle_sizes{detail::MaxSwizzleSize{1}, detail::MaxSwizzleSize{4}}; + + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (auto raster_order : raster_orders) { + for (auto max_swizzle_size : max_swizzle_sizes) { + for (DecompositionMode decomp_mode : decomposition_modes) { + + std::vector problem_splits = {detail::Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + auto max_splits = (k + TileShapeK - 1) / TileShapeK; + if (max_splits > 2) { + problem_splits.push_back(detail::Splits{2}); + } + if (max_splits > 3) { + problem_splits.push_back(detail::Splits{3}); + } + + problem_splits.push_back(detail::Splits{max_splits}); + + // Test the case in which we ask for more splits than there are K tiles in the GEMM. In this + // case, split-K will fall back to a splitting factor of `max_splits`. + problem_splits.push_back(detail::Splits{max_splits + 1}); + } + for (auto splits : problem_splits) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + try { + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, + max_swizzle_size, + splits, + decomp_mode + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception (unknown)"; + throw; + } + + EXPECT_TRUE(passed) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} failed"; + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // splits + } // decomposition_mode + } // max_swizzle_size + } // raster_order + } // k + } // n + } // m + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment_m, 256 + max_alignment_n, 160 + max_alignment_k, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +template +bool TestAllBiasElementwise(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::EXACT) { + return TestAll(alpha, beta, check_relative_equality); +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f18a7b39cbfe7dfb8d3251b2750e49261522de8a --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp @@ -0,0 +1,1742 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Testbed and host reference for EVT unittest +*/ + + +#pragma once +#include "gemm_testbed_3x.hpp" + +namespace test { +namespace gemm { +namespace device { + +/// Host-side tapply, tapply in cute is HOST_DEVICE +template +constexpr auto +tapply(T&& t, F&& f, G&& g, cute::seq) +{ + return g(f(std::get(static_cast(t)))...); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT: Base class for EVT Node + +template < class ElementCompute_ > +class HostEVTNodeBase { +public: + using ElementCompute = ElementCompute_; + +private: + bool check_relative_equality_; + // Factors used for calculating relative equality. These default + // values are borrowed from those used by default in the CUTLASS + // profiler for performing relative equality checks. + float epsilon_ = 0.05f; + float nonzero_floor_ = 1.0f / 256.0f; + +public: + HostEVTNodeBase(){} + HostEVTNodeBase(bool check_relative_equality): + check_relative_equality_(check_relative_equality) { } + + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + if (check_relative_equality_) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, Element(epsilon_), Element(nonzero_floor_) + ); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + void* get_tensor_C_ptr() { + return nullptr; + } + + void* get_tensor_D_ptr() { + return nullptr; + } + + bool compare_reference(std::stringstream& error_ss) { + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Accumulator + +template< class ElementCompute = float > +class HostAccumulator: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + + struct Arguments { }; + +public: + HostAccumulator(){} + template + HostAccumulator(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(check_relative_equality) {} + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + cutlass::NumericConverter accumulator_converter; + return accumulator_converter(acc); + } + + Arguments get_arguments() { + return Arguments{}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Scalar Broadcast + +template < + int Value, + int BroadcastCount = 1, + class StrideMNL = cute::Stride, + template class ReductionFn = cutlass::multiplies, + class ElementCompute = float +> +class HostScalarBroadcast : public HostEVTNodeBase { +public: + + using Base = HostEVTNodeBase; + struct Arguments { + ElementCompute scalar[BroadcastCount] = {0}; + ElementCompute const* scalar_ptrs[BroadcastCount] = { nullptr }; + StrideMNL dScalar[BroadcastCount] = {}; + }; +private: + ElementCompute scalar_{}; + StrideMNL dScalar{}; + ElementCompute scalar_reduced_{}; +public: + HostScalarBroadcast(){} + + template + HostScalarBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality), scalar_(ElementCompute(Value)) { + scalar_ = ElementCompute(Value); + scalar_reduced_ = scalar_; + for (int i = 1; i < BroadcastCount; ++i) { + scalar_reduced_ = ReductionFn{}(scalar_reduced_, ElementCompute(Value)); + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return scalar_reduced_; + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss << "Scalar: " << float(scalar_) << "\n\n"; + return true; + } + + Arguments get_arguments() { + if constexpr (BroadcastCount == 1) + return Arguments{{scalar_}, {nullptr}, {dScalar}}; + else if constexpr (BroadcastCount == 2) + return Arguments{{scalar_, scalar_}, {nullptr, nullptr}, {dScalar, dScalar}}; + else if constexpr (BroadcastCount == 3) + return Arguments{{scalar_, scalar_, scalar_}, {nullptr, nullptr, nullptr}, {dScalar, dScalar, dScalar}}; + else + return Arguments{{scalar_}, {nullptr}, {dScalar}}; + } + + auto get_flatten_arguments() { + if constexpr (BroadcastCount == 1) { + return cute::make_tuple(scalar_, nullptr); + } + else if constexpr (BroadcastCount == 2) { + return cute::make_tuple(scalar_, scalar_, nullptr, nullptr); + } + else if constexpr (BroadcastCount == 3) { + return cute::make_tuple(scalar_, scalar_, scalar_, nullptr, nullptr, nullptr); + } + else { + return cute::make_tuple(scalar_, nullptr); + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Row Broadcast +template < + typename ElementBias_, + typename StrideMNL = cute::Stride, + typename ElementCompute = float +> +class HostRowBroadcast: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementBias = ElementBias_; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + ElementBias const* ptr_row = nullptr; + ElementBias null_default = ElementBias(0); + StrideMNL dRow = {}; + }; +private: + cutlass::NumericConverter bias_converter_; + cutlass::HostTensor bias_; + int N_; +public: + HostRowBroadcast(){} + template + HostRowBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + N_ = cute::get<1>(problem_shape_MNKL); + bias_.resize(cutlass::Coord<1>(N_)); + + EXPECT_TRUE( + detail::initialize_tensor( + bias_.host_view(), cutlass::Distribution::Uniform, + seed + ) + ); + bias_.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + auto TensorBias = cute::make_tensor(bias_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, N_))); + + return bias_converter_(TensorBias(1, n + n_b)); + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss + << "PerColumnBias = \n" << bias_.host_view() << "\n\n"; + return true; + } + + Arguments get_arguments() { + return {bias_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(bias_.device_data(), ElementBias(0), StrideMNL{}); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Column Broadcast +template < + typename ElementBias_, + typename StrideMNL = cute::Stride, + typename ElementCompute = float +> +class HostColBroadcast: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementBias = ElementBias_; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + ElementBias const* ptr_row = nullptr; + ElementBias null_default = ElementBias(0); + StrideMNL dRow = {}; + }; +private: + cutlass::NumericConverter bias_converter_; + cutlass::HostTensor bias_; + int M_; +public: + HostColBroadcast(){} + template + HostColBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + M_ = cute::get<0>(problem_shape_MNKL); + bias_.resize(cutlass::Coord<1>(M_)); + + EXPECT_TRUE( + detail::initialize_tensor( + bias_.host_view(), cutlass::Distribution::Uniform, + seed + ) + ); + bias_.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + auto TensorBias = cute::make_tensor(bias_.host_data(), + cute::make_layout(cute::make_shape(M_, cute::_1{}))); + + return bias_converter_(TensorBias(m + m_b, 1)); + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss + << "PerRowBias = \n" << bias_.host_view() << "\n\n"; + return true; + } + + Arguments get_arguments() { + return {bias_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(bias_.device_data(), ElementBias(0), StrideMNL{}); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Aux Load + +template < + typename ElementAuxLoad_, + typename LayoutTagAux_, + bool isC = false, + typename ElementCompute = float +> +class HostAuxLoad: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementAuxLoad = ElementAuxLoad_; + using LayoutTagAux = LayoutTagAux_; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + struct Arguments_Aux { + ElementAuxLoad const *ptr_aux = nullptr; + ElementAuxLoad null_default = ElementAuxLoad(0); + StrideAux dAux = {}; + }; + + struct Arguments_C {}; + + using Arguments = cute::conditional_t; + +private: + cutlass::NumericConverter aux_load_converter_; + cutlass::HostTensor tensor_aux_load_; + + int M_, N_, L_; + + StrideAux stride_aux_; +public: + HostAuxLoad(){} + template + HostAuxLoad(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_NMKL = cute::append<4>(problem_size, 1); + auto [M_, N_, K, L_] = problem_shape_NMKL; + auto aux_coord = cutlass::make_Coord(M_ * L_, N_); + tensor_aux_load_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + EXPECT_TRUE( + detail::initialize_tensor( + tensor_aux_load_.host_view(), + cutlass::Distribution::Uniform, + seed + ) + ); + tensor_aux_load_.sync_device(); + stride_aux_ = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(M_, N_, L_)); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + + auto TensorAuxLoad = cute::make_tensor(tensor_aux_load_.host_data(), + cute::make_layout(cute::make_shape(M_, N_, L_), stride_aux_)); + return aux_load_converter_(TensorAuxLoad(m + m_b, n + n_b, l)); + } + + bool compare_reference(std::stringstream& error_ss) { + if constexpr (!isC) { + error_ss + << "AuxLoad = \n" << tensor_aux_load_.host_view()<< "\n\n"; + } + return true; + } + + void* get_tensor_C_ptr() { + if constexpr (isC) { + return static_cast(tensor_aux_load_.device_data()); + } + else { + return nullptr; + } + } + + Arguments get_arguments() { + if constexpr (isC) + return {}; + else + return {tensor_aux_load_.device_data(), ElementAuxLoad(0), stride_aux_}; + } + + auto get_flatten_arguments() { + if constexpr (isC) + return cute::make_tuple(); + else + return cute::make_tuple(tensor_aux_load_.device_data(), ElementAuxLoad(0), stride_aux_); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Compute + +template +T* findNonNullPtr(T* first_ptr) { + return first_ptr; +} + +template +T* findNonNullPtr(T* first_ptr, Args... args) { + if (first_ptr) { + return first_ptr; + } + return findNonNullPtr(args...); +} + +template < + template class ComputeOp_, + typename ElementCompute = float +> +class HostCompute: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ComputeOp = ComputeOp_; + + struct Arguments { + struct OpArgs {} op; + }; +private: + ComputeOp op_; +public: + HostCompute(){} + template + HostCompute(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, Args... frg_inputs) { + return op_(frg_inputs...); + } + + Arguments get_arguments(){ + return {}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Aux Store + +template < + class ElementAuxStore_, + typename LayoutTagAux_, + bool isD = false, + bool isRelu = false, + typename ElementCompute = float +> +class HostAuxStore: public HostEVTNodeBase { +public: + using ElementAuxStore = ElementAuxStore_; + using LayoutTagAux = LayoutTagAux_; + + using Base = HostEVTNodeBase; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + struct Arguments_Aux { + struct OpArgs { + ElementAuxStore* ptr_aux = nullptr; + StrideAux dAux = {}; + } op; + }; + + struct Arguments_D {}; + + using Arguments = cute::conditional_t; + + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_aux_store_; + cutlass::HostTensor reference_aux_store_; + int M_, N_, L_; + StrideAux stride_aux_; +public: + HostAuxStore(){} + template + HostAuxStore(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M_, N_, K, L_] = problem_shape_MNKL; + auto aux_coord = cutlass::make_Coord(M_ * L_, N_); + tensor_aux_store_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + + reference_aux_store_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + tensor_aux_store_.sync_device(); + stride_aux_ = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(M_, N_, L_)); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + + auto TensorAuxStore = cute::make_tensor(detail::make_iterator(static_cast(reference_aux_store_.host_data())), + cute::make_layout(cute::make_shape(M_, N_, L_), stride_aux_)); + if constexpr (isRelu) + TensorAuxStore(m + m_b, n + n_b, l) = destination_converter_(child_0_result >= 0); + else + TensorAuxStore(m + m_b, n + n_b, l) = destination_converter_(child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_aux_store_.sync_host(); + + bool equal = this->equality_check(reference_aux_store_.host_view(), tensor_aux_store_.host_view()); + if (!equal) { + error_ss + << "\n\nReference =\n" << reference_aux_store_.host_view() + << "\n\nComputed =\n" << tensor_aux_store_.host_view() << "\n\n"; + } + return equal; + } + + void* get_tensor_D_ptr() { + if constexpr (isD) + return static_cast(tensor_aux_store_.device_data()); + else + return nullptr; + } + + Arguments get_arguments() { + if constexpr (isD) { + return {}; + } + else { + return {tensor_aux_store_.device_data(), stride_aux_}; + } + } + + auto get_flatten_arguments() { + if constexpr (isD) { + return cute::make_tuple(); + } + else { + return cute::make_tuple(tensor_aux_store_.device_data(), stride_aux_); + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Row Reduce + +template < + template class ReduceFn, + typename ElementReduce, + bool FinalReduction = true, // Should match the FinalReduction in Device type + typename CtaTileShapeMNK = cute::Shape, + typename ElementCompute = float +> +class HostRowReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementDst = cute::conditional_t; + + static constexpr int TileM = cute::get<0>(CtaTileShapeMNK{}); + static constexpr int TileN = cute::get<1>(CtaTileShapeMNK{}); + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_row = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dRow = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_row_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_row_reduce_; + int N_; + ReduceFn reduce_fn_; + + int extent_m_; + int extent_n_; + int extent_l_; +public: + HostRowReduce(){} + template + HostRowReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + N_ = cute::get<1>(problem_shape_MNKL); + if constexpr (FinalReduction) { + tensor_row_reduce_.resize(cutlass::Coord<1>(N_)); + reference_row_reduce_.resize(cutlass::Coord<1>(N_)); + reduce_buffer_.resize(cutlass::Coord<1>(N_)); + } + else { + auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{})); + extent_m_ = cute::get<0>(NumTile); + extent_n_ = cute::get<1>(NumTile) * TileN; + extent_l_ = cute::get<2>(NumTile); + auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_); + tensor_row_reduce_.resize(shape); + reference_row_reduce_.resize(shape); + reduce_buffer_.resize(shape); + } + + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + if constexpr (FinalReduction) { + auto TensorRowReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, N_))); + TensorRowReduce(1, n + n_b) = reduce_fn_(TensorRowReduce(1, n + n_b), child_0_result); + } + else { + auto TensorRowReduce = cute::make_tensor( + reduce_buffer_.host_data(), + cute::make_layout( + cute::make_shape(extent_m_, extent_n_, extent_l_), + cute::make_stride(extent_n_, 1, extent_m_ * extent_l_) + ) + ); + TensorRowReduce((m+m_b)/TileM, n+n_b, l) = reduce_fn_(TensorRowReduce((m+m_b)/TileM, n+n_b, l), child_0_result); + } + + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_row_reduce_.sync_host(); + + auto TensorRowReduce = cute::make_tensor(reference_row_reduce_.host_data(), + cute::make_layout(cute::make_shape(reference_row_reduce_.size()))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(reduce_buffer_.size()))); + + // Filling the reference tensor with the reduce buffer + for (uint64_t n = 0; n < size(TensorRowReduce); n ++) { + TensorRowReduce(n) = destination_converter_(TensorReduceBuffer(n)); + } + + bool equal = this->equality_check(reference_row_reduce_.host_view(), tensor_row_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nRow Reduce Reference =\n" << reference_row_reduce_.host_view() + << "\n\nRow Reduce Computed =\n" << tensor_row_reduce_.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {tensor_row_reduce_.device_data()}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Column Reduce + +template < + template class ReduceFn, + typename ElementReduce, + bool FinalReduction = true, // Should match the FinalReduction in Device type + typename CtaTileShapeMNK = cute::Shape, + typename ElementCompute = float +> +class HostColumnReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementDst = cute::conditional_t; + + static constexpr int TileM = cute::get<0>(CtaTileShapeMNK{}); + static constexpr int TileN = cute::get<1>(CtaTileShapeMNK{}); + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_col = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dRow = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_column_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_column_reduce_; + int M_; + ReduceFn reduce_fn_; + + int extent_m_; + int extent_n_; + int extent_l_; +public: + HostColumnReduce(){} + template + HostColumnReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + M_ = cute::get<0>(problem_shape_MNKL); + + if constexpr (FinalReduction) { + tensor_column_reduce_.resize(cutlass::Coord<1>(M_)); + reference_column_reduce_.resize(cutlass::Coord<1>(M_)); + reduce_buffer_.resize(cutlass::Coord<1>(M_)); + } + else { + auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{})); + extent_m_ = cute::get<0>(NumTile) * TileM; + extent_n_ = cute::get<1>(NumTile); + extent_l_ = cute::get<2>(NumTile); + auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_); + tensor_column_reduce_.resize(shape); + reference_column_reduce_.resize(shape); + reduce_buffer_.resize(shape); + } + + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorColReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(M_, cute::_1{}))); + if constexpr (FinalReduction) { + TensorColReduce(m + m_b, 1) = reduce_fn_(TensorColReduce(m + m_b, 1), child_0_result); + } + else { + auto shape = reduce_buffer_.extent(); + auto TensorColReduce = cute::make_tensor( + reduce_buffer_.host_data(), + cute::make_layout( + cute::make_shape(extent_m_, extent_n_, extent_l_), + cute::make_stride(1, extent_m_, extent_m_ * extent_l_) + ) + ); + TensorColReduce(m+m_b, (n+n_b)/TileN, l) = reduce_fn_(TensorColReduce(m+m_b, (n+n_b)/TileN, l), child_0_result); + } + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_column_reduce_.sync_host(); + + auto TensorColReduce = cute::make_tensor(reference_column_reduce_.host_data(), + cute::make_layout(cute::make_shape(reference_column_reduce_.size()))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(reduce_buffer_.size()))); + + // Filling the reference tensor with the reduce buffer + for (uint64_t m = 0; m < size(TensorColReduce); m ++) { + TensorColReduce(m) = destination_converter_(TensorReduceBuffer(m)); + } + + bool equal = this->equality_check(reference_column_reduce_.host_view(), tensor_column_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nColumn Reduce Reference =\n" << reference_column_reduce_.host_view() + << "\n\nColumn Reduce Computed =\n" << tensor_column_reduce_.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {tensor_column_reduce_.device_data()}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Scalar Reduce + +template < + template class ReduceFn, + typename ElementReduce, + typename ElementCompute = float, + bool enabled = true +> +class HostScalarReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_scalar = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dScalar = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_scalar_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_scalar_reduce_; + ReduceFn reduce_fn_; +public: + HostScalarReduce(){} + template + HostScalarReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + tensor_scalar_reduce_.resize(cutlass::Coord<1>(1)); + reference_scalar_reduce_.resize(cutlass::Coord<1>(1)); + reduce_buffer_.resize(cutlass::Coord<1>(1)); + + tensor_scalar_reduce_.sync_device(); + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorRowReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + TensorRowReduce(0) = reduce_fn_(TensorRowReduce(0), child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + if constexpr (enabled) { + // Verify the store node + tensor_scalar_reduce_.sync_host(); + + auto TensorRowReduce = cute::make_tensor(reference_scalar_reduce_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + + // Filling the reference tensor with the reduce buffer + TensorRowReduce(0) = destination_converter_(TensorReduceBuffer(0)); + + bool equal = this->equality_check(reference_scalar_reduce_.host_view(), tensor_scalar_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nScalar Reduce Reference =\n" << reference_scalar_reduce_.host_view() + << "\n\nScalar Reduce Computed =\n" << tensor_scalar_reduce_.host_view() << "\n\n"; + } + return equal; + } + else { + return true; + } + + } + + Arguments get_arguments() { + return {tensor_scalar_reduce_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(tensor_scalar_reduce_.device_data()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Host EVT wrapper + +/// The ArgumentPack is used to model the alignment when num ops <= 4 +template +struct ArgumentPack; + +template +struct ArgumentPack { + T arg; + ArgumentPack(T first): + arg(first) {} +}; + +template +struct ArgumentPack { + First arg; + ArgumentPack rest_args; + + ArgumentPack(First first, Rest... rest) : + arg(first), rest_args(rest...) {} +}; + + +/// Base class for Host Visitor +template +struct HostVisitorBase: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + + using Arguments_struct = ArgumentPack; + using Arguments_tuple = cute::tuple; + + constexpr static int Rm1 = sizeof...(Ops); + constexpr static bool cond = Rm1 > 4; + using Arguments = cute::conditional_t; + + std::tuple ops; + + HostVisitorBase(){} + template + HostVisitorBase(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(check_relative_equality), + ops(test::gemm::device::tapply(std::tuple{}, + [&] (auto&& op) { + using Op = cute::remove_cvref_t; + return Op(problem_size, check_relative_equality, seed); + }, + [] (auto&&... _ops) { + return std::make_tuple(_ops...); + }, + cute::make_seq{} + )){ } + + bool compare_reference(std::stringstream& error_ss) { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.compare_reference(error_ss); + }, + [&] (auto&&... inputs) { + return arrayAnd(inputs...); + }, + cute::make_seq{} + ); + } + + void* get_tensor_C_ptr() { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.get_tensor_C_ptr(); + }, + [&] (auto&&... inputs) { + return findNonNullPtr(inputs...); + }, + cute::make_seq{} + ); + } + + void* get_tensor_D_ptr() { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.get_tensor_D_ptr(); + }, + [&] (auto&&... inputs) { + return findNonNullPtr(inputs...); + }, + cute::make_seq{} + ); + } + + Arguments get_arguments() { + return test::gemm::device::tapply(ops, + [&](auto& op) { + return op.get_arguments(); + }, + [&] (auto&&... args) { + if constexpr (Rm1 > 4) { + return cute::make_tuple(args...); + } + else { + return Arguments(args...); + } + }, + cute::make_seq{} + ); + } + + auto get_flatten_arguments() { + return test::gemm::device::tapply(ops, + [&](auto& op) { + return op.get_flatten_arguments(); + }, + [&] (auto&&... args) { + return flatten(cute::make_tuple(args...)); + }, + cute::make_seq{} + ); + } + + bool arrayAnd(bool passed) { + return passed; + } + + template + bool arrayAnd(bool first_passed, Args... passed) { + if (first_passed) { + return arrayAnd(passed...); + } + return first_passed; + } + +}; + + +/// Tree-struct visitor +template +struct HostTreeVisitor: public HostVisitorBase { +public: + using ElementCompute = typename NodeOp::Base::ElementCompute; + using Base = HostVisitorBase; + using Arguments = typename Base::Arguments; + + constexpr static int Rm1 = sizeof...(ChildOps); + + HostTreeVisitor(){} + template + HostTreeVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed){ } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + return cute::detail::tapply(this->ops, + [&] (auto& op) { + return op.visit(m, n, l, m_b, n_b, acc); + }, + [&] (auto&&... frg_inputs) { + return std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); + }, + cute::make_seq{} + ); + } +}; + + +/// General Graph visitor +template +struct HostTopoVisitor: public HostVisitorBase { +public: + using Base = HostVisitorBase; + constexpr static int Rm1 = Base::Rm1; + using Arguments = typename Base::Arguments; + +private: + ElementCompute frg_outputs_[Rm1]; +public: + HostTopoVisitor(){} + template + HostTopoVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed) { } + + template + ElementCompute visit_( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + frg_outputs_[I] = cute::transform_apply(cute::get(EdgeTuple{}), + [&] (auto&& _E) { + constexpr int e = cute::remove_cvref_t::value; + return frg_outputs_[e]; + }, + [&] (auto const&... frg_inputs) { + ElementCompute res = std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); + return res; + } + ); + + if constexpr (I < Rm1 - 1) { + return visit_(m, n, l, m_b, n_b, acc); + } + else { + return frg_outputs_[I]; + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return visit_(m, n, l, m_b, n_b, acc); + } + +}; + + +/// SplitTree visitor +template +struct HostSplitTreeVisitor: public HostVisitorBase { +public: + using Base = HostVisitorBase; + using Arguments = typename Base::Arguments; + + constexpr static int Rm2 = sizeof...(AuxOutTrees); + +private: + ElementCompute frg_input_; +public: + HostSplitTreeVisitor(){} + template + HostSplitTreeVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed) { } + + template + void visitAux( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator frag) { + std::get(this->ops).visit(m, n, l, m_b, n_b, frag); + + if constexpr (I < Rm2 - 1) { + return visitAux(m, n, l, m_b, n_b, frag); + } + else { + return; + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + /// Compute the input tree + frg_input_ = std::get<0>(this->ops).visit(m, n, l, m_b, n_b, acc); + + /// Compute the aux out tree + visitAux(m, n, l, m_b, n_b, frg_input_); + /// Visit the output tree + return std::get(this->ops).visit(m, n, l, m_b, n_b, frg_input_); + } +}; + +/// Universal testbed for EVT w/o smem +template +class Testbed3xEVTnoSmem { +public: + // The EVT Module to test + using EVTModule = EVT; //typename EVT::EVTModule; + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementC = typename Kernel::ElementC; + using ElementD = typename Kernel::ElementD; + + using ProblemShapeType = typename Kernel::ProblemShape; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + // + // Methods + // + Testbed3xEVTnoSmem( + bool check_relative_equality_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed ) : + impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(check_relative_equality_) { } + + Testbed3xEVTnoSmem( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B tensor + // + impl_.initialize(problem_size); + } + // Detail Implementation + TestBedImpl impl_; + + // Whether to use relative equality checks + bool check_relative_equality; + + bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { + + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + /// Reference Kernel + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + /// Epilogue EVT + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { + host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); + } + } + } + } + } + } + + std::stringstream error_ss; + bool passed = host_reference.compare_reference(error_ss); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K + << ", Batch count = " << L << "\n\n"; + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view(); + + file << error_ss.str(); + } + + return passed; + } + + bool run( + ProblemShapeType problem_size, + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, + int iterations = 20, + bool profiling = false) { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the Gemm operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + else { + scheduler_args = { static_cast(max_swizzle), raster_order }; + } + + /// Initializes data structures + /// A/B/C/D Tensor + initialize(problem_size); + + /// Initialize the epilogue arguments + EVTModule host_reference(problem_size, check_relative_equality, 2024); + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b + }, + {}, + hw_info, + scheduler_args + }; + + // Filling in the thread arguments + if constexpr (FlatArgs) { + auto epilogue_args = host_reference.get_flatten_arguments(); + std::memcpy(&arguments.epilogue.thread, &epilogue_args, sizeof(epilogue_args)); + + arguments.epilogue.ptr_C = static_cast(host_reference.get_tensor_C_ptr()); + arguments.epilogue.dC = impl_.collective_epilogue.stride_c; + + arguments.epilogue.ptr_D = static_cast(host_reference.get_tensor_D_ptr()); + arguments.epilogue.dD = impl_.collective_epilogue.stride_d; + } + else { + auto epilogue_args = host_reference.get_arguments(); + std::memcpy(&arguments.epilogue, &epilogue_args, sizeof(epilogue_args)); + } + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, host_reference); + if (!passed) { + std::cout << "Error : Failed \n"; + } + + return passed; + } +}; + +/// Universal testbed for EVT +template +class Testbed3xEVT { +public: + // The EVT Module to test + using EVTModule = typename EVT::EVTModule; + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementC = typename Kernel::ElementC; + using ElementD = typename Kernel::ElementD; + + using ProblemShapeType = typename Kernel::ProblemShape; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + + // + // Methods + // + Testbed3xEVT( + bool check_relative_equality_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(check_relative_equality_) { } + + Testbed3xEVT( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + Testbed3xEVT( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, stride_factor_B_, stride_factor_C_, stride_factor_D_, + CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B tensor + // + impl_.initialize(problem_size); + } + // Detail Implementation + TestBedImpl impl_; + + // Whether to use relative equality checks + bool check_relative_equality; + + bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { + + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + /// Reference Kernel + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + /// Epilogue EVT + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { + host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); + } + } + } + } + } + } + + std::stringstream error_ss; + bool passed = host_reference.compare_reference(error_ss); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K + << ", Batch count = " << L << "\n\n"; + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() + << "\nC =\n" << impl_.collective_epilogue.tensor_C.host_view() << "\n\n"; + + file << error_ss.str(); + } + + return passed; + } + + bool run( + ProblemShapeType problem_size, + bool profiling = false, + int iterations = 20, + int splits = 1) { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the Gemm operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { splits }; + } + + /// Initializes data structures + /// A/B/C/D Tensor + initialize(problem_size); + + /// Initialize the epilogue arguments + EVTModule host_reference(problem_size, check_relative_equality, 2024); + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b + }, + { // Epilogue arguments + {}, // thread + static_cast(host_reference.get_tensor_C_ptr()), + impl_.collective_epilogue.stride_c, + static_cast(host_reference.get_tensor_D_ptr()), + impl_.collective_epilogue.stride_d + }, // Epilogue arguments end + hw_info, + scheduler_args + }; + + // Filling in the thread arguments + typename EVTModule::Arguments epilogue_args = host_reference.get_arguments(); + std::memcpy(&arguments.epilogue.thread, &epilogue_args.arg, sizeof(epilogue_args.arg)); + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, host_reference); + if (!passed) { + std::cout << "Error : Failed \n"; + } + + return passed; + } +}; + +template +bool TestAllEVT(bool check_relative_equality = false) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3xEVT testbed(check_relative_equality); + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run(problem_size); + + if (!passed) { + return false; + } + } + } + } + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbc54ec582d88d9039968d8153cf6127a06ec274 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -0,0 +1,2409 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Testbed for Ptr-Array and Grouped GEMM interface +*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/complex.h" +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" +#include "cute/layout.hpp" +#include "cute/numeric/int.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class ScalarLoc { + ON_HOST = 0, + ON_DEVICE = 1 +}; + +enum class VectorScale { + DISABLED = 0, + ENABLED = 1 +}; + +enum class CheckEquality { + EXACT = 0, + RELATIVE = 1 +}; + +namespace detail{ + +// Helper classes that take default data type when +// the Gemm::EpilogueOutputOp does not have ElementCompute +// and ElementScalar. +// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) +template +struct ElementComputeType { + using Type = Default; +}; + +template +struct ElementComputeType> { + using Type = typename Gemm::EpilogueOutputOp::ElementCompute; +}; + +template +struct ElementScalarType { + using Type = Default; +}; + +template +struct ElementScalarType> { + using Type = typename Gemm::EpilogueOutputOp::ElementScalar; +}; + + +template +struct IsF8F6F4Kernel { + static constexpr bool value = false; +}; + +template +struct IsF8F6F4Kernel> { + static constexpr bool value = true; +}; + + +// The maximum swizzle size to use +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +class MaxSwizzleSize { +public: + MaxSwizzleSize() = default; + + template && + !cute::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +template +struct IsDefaultEpilogue { + static constexpr bool value = false; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +// The number of splits to test. +// +// This class makes it harder to confuse the order of arguments +// of the various run(...) functions in this file. The constructor +// is explicit, so one can't just type 42 (or false, which the +// compiler unhelpfully turns into 0); one has to type Splits(42). +// Splits() picks the default number of splits, 1. +// +// The conversion-to-int operator (operator int()) MUST be explicit! +// Conversion to int MUST require static_cast. +// Otherwise, that defeats a key purpose of this class, +// which is to catch common errors of confusing the order +// of function arguments. +class Splits { +public: + Splits() = default; + + template && + !cute::is_same_v)) > + explicit Splits(IntegralNotBool splits) : splits_(splits) {} + explicit operator int() const { return splits_; } +private: + int splits_ = 1; +}; + +// The number of iterations to test. +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +// Iterations() picks the default number of iterations, 20. +class Iterations { +public: + Iterations() = default; + + template && + !cute::is_same_v)) > + explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} + explicit operator int() const { return iterations_; } +private: + int iterations_ = 20; +}; + +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + + else if (bits_input <= 8) { + + if constexpr ( + cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + + scope_max = 1; + scope_min = -1; + + } + + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Looks at Cute Stride to check Row / Column Major +template +static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); +} + + +// +// Default MMA input Operands : A , B +// +template< + class ScheduleType_, + class Gemm, + class ElementA_ = typename Gemm::GemmKernel::ElementA, + class ElementB_ = typename Gemm::GemmKernel::ElementB> +struct HostCollectiveMainloop { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; + + std::vector stride_a_host; + std::vector stride_b_host; + + cutlass::DeviceAllocation stride_a_device; + cutlass::DeviceAllocation stride_b_device; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + std::vector> tensors_A; + std::vector> tensors_B; + cutlass::DeviceAllocation device_tensors_A; + cutlass::DeviceAllocation device_tensors_B; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_), + check_relative_equality(check_relative_equality_) { } + + bool initialize(ProblemShapeType problem_shapes) { + // + // Allocate the GEMM workspace + // + // for pointer array problem_shapes.groups() is 1 + + tensors_A.clear(); + tensors_B.clear(); + stride_a_host.clear(); + stride_b_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + for(int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N); + + tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); + tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); + EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_A[i].host_view().at({0, 0}) = ElementA(1); + tensors_B[i].host_view().at({0, 0}) = ElementB(1); + + tensors_A[i].sync_device(); + tensors_B[i].sync_device(); + } + + return true; + } + + Arguments to_args(ProblemShapeType problem_shapes) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + std::vector ptr_A_host(L); + std::vector ptr_B_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_A_host.at(i) = tensors_A[i].device_data(); + ptr_B_host.at(i) = tensors_B[i].device_data(); + } + + device_tensors_A.reset(L); + device_tensors_A.copy_from_host(ptr_A_host.data()); + + device_tensors_B.reset(L); + device_tensors_B.copy_from_host(ptr_B_host.data()); + + stride_a_device.reset(problem_shapes.groups()); + stride_a_device.copy_from_host(stride_a_host.data()); + stride_b_device.reset(problem_shapes.groups()); + stride_b_device.copy_from_host(stride_b_host.data()); + + Arguments arguments; + + if constexpr (IsGroupGemm) { + arguments + = + { + device_tensors_A.get(), stride_a_device.get(), device_tensors_B.get(), stride_b_device.get() + }; + } + else { + arguments = + { + device_tensors_A.get(), stride_a_host[0], device_tensors_B.get(), stride_b_host[0] + }; + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), + make_layout(make_shape(M, K, 1), stride_a_host[batch])); + auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), + make_layout(make_shape(N, K, 1), stride_b_host[batch])); + + cutlass::reference::host::GettMainloopParams mainloop_params{}; + + mainloop_params.A = A; + mainloop_params.B = B; + mainloop_params.transform_A = TransformA; + mainloop_params.transform_B = TransformB; + + return mainloop_params; + } + + void print_tensors(std::ofstream& file, int batch) { + file << "A =\n" << tensors_A[batch].host_view() + << "\nB =\n" << tensors_B[batch].host_view(); + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, int batch) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); + + bool passed = true; + return passed; + } +}; + + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using InternalLayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using InternalLayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + std::vector stride_a_host; + std::vector stride_b_host; + cutlass::DeviceAllocation stride_a_device; + cutlass::DeviceAllocation stride_b_device; + + std::vector layout_sfa_host; + std::vector layout_sfb_host; + cutlass::DeviceAllocation layout_sfa_device; + cutlass::DeviceAllocation layout_sfb_device; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + std::vector> tensors_A; + std::vector> tensors_B; + std::vector> tensors_SFA; + std::vector> tensors_SFB; + + cutlass::DeviceAllocation device_tensors_A; + cutlass::DeviceAllocation device_tensors_B; + cutlass::DeviceAllocation device_tensors_SFA; + cutlass::DeviceAllocation device_tensors_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_shapes) { + // + // Allocate the GEMM workspace + // + + tensors_A.clear(); + tensors_B.clear(); + stride_a_host.clear(); + stride_b_host.clear(); + tensors_SFA.clear(); + tensors_SFB.clear(); + layout_sfa_host.clear(); + layout_sfb_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N); + + tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); + tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); + EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_A[i].host_view().at({0, 0}) = ElementA(1); + tensors_B[i].host_view().at({0, 0}) = ElementB(1); + + tensors_A[i].sync_device(); + tensors_B[i].sync_device(); + + using namespace cute; + + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1))); + layout_sfb_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1))); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{}, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{}, k_blks * Blk_SF{}); + + tensors_SFA.push_back(cutlass::HostTensor(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A))); + tensors_SFB.push_back(cutlass::HostTensor(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_SFA[i].host_view(), init_A, seed + 2024 + i)); + EXPECT_TRUE(initialize_tensor(tensors_SFB[i].host_view(), init_B, seed + 2025 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_SFA[i].host_view().at({0, 0}) = ElementSF(1); + tensors_SFB[i].host_view().at({0, 0}) = ElementSF(1); + + tensors_SFA[i].sync_device(); + tensors_SFB[i].sync_device(); + } + + return true; + } + + Arguments to_args(ProblemShapeType problem_shapes) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + std::vector ptr_A_host(L); + std::vector ptr_B_host(L); + std::vector ptr_SFA_host(L); + std::vector ptr_SFB_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_A_host.at(i) = tensors_A[i].device_data(); + ptr_B_host.at(i) = tensors_B[i].device_data(); + ptr_SFA_host.at(i) = tensors_SFA[i].device_data(); + ptr_SFB_host.at(i) = tensors_SFB[i].device_data(); + } + + device_tensors_A.reset(L); + device_tensors_A.copy_from_host(ptr_A_host.data()); + + device_tensors_B.reset(L); + device_tensors_B.copy_from_host(ptr_B_host.data()); + + device_tensors_SFA.reset(L); + device_tensors_SFA.copy_from_host(ptr_SFA_host.data()); + + device_tensors_SFB.reset(L); + device_tensors_SFB.copy_from_host(ptr_SFB_host.data()); + + stride_a_device.reset(problem_shapes.groups()); + stride_a_device.copy_from_host(stride_a_host.data()); + + stride_b_device.reset(problem_shapes.groups()); + stride_b_device.copy_from_host(stride_b_host.data()); + + layout_sfa_device.reset(problem_shapes.groups()); + layout_sfa_device.copy_from_host(layout_sfa_host.data()); + + layout_sfb_device.reset(problem_shapes.groups()); + layout_sfb_device.copy_from_host(layout_sfb_host.data()); + + if constexpr (IsGroupGemm) { + return Arguments{ + device_tensors_A.get(), stride_a_device.get(), + device_tensors_B.get(), stride_b_device.get(), + device_tensors_SFA.get(), layout_sfa_device.get(), + device_tensors_SFB.get(), layout_sfb_device.get() + }; + } + else { + return Arguments{ + device_tensors_A.get(), stride_a_host[0], + device_tensors_B.get(), stride_b_host[0], + device_tensors_SFA.get(), layout_sfa_host[0], + device_tensors_SFB.get(), layout_sfb_host[0] + }; + } + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), + make_layout(make_shape(M, K, 1), stride_a_host[batch])); + auto SfA = make_tensor(tensors_SFA[batch].host_data(), layout_sfa_host[batch]); + + auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), + make_layout(make_shape(N, K, 1), stride_b_host[batch])); + auto SfB = make_tensor(tensors_SFB[batch].host_data(), layout_sfb_host[batch]); + + return cutlass::reference::host::GettMainloopParams + {A, SfA, B, SfB}; + } + + void print_tensors(std::ofstream& file, int batch) { + file << "A =\n" << tensors_A[batch].host_view() + << "\nB =\n" << tensors_B[batch].host_view() + << "\nSFA =\n" << tensors_SFA[batch].host_view() + << "\nSFB =\n" << tensors_SFB[batch].host_view(); + } + + bool compare_reference( + ProblemShapeType problem_shapes, int batch) { + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFA[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFB[batch].host_view()), 0); + return true; + } +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +template +struct HostCollectiveDefaultEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using InternalStrideD = typename kernel::InternalStrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + using InternalStrideC = typename kernel::InternalStrideC; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using FusionOp = typename Gemm::EpilogueOutputOp; + + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + cutlass::DeviceAllocation stride_c_device; + cutlass::DeviceAllocation stride_d_device; + + std::vector stride_c_host; + std::vector stride_d_host; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + ElementScalar alpha; + ElementScalar beta; + + std::vector> tensors_C; + std::vector> tensors_D; + std::vector> references_D; + cutlass::DeviceAllocation device_tensors_C; + cutlass::DeviceAllocation device_tensors_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveDefaultEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + + tensors_C.clear(); + tensors_D.clear(); + references_D.clear(); + stride_c_host.clear(); + stride_d_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); + stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M, N); + + tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); + tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); + references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); + EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); + tensors_C[i].host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); + tensors_C[i].sync_device(); + tensors_D[i].sync_device(); + } + alpha = alpha_; + beta = beta_; + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + tensors_D[batch].sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); + + if (tensors_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); + } + + if (references_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); + } + + bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + std::vector ptr_C_host(L); + std::vector ptr_D_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_C_host.at(i) = tensors_C[i].device_data(); + ptr_D_host.at(i) = tensors_D[i].device_data(); + } + + device_tensors_C.reset(L); + device_tensors_C.copy_from_host(ptr_C_host.data()); + + device_tensors_D.reset(L); + device_tensors_D.copy_from_host(ptr_D_host.data()); + + stride_c_device.reset(problem_shapes.groups()); + stride_c_device.copy_from_host(stride_c_host.data()); + + stride_d_device.reset(problem_shapes.groups()); + stride_d_device.copy_from_host(stride_d_host.data()); + + Arguments arguments; + if constexpr (IsGroupGemm) { + arguments = + { + {alpha, beta}, + device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + }; + } + else { + arguments = + { + {alpha, beta}, + device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + }; + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + L = std::max(problem_shapes.groups(), L); + + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); + auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D)> + epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha; + epilogue_params.beta = beta; + + return epilogue_params; + } +}; + +template +struct HostCollectiveEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using InternalStrideD = typename kernel::InternalStrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + using InternalStrideC = typename kernel::InternalStrideC; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + + // + // FusionOperation derived types/queries + // + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool IsLegacy = + cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize> + >; + + using FusionOp = typename Gemm::EpilogueOutputOp; + static_assert(cute::is_base_of_v); + + + // Scale factor Generation related + using SfStrategy = cutlass::reference::host::SfStrategy; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; + static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; + using ElementSFD = non_void_t, ElementD>; + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + SFD_VectorSize + >; + using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; + std::vector> tensors_SFD; + std::vector> references_SFD; + cutlass::DeviceAllocation device_tensors_SFD; + + using ElementCompute = typename FusionOp::ElementCompute; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementBias = non_void_t; + using ElementAux = non_void_t; + using ElementAmax = non_void_t; + using LayoutTagAux = non_void_t; + using ActivationFunctor = non_void_t>; + + static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; + static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; + static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; + static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + cutlass::DeviceAllocation stride_c_device; + cutlass::DeviceAllocation stride_d_device; + + std::vector stride_c_host; + std::vector stride_d_host; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + cutlass::HostTensor alpha; + cutlass::HostTensor beta; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor bias; + std::vector> tensors_C; + cutlass::DeviceAllocation device_tensors_C; + cutlass::HostTensor norm_constant; + + // Outputs + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + std::vector> tensors_Aux; + cutlass::DeviceAllocation device_tensors_Aux; + cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + std::vector> tensors_D; + std::vector> references_D; + cutlass::DeviceAllocation device_tensors_D; + + // References + cutlass::HostTensor reference_dbias; + std::vector> references_Aux; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors + cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; + // Random distribution with which to initialize the bias vector + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + + tensors_C.clear(); + tensors_D.clear(); + references_D.clear(); + stride_c_host.clear(); + stride_d_host.clear(); + + tensors_SFD.clear(); + references_SFD.clear(); + + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); + stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); + + auto c_coord = cutlass::make_Coord(M, N); + tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); + tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); + references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); + EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); + tensors_C[i].host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); + tensors_C[i].sync_device(); + tensors_D[i].sync_device(); + } + + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); + if constexpr (IsPerRowScaleEnabled) { + alpha.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (vector_scale_mode == VectorScale::DISABLED) { + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + else { + beta.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + } + else { + alpha.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + beta.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + alpha.sync_device(); + beta.sync_device(); + + if constexpr (IsScaleFactorEnabled) { + scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); + EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); + EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); + EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + } + + if constexpr (IsBiasEnabled) { + bias.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); + bias.sync_device(); + } + + if constexpr (IsDeBiasEnabled) { + bias.resize(col_vector_coord); + reference_dbias.resize(col_vector_coord); + cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); + cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabledD) { + abs_max_D.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_D.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_D.sync_device(); + reference_abs_max_D.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); + } + + tensors_Aux.clear(); + references_Aux.clear(); + + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxInEnabled)); + + if constexpr (IsAuxInEnabled) { + auto aux_coord = cutlass::make_Coord(M, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + for (int32_t i = 0; i < L; ++i) { + tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); + EXPECT_TRUE(initialize_tensor(tensors_Aux[i].host_view(), init_C, seed + 2023)); + tensors_Aux[i].sync_device(); + } + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); + } + + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxOutEnabled)); + + if constexpr (IsAuxOutEnabled) { + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto aux_coord = cutlass::make_Coord(M, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); + references_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout, false)); + tensors_Aux[i].sync_device(); + } + + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); + + if constexpr (IsScaleFactorEnabled) { + scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); + scale_Aux.sync_device(); + } + + if constexpr (IsAbsMaxEnabledAux) { + abs_max_Aux.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_Aux.sync_device(); + reference_abs_max_Aux.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); + } + } + + + if constexpr (IsBlockScaleSupported) { + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, _] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + // If block scaled output is supported we always have at least 1 SFD + auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); + auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); + auto sfd_coord = [&] () { + return cutlass::make_Coord(m_blks * Blk_MN{}, n_blks * Blk_SF{}); + }(); + tensors_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D))); + references_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false)); + tensors_SFD[i].sync_device(); + } + norm_constant.resize(scalar_coord, true); + EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); + norm_constant.sync_device(); + } + + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) { + tensors_D[batch].sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); + + if (tensors_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); + } + + if (references_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); + } + + bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + std::vector ptr_C_host(L); + std::vector ptr_D_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_C_host.at(i) = tensors_C[i].device_data(); + ptr_D_host.at(i) = tensors_D[i].device_data(); + } + + device_tensors_C.reset(L); + device_tensors_C.copy_from_host(ptr_C_host.data()); + + device_tensors_D.reset(L); + device_tensors_D.copy_from_host(ptr_D_host.data()); + + stride_c_device.reset(problem_shapes.groups()); + stride_c_device.copy_from_host(stride_c_host.data()); + + stride_d_device.reset(problem_shapes.groups()); + stride_d_device.copy_from_host(stride_d_host.data()); + + std::vector ptr_Aux_host(L); + if constexpr (IsAuxInEnabled || IsAuxOutEnabled) { + for (int32_t i = 0; i < L; ++i) { + ptr_Aux_host.at(i) = tensors_Aux[i].device_data(); + } + device_tensors_Aux.reset(L); + device_tensors_Aux.copy_from_host(ptr_Aux_host.data()); + } + + auto device_tensors_C_ptr = cute::is_void_v ? nullptr : + reinterpret_cast(device_tensors_C.get()); + + Arguments arguments; + if constexpr (IsGroupGemm) { + arguments = + { + {}, + device_tensors_C_ptr, stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + }; + } + else { + arguments = + { + {}, + device_tensors_C_ptr, stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + }; + } + + auto &fusion_args = arguments.thread; + if constexpr (IsLegacy) { + arguments.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.ptr_Bias = bias.device_data(); + arguments.ptr_T = device_tensors_Aux.get(); + } + else { + fusion_args.alpha = alpha.at(coord_0); + fusion_args.beta = beta.at(coord_0); + + fusion_args.alpha_ptr = alpha.device_data(); + // can_implement requires beta_ptr to not be set if its voidC + fusion_args.beta_ptr = cute::is_void_v ? nullptr : + beta.device_data(); + + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); + } + + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } + + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); + } + + // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty + if constexpr (cute::is_same_v>) { + fusion_args.activation.scale = ElementCompute(1); + } + + // Treat Clamp as ReLU + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = 0; + fusion_args.activation.upper_bound = std::numeric_limits::max(); + } + + if constexpr (IsAbsMaxEnabledD) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = device_tensors_Aux.get(); + fusion_args.dAux = stride_Aux; + } + + if constexpr (IsAuxOutEnabled) { + fusion_args.aux_ptr = device_tensors_Aux.get(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabledAux) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } + } + + if constexpr (IsBlockScaleSupported) { + std::vector ptr_SFD_host(L); + for (int32_t i = 0; i < L; ++i) { + ptr_SFD_host.at(i) = tensors_SFD[i].device_data(); + } + device_tensors_SFD.reset(L); + device_tensors_SFD.copy_from_host(ptr_SFD_host.data()); + + arguments.thread.block_scale_factor_ptr = device_tensors_SFD.get(); + arguments.thread.norm_constant_ptr = norm_constant.device_data(); + } + + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto [M, N, K, L] = problem_shape_MNKL; + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); + auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); + auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + auto Aux_layout = cute::make_layout(cute::make_shape(M, N, 1), stride_Aux); + auto Aux = [&]() { + auto ptr = recast_ptr(nullptr); + if (IsAuxInEnabled) { + ptr = detail::make_iterator(tensors_Aux[batch].host_data()); + } else if (IsAuxOutEnabled) { + ptr = detail::make_iterator(references_Aux[batch].host_data()); + } + return cute::make_tensor(ptr, Aux_layout); + }(); + auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, M))); + auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, N))); + + auto SfD = [&](){ + if constexpr (IsBlockScaleSupported) { + auto tensor = make_tensor(detail::make_iterator(references_SFD[batch].host_data()), + Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + return tensor; + } + else { + // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. + return D; + } + }(); + + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + decltype(Bias), + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), + ActivationFunctor + , decltype(SfD) + , Int + , cutlass::plus + , false + , SfGenStrategy + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha.at(coord_0); + epilogue_params.beta = beta.at(coord_0); + + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_a = scale_A.at(coord_0); + epilogue_params.scale_b = scale_B.at(coord_0); + epilogue_params.scale_c = scale_C.at(coord_0); + epilogue_params.scale_d = scale_D.at(coord_0); + } + + if constexpr (IsBiasEnabled or IsDeBiasEnabled) { + epilogue_params.Bias = Bias; + } + + if constexpr (IsAbsMaxEnabledD) { + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + } + + if constexpr (IsAuxInEnabled) { + epilogue_params.Aux = Aux; + } + + if constexpr (IsAuxOutEnabled) { + epilogue_params.Aux = Aux; + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_aux = scale_Aux.at(coord_0); + } + if constexpr (IsAbsMaxEnabledAux) { + epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); + } + } + + if constexpr (IsPerRowScaleEnabled) { + epilogue_params.Valpha = Valpha; + if (vector_scale_mode == VectorScale::ENABLED) { + epilogue_params.Vbeta = Vbeta; + } + } + + if constexpr (IsBlockScaleSupported) { + epilogue_params.SfD = SfD; + epilogue_params.st = norm_constant.at(coord_0); + } + + return epilogue_params; + } +}; + +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB +> +struct TestbedImpl { + // Kernel data types + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type + using HostCollectiveMainloopType = HostCollectiveMainloop; + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, + HostCollectiveEpilogue>; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; + using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; + using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; + using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + uint32_t sm_count; + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr uint32_t mma_promotion_interval = 4; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + HostCollectiveMainloopType collective_mma_inputs; + CollectiveEpilogue collective_epilogue; + + static constexpr bool IsGroupGemm = CollectiveEpilogue::IsGroupGemm; + + // + // Methods + // + + TestbedImpl( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + /// Initializes data structures + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + collective_mma_inputs.initialize(problem_shapes); + collective_epilogue.initialize(problem_shapes, alpha_, beta_); + + return true; + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) + { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + + bool passed = collective_mma_inputs.compare_reference(problem_shapes, batch); + passed &= collective_epilogue.compare_reference(problem_shapes, alpha, beta, batch); + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << batch << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << batch + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + collective_mma_inputs.print_tensors(file, batch); + collective_epilogue.print_tensors(file, batch); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta) + { + using namespace cute; + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + bool passed = true; + for (int32_t i = 0; i < L; ++i) { + auto mainloop_params = collective_mma_inputs.to_host_args(problem_shapes, i); + auto epilogue_params = collective_epilogue.to_host_args(problem_shapes, i); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + passed &= compare_reference(problem_shapes, alpha, beta, i); + } + return passed; + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } + + return true; + } + + /// Executes one test + bool run( + ProblemShapeType problem_shapes, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + detail::Iterations iterations = detail::Iterations{} + ) + { + + // Fail test if insufficient CUDA device + if (!sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + + if (!this->initialize(problem_shapes, alpha, beta)) { + std::cerr << "Initialization failed \n"; + return false; + } + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; + + typename HostCollectiveMainloopType::Arguments mainloop_args; + + mainloop_args = collective_mma_inputs.to_args(problem_shapes); + + if constexpr (IsGroupGemm) { + arguments = + { + cutlass::gemm::GemmUniversalMode::kGrouped, + problem_shapes, + mainloop_args, + collective_epilogue.to_args(problem_shapes), + hw_info + }; + } + else { + arguments = + { + cutlass::gemm::GemmUniversalMode::kArray, + problem_shapes, + mainloop_args, + collective_epilogue.to_args(problem_shapes), + hw_info + }; + } + + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return false; + } + + // + // Run the GEMM + // + + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_shapes, alpha, beta); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta + << "\n"; + } + + return passed; + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB +> +struct Testbed3x { + + using TestBedImpl = typename detail::TestbedImpl< + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, + ElementB + >; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + static constexpr bool IsGroupGemm = TestBedImpl::IsGroupGemm; + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3x( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_shapes, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + detail::Iterations iterations = detail::Iterations{} + ) + { + return impl_.run( + problem_shapes, alpha, beta, iterations); + } +}; + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed(check_relative_equality, ScalarLoc::ON_DEVICE, VectorScale::DISABLED); + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + int batches[] = {5, 10}; + + bool passed = true; + + for (int batch : batches) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + + if constexpr (Testbed3x::IsGroupGemm) { + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + for (int i = 0; i < batch; ++i) { + problem_sizes_host.push_back({m * ((i % 3) + 1), n * ((i % 4) + 1), k * ((i % 5) + 1)}); + } + + problem_sizes_device.reset(problem_sizes_host.size()); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + passed = testbed.run( + ProblemShapeType{static_cast(problem_sizes_host.size()), problem_sizes_device.get(), problem_sizes_host.data()}, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + else { + ProblemShapeType problem_size{{m, n, k, batch}}; + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << batch << " FAILED.\n"; + return false; + } + } // k + } // n + } // m + } // batch + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSmall(double alpha = 1.0, double beta = 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ElementA = typename Gemm::GemmKernel::ElementA; + using ElementB = typename Gemm::GemmKernel::ElementB; + using TiledMma = typename Gemm::GemmKernel::TiledMma; + + static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4(); + // For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. + int alignment_bits_a = cutlass::detail::get_input_alignment_bits(); + int alignment_input_a = (alignment_bits_a / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_a / cute::sizeof_bits::value); + + int alignment_bits_b = cutlass::detail::get_input_alignment_bits(); + int alignment_input_b = (alignment_bits_b / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_b / cute::sizeof_bits::value); + + int alignment_input = (alignment_input_a == 0 || alignment_input_b == 0) ? 0 : std::max(alignment_input_a, alignment_input_b); + + if constexpr (apply_alignment_offset) { + // If BlockScaled, then min alignment is SFVecSize + static constexpr bool IsBlockScaleSupported = Gemm::EpilogueOutputOp::IsBlockScaleSupported; + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + if constexpr (IsBlockScaleSupported) { + alignment_input = cutlass::round_up(alignment_input, SFVecSize); + } + } + + + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + CtaShape_MNK cta_shape; + Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); + // For Ptr-Array and Grouped GEMM ideally we need to know SM count at runtime + static constexpr int SmCount = 16; + + float waves[] = {0.5, 2.5}; + int batches[] = {3}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + // this is to test with min alignment + problem_size_k = {256 - alignment_input, 512 + alignment_input}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + bool passed = true; + + for (int batch : batches) { + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + float num_grid = wave * SmCount; + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = static_cast(num_grid) / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = static_cast(num_grid) / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) - alignment_input; // this is just to test with unusual problem shapes + int n = grid_n * cute::size<1>(cta_shape) + alignment_input; + + if constexpr (Testbed3x::IsGroupGemm) { + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + for (int i = 0; i < batch; ++i) { + problem_sizes_host.push_back({m * ((i % 2) + 1), n * ((i % 3) + 1), k * ((i % 2) + 1)}); + } + problem_sizes_device.reset(problem_sizes_host.size()); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + ProblemShapeType problem_shapes{batch, problem_sizes_device.get(), problem_sizes_host.data()}; + + if (CUTLASS_DEBUG_TRACE_LEVEL > 0) { + for (int i = 0; i < batch; ++i) { + std::cout << "problem_shapes : " << problem_shapes.get_host_problem_shape(i) << " \n"; + } + } + passed = testbed.run( + problem_shapes, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + else { + ProblemShapeType problem_shapes{{m, n, k, batch}}; + if (CUTLASS_DEBUG_TRACE_LEVEL > 0) { + std::cout << "problem_shapes : " << problem_shapes.get_host_problem_shape() << " \n"; + } + passed = testbed.run( + problem_shapes, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // k + } // waves + } // batches + + return passed; +} + +template +bool TestSmallFusion(double alpha = 1.0, double beta = 0.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED) { + return TestSmall( + alpha, beta, check_relative_equality, use_device_scalars, vector_scale_mode); +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b00f98a97846de175f1c6f95919c483ab4b81da --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp @@ -0,0 +1,515 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with elementwise tensor-tensor broadcast epilogue +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "testbed_utils.h" +#include "gemm_testbed_3x.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Testbed3xTensorBroadcast { + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementA = typename Kernel::ElementA; + using StrideA = typename Kernel::StrideA; + using ElementB = typename Kernel::ElementB; + using StrideB = typename Kernel::StrideB; + using ElementC = typename Kernel::ElementC; + using StrideC = typename Kernel::StrideC; + using ElementD = typename Kernel::ElementD; + using StrideD = typename Kernel::StrideD; + + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementCompute = typename Epilogue::ElementCompute; + using ElementScalar = typename Epilogue::ElementScalar; + using ProblemShapeType = typename Kernel::ProblemShape; + using ElementBias = typename Epilogue::ElementBias; + using ActivationFunctor = typename Epilogue::ActivationFunctor; + + static constexpr bool IsBinaryOp0Enabled = Epilogue::IsBinaryOp0Enabled; + static constexpr bool IsBinaryOp1Enabled = Epilogue::IsBinaryOp1Enabled; + static constexpr bool IsUnaryOpEnabled = Epilogue::IsUnaryOpEnabled; + + static constexpr bool PerColBias = Epilogue::PerColumnBias; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + cutlass::HostTensor bias; + cutlass::HostTensor tensor_C1; + // tensor_C0 is taken from TestbedImpl's tensor_C + + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3xTensorBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_) { } + + Testbed3xTensorBroadcast( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, + stride_factor_B_, + stride_factor_C_, + stride_factor_D_, + CheckEquality::EXACT, ScalarLoc::ON_HOST, VectorScale::ENABLED, + init_A_, + init_B_, + init_C_, + cutlass::Distribution::Uniform, + cutlass::Distribution::Uniform, + seed_) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B/C/D tensor + // + impl_.initialize(problem_size); + } + + void initialize_bias(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto bias_size = PerColBias ? cute::get<1>(problem_shape_MNKL) : cute::get<0>(problem_shape_MNKL); + bias.resize(cutlass::Coord<1>(bias_size)); + + EXPECT_TRUE(detail::initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2023)); + bias.sync_device(); + } + + void initialize_c1(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto c_coord = cutlass::make_Coord(M * L, N); + + tensor_C1.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C)); + EXPECT_TRUE(detail::initialize_tensor(tensor_C1.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2024)); + tensor_C1.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta, + bool use_bias) + { + auto [M, N, K, L] = problem_shape_MNKL; + + impl_.collective_epilogue.tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_B.host_view()), 0); + + if (impl_.collective_epilogue.tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.tensor_D.host_view()), 0); + } + + if (impl_.collective_epilogue.reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.reference_D.host_view()), 0); + } + + bool passed = cutlass::reference::host::TensorEquals(impl_.collective_epilogue.reference_D.host_view(), impl_.collective_epilogue.tensor_D.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_broadcast" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias + << ", per-col bias: " << PerColBias << "\n\n"; + + if (use_bias){ + file << "Bias = \n" << bias.host_view()<< "\n\n"; + } + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() + << "\nC0 =\n" << impl_.collective_epilogue.tensor_C.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\n\nReference =\n" << impl_.collective_epilogue.reference_D.host_view() + << "\n\nComputed =\n" <(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto D = cute::make_tensor(impl_.collective_epilogue.reference_D.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); + auto Bias = cute::make_tensor(static_cast(use_bias ? bias.host_data() : nullptr), + cute::make_layout(PerColBias ? cute::make_shape(1, N) : cute::make_shape(M, 1))); + auto C0 = cute::make_tensor(impl_.collective_epilogue.tensor_C.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + auto C1 = cute::make_tensor(tensor_C1.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + + // Create host workspace for output of testbed. This computes a portion of the epilogue: + // ref_compute_out = Activation(alpha * (A @ B) + bias) + cutlass::HostTensor ref_compute_out; + auto c_coord = cutlass::make_Coord(M * L, N); + ref_compute_out.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C), false); + auto RefComputeOut = cute::make_tensor(ref_compute_out.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + // Use a dummy null tensor for operand C because the epilogue overrides C. + auto dummy_C = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + ElementCompute dummy_beta(0); + auto dummy_Aux = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); + auto dummy_Valpha = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); + auto dummy_Vbeta = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); + + auto dummy_SFD = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + using DummySFDVectorSize = cute::Int<0>; + + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(dummy_C), + decltype(RefComputeOut), + decltype(Bias), + decltype(dummy_Aux), + decltype(dummy_Valpha), + decltype(dummy_Vbeta), + ActivationFunctor, + decltype(dummy_SFD), + DummySFDVectorSize, + cutlass::plus, + PerColBias> epilogue_params{ + alpha, + dummy_beta, + dummy_C, + RefComputeOut, + Bias, + dummy_Aux, + dummy_Valpha, + dummy_Vbeta + }; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + cutlass::NumericConverter source_converter; + cutlass::NumericConverter destination_converter; + cutlass::multiplies mul; + + // Compute broadcast operations atop the reference + #pragma omp parallel for collapse(3) + for (int64_t l = 0; l < cute::size<2>(A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(A.layout()); ++m) { + for (int64_t n = 0; n < cute::size<0>(B.layout()); ++n) { + ElementCompute intermediate = RefComputeOut(m, n, l); + // Apply BinaryOp0, if needed + if constexpr (IsBinaryOp0Enabled) { + typename Epilogue::ThreadEpilogueOp::BinaryOp0 bin0; + ElementCompute converted_source = source_converter(C0(m, n, l)); + intermediate = bin0(intermediate, mul(beta, converted_source)); + } + + // Apply BinaryOp1, if needed + if constexpr (IsBinaryOp1Enabled) { + typename Epilogue::ThreadEpilogueOp::BinaryOp1 bin1; + ElementCompute converted_source = source_converter(C1(m, n, l)); + intermediate = bin1(intermediate, mul(beta, converted_source)); + } + + // Apply UnaryOp, if needed + if constexpr (IsUnaryOpEnabled) { + typename Epilogue::ThreadEpilogueOp::UnaryOp unary; + intermediate = unary(intermediate); + } + + D(m, n, l) = destination_converter(intermediate); + } + } + } + + return compare_reference(problem_shape_MNKL, alpha, beta, use_bias); + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + int iterations = 20, + bool use_bias = true) + { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + /// Initializes data structures + /// A/B/C0/D Tensor + initialize(problem_size); + initialize_bias(problem_size); + + if constexpr (IsBinaryOp1Enabled) { + initialize_c1(problem_size); + } + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b, + impl_.mma_promotion_interval + }, + { // Epilogue arguments + { alpha, beta }, // ThreadOp arguments + impl_.collective_epilogue.stride_c, + impl_.collective_epilogue.tensor_D.device_data(), + impl_.collective_epilogue.stride_d, + use_bias ? bias.device_data() : nullptr, + impl_.collective_epilogue.tensor_C.device_data(), + tensor_C1.device_data() + }, // Epilogue arguments end + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, alpha, beta, use_bias); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << float(alpha) + << ", beta: " << float(beta) + << ", use_bias: " << use_bias + << "\n"; + } + + return passed; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllTensorBroadcast(bool use_bias=true) { + using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3xTensorBroadcast testbed; + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (bool use_bias : {true, false}) { + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(1), + false, // profiling + 20, // iterations + use_bias + ); + + if (!passed) { + return false; + } + } + } + } + } + + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(1), + false, // profiling + 20 // iterations + ); + if (!passed) { + return false; + } + } + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..6ae7b864cb272782da4920ffc038830d3b5984b2 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct MultistageTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = + typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + MultistageTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {} + + /// Helper to initialize a tensor view + template + bool initialize_tensor(cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, uint64_t seed) { + if (dist_kind == cutlass::Distribution::Uniform) { + int scope = (cutlass::sizeof_bits::value == 8) ? 2 : 8; + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, + -scope, 0); + } else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, -1); + } else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), + view.capacity()); + } else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Waives test if CUDA device is insufficient + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run(cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waives test if CUDA device is insufficient + if (!sufficient()) { + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor + tensor_A(problem_size.mk()); + + cutlass::HostTensor + tensor_B(problem_size.kn()); + + cutlass::HostTensor + tensor_C(problem_size.mn()); + + cutlass::HostTensor + tensor_D(problem_size.mn()); + + cutlass::HostTensor + reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, tensor_A.device_ref(), tensor_B.device_ref(), + tensor_C.device_ref(), tensor_D.device_ref(), {alpha, beta}}; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, alpha, tensor_A.host_ref(), tensor_B.host_ref(), beta, + reference_D.host_ref(), ElementAccumulator(0)); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + + fname << "error_Gemm_device_" << problem_size.m() << "x" + << problem_size.n() << "x" << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" << Gemm::ThreadblockShape::kN + << "x" << Gemm::ThreadblockShape::kK << "_" << Gemm::WarpShape::kM + << "x" << Gemm::WarpShape::kN << "x" << Gemm::WarpShape::kK + << ".txt"; + + std::ofstream file(fname.str()); + + file << "problem: " << problem_size << ", alpha: " << alpha + << ", beta: " << beta << "\n\n"; + + file << "A =\n" + << tensor_A.host_view() << "\nB =\n" + << tensor_B.host_view() << "\nC =\n" + << tensor_C.host_view() << "\n\nReference =\n" + << reference_D.host_view() << "\nComputed =\n" + << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = {16, 528}; + + int problem_size_n[] = {16, 528}; + + int problem_size_k[] = {Gemm::InstructionShape::kK, + Gemm::ThreadblockShape::kK * Gemm::kStages + + Gemm::InstructionShape::kK}; + + double problem_alpha[] = {1.0}; + + // TODO Try non zero beta value after multistaged epilogue is implemented + double problem_beta[] = {0.0}; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + passed = + run({m, n, k}, ElementCompute(alpha), ElementCompute(beta)); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..e309208bb4311253be5b7366841164eb62748bab --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct MultistageInterleavedTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + MultistageInterleavedTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm::ElementA, + typename Gemm::LayoutA> tensor_A(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_C(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_D(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reorder_column( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); + + cutlass::reference::host::TensorCopy( + reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), + tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nB_reordered =\n" << tensor_B_reordered.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = { + InterleavedK, 512 + InterleavedK + }; + + int problem_size_n[] = { + InterleavedK, 512 + InterleavedK + }; + + int problem_size_k[] = { + InterleavedK, Gemm::ThreadblockShape::kK * Gemm::kStages + InterleavedK + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 0.0 + }; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = run( + {m, n, k}, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py new file mode 100644 index 0000000000000000000000000000000000000000..a180028205abb689436c73403eea82758ade7da9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py @@ -0,0 +1,341 @@ +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# this file creates the test/unit/gemm/device simt tests + + +outputDir = "" + +################################################################################ +# parameters +# Edge - for tiles, the edges represent the length of one side +# Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles +# MaxEdge - maximum length of each edge +# Min/Max - minimum/maximum of the product of edge lengths +################################################################################ + +warpsPerThreadblockEdge = [1, 2, 4, 8, 16] +warpsPerThreadblockRatio = 2 +warpsPerThreadblockMax = 16 +# NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases + +warpShapeEdges = [8, 16, 32, 64, 128, 256] +warpShapeRatio = 4 +warpShapeMax = 64*64 +warpShapeMin = 8*8 + +threadblockEdgeMax = 256 + +# char, type bits/elem, max tile, L0 threadblock tiles +precisions = [ + ["c", "cutlass::complex", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], + ["q", "cutlass::Quaternion", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], + ["d", "double", 64, 64*64, [ [ 64, 64], [ 32, 32] ] ], + ["h", "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ], + ["i", "int", 32, 128*128, [ [128, 64], [ 16, 32] ] ], + ["s", "float", 32, 128*128, [ [128, 256], [128, 128], [ 64, 64] ] ], + ["z", "cutlass::complex", 128, 64*64, [ [ 32, 64], [ 16, 32] ] ], + ] +# L1 will have a single kernel for every unique shape +# L2 will have everything else + +transposes = [ + [False, False], + [False, True], + [True, False], + [True, True] + ] + +################################################################################ +# warps per threadblock +################################################################################ +warpsPerThreadblocks = [] +for warpsPerThreadblock0 in warpsPerThreadblockEdge: + for warpsPerThreadblock1 in warpsPerThreadblockEdge: + if warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio and warpsPerThreadblock1 / warpsPerThreadblock0 <= warpsPerThreadblockRatio and warpsPerThreadblock0 * warpsPerThreadblock1 <= warpsPerThreadblockMax: + warpsPerThreadblocks.append([warpsPerThreadblock0, + warpsPerThreadblock1]) +print("WarpsPerThreadblocks",warpsPerThreadblocks) + +################################################################################ +# warp shapes +################################################################################ +warpNumThreads = 32 +warpShapes = [] +for warp0 in warpShapeEdges: + for warp1 in warpShapeEdges: + if warp0 / warp1 <= warpShapeRatio and warp1 / warp0 <= warpShapeRatio and warp0*warp1 <= warpShapeMax and warp0*warp1 > warpShapeMin: + warpShapes.append([warp0, warp1]) +print("WarpShapes", warpShapes) + +numL0 = 0 +numL1 = 0 +numL2 = 0 + +################################################################################ +# create kernels +# create a file for each precision/transpose +# each file contains many tile sizes +################################################################################ + +# precisions +for precision in precisions: + + # get precision char + precisionChar = precision[0] + precisionType = precision[1] + precisionBits = precision[2] + threadblockMaxElements = precision[3] + threadblockTilesL0 = precision[4] + + # transposes + for transpose in transposes: + + # get transpose char + columnMajorA = transpose[0] + columnMajorB = transpose[1] + transCharA = "n" if columnMajorA else "t" + transCharB = "n" if columnMajorB else "t" + + # open file + fileName="simt_%sgemm_%s%s_sm50.cu" % (precisionChar, transCharA, transCharB) + print("\n", fileName) + filePath = "%s%s" % (outputDir, fileName) + out = open(filePath, "w+") + + # write file header + out.write("/***************************************************************************************************\n" +" * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n" +" * SPDX-License-Identifier: BSD-3-Clause \n" +" * \n" +" * Redistribution and use in source and binary forms, with or without \n" +" * modification, are permitted provided that the following conditions are met: \n" +" * \n" +" * 1. Redistributions of source code must retain the above copyright notice, this \n" +" * list of conditions and the following disclaimer. \n" +" * \n" +" * 2. Redistributions in binary form must reproduce the above copyright notice, \n" +" * this list of conditions and the following disclaimer in the documentation \n" +" * and/or other materials provided with the distribution. \n" +" * \n" +" * 3. Neither the name of the copyright holder nor the names of its \n" +" * contributors may be used to endorse or promote products derived from \n" +" * this software without specific prior written permission. \n" +" * \n" +" * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" \n" +" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE \n" +" * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE \n" +" * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE \n" +" * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL \n" +" * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR \n" +" * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER \n" +" * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, \n" +" * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE \n" +" * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \n" +" *\n" +" **************************************************************************************************/\n" +"/*! \\file\n" +" \\brief Tests for device-wide GEMM interface\n" +"*/\n" +"\n" +"#include \n" +"\n" +"#include \"cutlass/cutlass.h\"\n" +"#include \"cutlass/gemm/device/gemm.h\"\n" +"#include \"cutlass/numeric_types.h\"\n" +"\n" +"#include \"../../common/cutlass_unit_test.h\"\n" +"\n" +"#include \"cutlass/util/host_tensor.h\"\n" +"#include \"cutlass/util/tensor_view_io.h\"\n" +"#include \"cutlass/util/reference/host/tensor_fill.h\"\n" +"#include \"cutlass/util/reference/host/tensor_copy.h\"\n" +"#include \"cutlass/util/reference/host/tensor_compare.h\"\n" +"#include \"cutlass/util/reference/host/gemm.h\"\n" +"\n" +"#include \"testbed.h\"\n" +"\n") + foundThreadblockTilesL0 = {} + foundThreadblockTilesL1 = {} + + ######################################################################## + # for each combination of tile sizes + ######################################################################## + for warpsPerThreadblock in warpsPerThreadblocks: + for warpShape in warpShapes: + warpThreadsM = 0 + if warpShape[0] > warpShape[1]: + warpThreadsM = 8 + else: + warpThreadsM = 4 + warpThreadsN = warpNumThreads / warpThreadsM + + # skip shapes with conflicting rectangularity + # they are unlikely to be fastest + blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1] + blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1] + warpG = warpShape[0] > warpShape[1] + warpL = warpShape[0] < warpShape[1] + + blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1]*2 + blockL2 = warpsPerThreadblock[0]*2 < warpsPerThreadblock[1] + warpG2 = warpShape[0] > warpShape[1]*2 + warpL2 = warpShape[0]*2 < warpShape[1] + + if blockG2 and warpL: continue + if blockL2 and warpG: continue + if warpG2 and blockL: continue + if warpL2 and blockG: continue + + # check threadblock ratios and max + threadblockTile = [warpShape[0]*warpsPerThreadblock[0], + warpShape[1]*warpsPerThreadblock[1]] + if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: continue + if threadblockTile[0] > threadblockEdgeMax: continue + if threadblockTile[1] > threadblockEdgeMax: continue + totalThreads = warpNumThreads*warpsPerThreadblock[0]*warpsPerThreadblock[1] + + # calculate unroll + # ensure that every iteration at least a full load of A,B are done + unrollMin = 8 + unrollMin0 = totalThreads / threadblockTile[0] + unrollMin1 = totalThreads / threadblockTile[1] + unroll = max(unrollMin, unrollMin0, unrollMin1) + + threadTileM = warpShape[0] / warpThreadsM + threadTileN = warpShape[1] / warpThreadsN + if threadTileM < 2 or threadTileN < 2: continue + if threadTileM*threadTileN*precisionBits > 8*8*32: continue + + # epilogue currently only supports N < WarpNumThreads + if threadblockTile[1] < warpNumThreads: continue + + # limit smem + smemBitsA = threadblockTile[0]*unroll*2*precisionBits + smemBitsB = threadblockTile[1]*unroll*2*precisionBits + smemKBytes = (smemBitsA+smemBitsB)/8/1024 + if (smemKBytes > 48): continue + + # test level 0 + testLevel = -1 + for tileId in range(0, len(threadblockTilesL0)): + tbTile = threadblockTilesL0[tileId] + if tbTile[0] == threadblockTile[0] and tbTile[1] == threadblockTile[1]: + if tuple(tbTile) not in foundThreadblockTilesL0: + testLevel = 0 + numL0 += 1 + foundThreadblockTilesL0[tuple(tbTile)] = True + + # test level 1 + if testLevel < 0: + threadblockTileAlreadyUsed = False + if tuple(threadblockTile) not in foundThreadblockTilesL1: + testLevel = 1 + numL1 += 1 + foundThreadblockTilesL1[tuple(threadblockTile)] = True + + # test level 2 + if testLevel < 0: + testLevel = 2 + numL2 += 1 + + ################################################################ + # write this tile to file + ################################################################ + + print("%ix%ix%i__%ix%i_%ix%i_%ix%i L%i" % ( + threadblockTile[0], threadblockTile[1], unroll, + threadTileM, threadTileN, + warpThreadsM, warpThreadsN, + warpsPerThreadblock[0], warpsPerThreadblock[1], testLevel)) + + out.write("////////////////////////////////////////////////////////////////////////////////\n" + "// Elements / Thread: %3i x %3i\n" + "// Threads / Warp: %3i x %3i\n" + "// Warps / Block: %3i x %3i\n" + "// Threadblock: %3i x %3i x %2i\n" + % ( threadTileM, threadTileN, + warpThreadsM, warpThreadsN, + warpsPerThreadblock[0], warpsPerThreadblock[1], + threadblockTile[0], threadblockTile[1], unroll + ) + ) + + out.write("CUTLASS_TEST_L%i(SM50_device_%sgemm_%s%s, %ix%ix%i_%ix%ix1_%ix%i_%ix%i_%ix%i, {\n" % ( + testLevel, + precisionChar, + transCharA, + transCharB, + threadblockTile[0], + threadblockTile[1], + unroll, + warpShape[0], + warpShape[1], + threadTileM, + threadTileN, + warpThreadsM, + warpThreadsN, + warpsPerThreadblock[0], + warpsPerThreadblock[1] + )) + out.write(" using precision = %s;\n" % precisionType) + out.write(" using ThreadblockShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n" % ( + threadblockTile[0], + threadblockTile[1], + unroll)) + out.write(" using WarpShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n\n" % ( + warpShape[0], + warpShape[1], + unroll)) + out.write(" static int const kEpilogueElementsPerAccess = 1;\n" + " using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;\n" + " using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<\n" + " precision, kEpilogueElementsPerAccess, precision, precision>;\n\n") + + out.write(" using Gemm = cutlass::gemm::device::Gemm<\n" + " precision, cutlass::layout::%sMajor,\n" + " precision, cutlass::layout::%sMajor,\n" + " precision, cutlass::layout::RowMajor,\n" + " precision,\n" + " cutlass::arch::OpClassSimt,\n" + " cutlass::arch::Sm50,\n" + " ThreadblockShape, WarpShape, InstructionShape,\n" + " EpilogueOutputOp,\n" + " cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,\n" + " 2 // Stages\n" + " >;\n" % ( + "Column" if columnMajorA else "Row", + "Column" if columnMajorB else "Row", + )) + out.write(" EXPECT_TRUE(test::gemm::device::TestAllGemm());\n" + "} )\n\n") + + + out.close() +print("NumKernels:", numL0, numL1, numL2) + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp new file mode 100644 index 0000000000000000000000000000000000000000..63ffc3281dd2b9e9f74e0024c73da00628331dd4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Host reference and operations for Sm90 EVT unit test +*/ +#pragma once +#include "gemm_testbed_3x_evt.hpp" + +////////////////////////////////////////////////////////////////////////////// +/// Host references used for testing +namespace test::gemm::device { +template +using HEVT = HostTreeVisitor; + +template +using HDAG = HostTopoVisitor; + +template +using HST = HostSplitTreeVisitor; + +/// D = alpha * acc + beta * C + AuxLoad +template +class HostEVTAuxLoad { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using AuxLoadNode = HostAuxLoad; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, AuxLoadNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// D = alpha * acc + beta * C + per-column bias +template +class HostPerColBias { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using RowBroadcastNode = HostRowBroadcast; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, RowBroadcastNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// D = beta * C + Graph(relu(alpha * acc + aux) + aux) +/// Testing EVT - DAG structure +template +class HostEVTDAG { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using AuxLoadNode = HostAuxLoad; + using DAGNode = HDAG< + float, + cute::tuple< + cute::tuple<>, // 0. alpha + cute::tuple<>, // 1. acc + cute::tuple<>, // 2. aux load + cute::tuple, // 3. alpha * acc + aux load + cute::tuple, // relu(alpha * acc + aux load) + cute::tuple // relu(alpha * acc + aux load) + aux load + >, + ScalarAlpha, + AccFetchNode, + AuxLoadNode, + HostCompute, + HostCompute, + HostCompute + >; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, DAGNode>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// EVT = alpha * acc + C +/// D = Graph(maximum(EVT + per-row bias, EVT)) +/// Testing DAG - EVT +template +class HostDAGEVT { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using EVTNode = HEVT< + HostAuxStore, + HEVT< + HostCompute, + HostScalarBroadcast<2>, + HostAccumulator<>, + HostAuxLoad + > + >; + using EVTModule = HEVT< + HostAuxStore, + HDAG< + float, + cute::tuple< + cute::tuple<>, // 0. EVT + cute::tuple<>, // 1. per-row bias + cute::tuple, // 2. EVT + per-row bias + cute::tuple // 3. maximum(EVT + per-row bias, EVT) + >, + EVTNode, + HostColBroadcast>, + HostCompute, + HostCompute + > + >; +}; + +/// Xreduce(alpha * acc + beta * C) +template +class HostReduce { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using BinaryCompute0 = HEVT, ScalarAlpha, AccFetchNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, BinaryCompute0>; + using ReduceNode = HEVT; + using EVTModule = HEVT, ReduceNode>; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template class ActivationFn, class ElementD> +class HostScaledLinCombPerRowBiasEltAct { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using EVTModule = HEVT< + HostAuxStore, + HEVT< + HostCompute::template Op>, // activation(Z) * scaled_d + HEVT< + HostCompute, // activation(Z) + HEVT< + HostCompute, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + > + >, + HostScalarBroadcast<1> // scale_d + > + >; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template class ActivationFn, class ElementD, class ElementAux = ElementD> +class HostScaledLinCombPerRowBiasEltActAmaxAux { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + template + using amax = cutlass::maximum_absolute_value_reduction; + using EVTModuleAuxFp8 = HEVT< + HostAuxStore, + HST, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + >, + // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) + HEVT< + HostCompute::template Op>, + HEVT< + HostScalarReduce, + HEVT< + HostCompute, //activation(Z) * scaled_d + HostAccumulator<> // Z + > + >, + HostScalarBroadcast<1> // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + HEVT< + HostAuxStore, + HEVT< + HostCompute, + HEVT< + HostScalarReduce, + HostAccumulator<> + >, + HostScalarBroadcast<1> + > + > + > + >; + + using EVTModuleAuxNotFp8 = HEVT< + // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) + HostAuxStore, + HEVT< + HostCompute::template Op>, + HEVT< + HostScalarReduce, + HEVT< + HostCompute, //activation(Z) * scaled_d + HEVT< + // Aux = Z + HostAuxStore, + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + HEVT< + HostCompute, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + > + > + > + >, + HostScalarBroadcast<1> // scale_d + > + >; + + using EVTModule = cute::conditional_t, EVTModuleAuxFp8, EVTModuleAuxNotFp8>; + +}; +} // namespace test::gemm::device + +////////////////////////////////////////////////////////////////////////////// +namespace cutlass::epilogue { +namespace fusion { + +namespace detail { + +template +struct maximum_with_default_nan_propagation : maximum {}; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + AuxLoad +template< + class EpilogueDescriptor, + class AuxLoadDescriptor, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombAuxLoad = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad< + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, + typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, + typename AuxLoadDescriptor::CopyOpS2R // aux load + > + > + >; + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + AuxLoadNoSmem +template< + class EpilogueDescriptor, + class ElementAux, + class StrideAux, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombAuxLoadNoSmem = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad<0, void, ElementAux, StrideAux, void, void> // aux load + > + >; + +////////////////////////////////////////////////////////////////////////////// +/// Example DAG +/// beta * C + Graph(alpha * acc + gamma + acc) +template< + typename EpilogueDescriptor, + typename AuxLoadDescriptor, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEVTDAG = + Sm90EVT, // beta * C + (alpha * acc + aux) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90TopologicalVisitor< + ElementCompute, + cute::tuple< + cute::seq<>, // 0. alpha + cute::seq<>, // 1. acc + cute::seq<>, // 2. aux load + cute::seq<1, 0, 2>, // 3. alpha * acc + aux load + cute::seq<3>, // relu(alpha & acc + aux load) + cute::seq<2, 4> // relu(alpha * acc + aux load) + aux load + >, + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad< + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride, + typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>, + Sm90Compute, + Sm90Compute, + Sm90Compute + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// Example DAG +/// EVT = alpha * acc + C +/// D = Graph(maximum(EVT + per-row bias, EVT)) +template< + class EpilogueDescriptor, + class AuxStoreDescriptor, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDAGEVT = + Sm90TopologicalVisitor< + ElementCompute, + cute::tuple< + cute::seq<>, + cute::seq<>, + cute::seq<1, 0>, + cute::seq<0, 2> + >, + Sm90EVT< + Sm90AuxStore< + AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride, + typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>, + Sm90EVT, + Sm90ScalarBroadcast, + Sm90AccFetch, + Sm90SrcFetch + > + >, + Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias, ElementCompute>, + Sm90Compute, + Sm90Compute + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + per-column bias +template< + class EpilogueDescriptor, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColumnBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias, ElementCompute> + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = per-column reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColumnReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = per-row reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = scalar reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombScalarReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; +} // namespace fusion + +} // namespace cutlass::epilogue diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..0007666cdd084f35015200e36fd47f75971f6c1c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h @@ -0,0 +1,639 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_utils.h" +#include "testbed_universal.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Testbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + typename Gemm::LayoutA::Stride stride_factor_A; + typename Gemm::LayoutB::Stride stride_factor_B; + typename Gemm::LayoutC::Stride stride_factor_C; + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + Testbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + stride_factor_A(typename Gemm::LayoutA::Stride()), + stride_factor_B(typename Gemm::LayoutB::Stride()), + stride_factor_C(typename Gemm::LayoutC::Stride()), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + Testbed( + typename Gemm::LayoutA::Stride stride_factor_A_, + typename Gemm::LayoutB::Stride stride_factor_B_, + typename Gemm::LayoutC::Stride stride_factor_C_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_C(stride_factor_C_), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), stride_factor_A)); + tensor_B.resize(problem_size.kn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), stride_factor_B)); + tensor_C.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); + tensor_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); + reference_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0) + << "tensor_D (size " << tensor_D.size() << ") has nonpositive norm"; + } + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0) + << "reference_D (size " << reference_D.size() << ") has nonpositive norm"; + } + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << "reference_D does not equal tensor_D"; + + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + if (Relu) { + for (int i = 0; i < problem_size.m(); ++i) { + for (int j = 0; j < problem_size.n(); ++j) { + reference_D.at(cutlass::MatrixCoord(i, j)) = + ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) + ? (typename Gemm::ElementC)0 + : reference_D.at(cutlass::MatrixCoord(i, j)); + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { +/* + std::cout << "\n-----------------------\n"; + std::cout << "problem size: " << problem_size << "\n"; + std::cout << "split_k_slices: " << split_k_slices << "\n"; + std::cout << "alpha: " << alpha << "\n"; + std::cout << "beta: " << beta << "\n"; + std::cout << "-----------------------\n\n"; +*/ + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) + << "gemm_op.initialize returned with error " << to_string(status) + << ", indicating that this test is not supported. Last CUDA error: " + << cudaGetErrorString(cudaGetLastError()); + if (status != cutlass::Status::kSuccess) { + return true; + } + + // + // Run the GEMM + // + + try { + status = gemm_op(); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "gemm_op() threw a std::exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "gemm_op() threw an exception of unknown type"; + throw; + } + EXPECT_TRUE(status == cutlass::Status::kSuccess) + << "gemm_op failed with error " << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + EXPECT_TRUE(passed) << "Error: split_k_slices = " << split_k_slices + << ", alpha: " << alpha; + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmBasic( + const typename Gemm::LayoutA::Stride& stride_factor_A = typename Gemm::LayoutA::Stride(), + const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), + const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) { + bool passed = true; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; + + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + + int problem_size_k[] = { + kAlignmentK, Gemm::ThreadblockShape::kK * (Gemm::kStages + 1) - kAlignmentK}; + + int split_k_slices[] = { + 1, 2, 3 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + Testbed testbed(stride_factor_A, stride_factor_B, stride_factor_C); + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + try { + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestAllGemmBasic: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestAllGemmBasic: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: (unknown)"; + throw; + } + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemm( + const typename Gemm::LayoutA::Stride& stride_factor_A, + const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), + const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) +{ + // Test basic GEMM with non-default stride factors + return TestAllGemmBasic(stride_factor_A, stride_factor_B, stride_factor_C); +} + +template +bool TestAllGemm() +{ +#ifdef NDEBUG + // Non-debug builds also test basic GEMM with default stride factors + if (!TestAllGemmBasic()) { + return false; + } +#endif // NDEBUG + + // Test universal GEMM +#if 0 + // Define the universal kernel + using UniversalKernel = cutlass::gemm::kernel::GemmUniversal< + typename Gemm::GemmKernel::Mma, // Mma + typename Gemm::GemmKernel::Epilogue, // Epilogue + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> // ThreadblockSwizzle + >; +#else + // Define the streamk universal kernel + using UniversalKernel = cutlass::gemm::kernel::GemmUniversalStreamk< + typename Gemm::GemmKernel::Mma, // Mma + typename Gemm::GemmKernel::Epilogue, // Epilogue + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK // ThreadblockSwizzle + >; +#endif + + // Define the universal adaptor + using UniversalGemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Test universal GEMM + return TestAllGemmUniversal(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmPerf(int iterations = 1) { + bool passed = true; + + int problem_size_m[] = { 2048 }; + + int problem_size_n[] = { 4352 }; + + int problem_size_k[] = { 4096 }; + + int split_k_slices[] = { 1 }; + double problem_alpha[] = { 1 }; + double problem_beta[] = { 0.0 }; + + Testbed testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + for (int i = 0; i < iterations; i++){ + try { + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestGemmPerf: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestGemmPerf: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: (unknown)"; + throw; + } + } + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..add984ca3b9a0c05325b93cf52cbadd710527ba6 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedComplex : public Testbed { + + using Base = Testbed; + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + + // + // Methods + // + + TestbedComplex( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + Base(init_A_, init_B_, init_C_, seed_) { } + + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex( + problem_size, + alpha, + this->tensor_A.host_ref(), + Gemm::kTransformA, + this->tensor_B.host_ref(), + Gemm::kTransformB, + beta, + this->tensor_C.host_ref(), + this->reference_D.host_ref(), + ElementAccumulator(0) + ); + + return this->compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Initialize workspace + // + + this->initialize(problem_size); + + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + this->tensor_A.device_ref(), + this->tensor_B.device_ref(), + this->tensor_C.device_ref(), + this->tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmComplex() { + bool passed = true; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = + cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + int problem_size_m[] = { + kAlignment, 512 - 3*kAlignment + }; + + int problem_size_n[] = { + kAlignment, 512 - 2*kAlignment + }; + + int problem_size_k[] = { + kAlignment, 128 - kAlignment + }; + + int split_k_slices[] = { + 1, 2, 3 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + TestbedComplex testbed; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..eca0b0ae0decf3293f6f73cb6ebbc5b5735a8e49 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -0,0 +1,670 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithBroadcastReferenceOp { + + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + GemmWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { + + ElementCompute t_full = binary_op(gemm, bias); + + if (OutputOp::kStoreT) { + T = ElementT(t_full); + } + + if (OutputOp::kStoreZ) { + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = GEMM(AB, C) +// +// T[i, j] = BinaryOp(Y[i, j], Broadcast[i]) +// +// Z[i, j] = Elementwise(T[i, j]) +// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +struct TestbedGemmWithBroadcast { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename OutputOp::ElementCompute; + using ElementVector = typename OutputOp::ElementVector; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; // Input A + cutlass::HostTensor tensor_B; // Input B + cutlass::HostTensor tensor_C; // Input C + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + + cutlass::HostTensor tensor_Z; + cutlass::HostTensor tensor_T; + + cutlass::HostTensor tensor_C_ref; + cutlass::HostTensor tensor_Y_ref; + cutlass::HostTensor tensor_Z_ref; + cutlass::HostTensor tensor_T_ref; + + + // + // Methods + // + + TestbedGemmWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 1; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_Z.resize(problem_size.mn()); + tensor_T.resize(problem_size.mn()); + tensor_Broadcast.resize({ + problem_size.m(), + 1 + }); + + tensor_C_ref.resize(problem_size.mn()); + tensor_Y_ref.resize(problem_size.mn()); + tensor_Z_ref.resize(problem_size.mn()); + tensor_T_ref.resize(problem_size.mn()); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Broadcast.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { + for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { + tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + + tensor_Z.sync_device(); + tensor_T.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + tensor_Z.sync_host(); + tensor_T.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (OutputOp::kStoreZ) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z_ref.host_view()), 0); + } + + if (OutputOp::kStoreT) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T_ref.host_view()), 0); + } + + bool passed = true; + float norm_diff = 0; + + if (OutputOp::kStoreZ) { + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Z_ref.host_view(), tensor_Z.host_view(), float()); + passed = (norm_diff <= 0.1f); + EXPECT_LT(norm_diff, 0.1f) << " tensor_Z is incorrect"; + } + + if (OutputOp::kStoreT) { + + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_T_ref.host_view(), tensor_T.host_view(), float()); + passed = (passed && (norm_diff <= 0.1f)); + + EXPECT_LT(norm_diff, 0.1f) << " tensor_T is incorrect"; + } + + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("errors_testbed_gemm_with_broadcast.txt"); + + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nZ =\n" << tensor_Z.host_view() + << "\nT =\n" << tensor_T.host_view() + << "\n\n" + << "\nY_ref =\n" << tensor_Y_ref.host_view() + << "\nZ_ref =\n" << tensor_Z_ref.host_view() + << "\nT_ref =\n" << tensor_T_ref.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + ElementAccumulator, typename Gemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C_ref.host_ref(), + tensor_Y_ref.host_ref(), + ElementAccumulator(0) + ); + + using ElementC = typename Gemm::ElementC; + + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int m = 0; m < problem_size.m(); ++m) { + for (int n = 0; n < problem_size.n(); ++n) { + + ElementZ z; + ElementT t; + + reference_op(z, t, tensor_Y_ref.at({m, n}), tensor_Broadcast.at({m, 0})); + + if (OutputOp::kStoreZ) { + tensor_Z_ref.at({m, n}) = z; + } + + if (OutputOp::kStoreT) { + tensor_T_ref.at({m, n}) = t; + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementAccumulator alpha = ElementAccumulator(1), + ElementAccumulator beta = ElementAccumulator(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_Z.device_data(), + tensor_Broadcast.device_data(), + tensor_T.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_Z.layout().stride(0), + 0, // This must be zero + tensor_T.layout().stride(0), + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = true; + + passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + // + // Profile + // + + #if 0 // profiling disabled for now. + + int const kWorkspaces = 100; + + cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Broadcast(tensor_Broadcast.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Z(tensor_Z.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_T(tensor_T.capacity() * kWorkspaces); + + cudaEvent_t events[2]; + for (auto & event : events) { + cudaError_t result = cudaEventCreate(&event); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); + return false; + break; + } + } + + int const kWarmupIterations = 5; + int const kProfilingIterations = 100; + + for (int i = 0; i < kWarmupIterations; ++i) { + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + + cudaError_t result = cudaEventRecord(events[0]); + EXPECT_EQ(result, cudaSuccess); + + for (int i = 0; i < kProfilingIterations; ++i) { + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), + profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), + profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), + profiling_tensor_Z.get() + tensor_Z.capacity() * (i % kWorkspaces), + profiling_tensor_Broadcast.get() + tensor_Broadcast.capacity() * (i % kWorkspaces), + profiling_tensor_T.get() + tensor_T.capacity() * (i % kWorkspaces), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_Z.layout().stride(0), + 0, // This must be zero + tensor_T.layout().stride(0), + }; + + gemm_op.initialize(arguments, workspace.get()); + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + result = cudaEventRecord(events[1]); + EXPECT_EQ(result, cudaSuccess); + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess); + + float elapsed_time = 0; + result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); + EXPECT_EQ(result, cudaSuccess); + + double average_time = double(elapsed_time) / double(kProfilingIterations); + + std::cout << problem_size << ": " << average_time << " ms" << std::endl; + + for (auto & event : events) { + cudaEventDestroy(event); + } + #endif + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +bool TestGemmWithBroadcast( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedGemmWithBroadcast testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +bool TestAllGemmWithBroadcast() { + + int M_problems[] = {8, 136, 264, 520}; + int N_problems[] = {8, 136, 264, 520}; + int K_problems[] = {8, 136, 264, 520}; + double alpha_problems[] = {1.25, 2.25}; + double beta_problems[] = {0, 1, 2.0}; + + bool passed = true; + + for (int M : M_problems) { + for (int N : N_problems) { + for (int K : K_problems) { + for (double alpha : alpha_problems) { + for (double beta : beta_problems) { + + TestbedGemmWithBroadcast testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + 1, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + EXPECT_TRUE(passed) + << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta; + + if (!passed) { + + return passed; + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..af3629ccfb87e09e80b85af508379780d6428dc5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithReductionReference { + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; + using ElementC = typename Gemm::ElementC; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + // + // Data members + // + + BinaryOp binary_op; + + // + // Methods + // + + GemmWithReductionReference() { } + + ElementCompute operator()( + ElementAccumulator d_y, + ElementT t) { + + return binary_op(ElementCompute(d_y), ElementCompute(t)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp +> +struct TestbedGemmWithReduction { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_Reduction; + cutlass::HostTensor tensor_Tensor; + cutlass::HostTensor tensor_C_ref; + cutlass::HostTensor reference_d_Y; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Reduction; + + // + // Methods + // + + TestbedGemmWithReduction( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 1; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + for (int m = 0; m < view.extent().row(); ++m) { + for (int n = 0; n < view.extent().column(); ++n) { + //view.at({m, n}) = Element(float(((idx ++) % 17) - 8)); + view.at({m, n}) = (n == 0 ? Element(m) : Element()); + + } + } + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + + tensor_Reduction.resize({ + problem_size.m(), + (problem_size.n() - 1 + Gemm::ThreadblockShape::kN) / Gemm::ThreadblockShape::kN + }); + + tensor_Tensor.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + reference_d_Y.resize(problem_size.mn(), false); + tensor_C_ref.resize(problem_size.mn(), false); + reference_Reduction.resize({problem_size.m(), 1}, false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Tensor.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { + for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { + tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_Reduction.sync_device(); + tensor_Tensor.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + tensor_Reduction.sync_host(); + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Reduction.host_view()), 0); + + bool passed = true; + for (int m = 0; m < tensor_Reduction.extent().row(); ++m) { + + ElementAccumulator reduced_value = ElementAccumulator(); + for (int j = 0; j < tensor_Reduction.extent().column(); ++j) { + reduced_value += tensor_Reduction.at({m, j}); + } + + if (reduced_value != reference_Reduction.at({m, 0})) { + std::cout << "Error in bias[" << m << "] - Expected: " << reference_Reduction.at({m, 0}) << ", got: " << reduced_value << std::endl; + passed = false; + break; + } + } + EXPECT_TRUE(passed) << "Reduction is incorect."; + + if (!cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view())) { + EXPECT_TRUE(false) << " mismatched reference"; + passed = false; + } + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors_sm70.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nT = \n" << tensor_Tensor.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view() + << "\n\nReduction =\n" << tensor_Reduction.host_view() << "\n" + << "\nReference reduction =\n" << reference_Reduction.host_view() << "\n"; + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + ElementAccumulator, typename Gemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C_ref.host_ref(), + reference_d_Y.host_ref(), + ElementAccumulator(0) + ); + + using ElementC = typename Gemm::ElementC; + + ReferenceOp reference_op; + + // compute backwards + for (int m = 0; m < problem_size.m(); ++m) { + ElementAccumulator reduced_value = ElementAccumulator(); + for (int n = 0; n < problem_size.n(); ++n) { + ElementAccumulator d_full = reference_op(reference_d_Y.at({m, n}), tensor_Tensor.at({m, n})); + reduced_value += d_full; + reference_D.at({m, n}) = ElementC(d_full); + } + reference_Reduction.at({m, 0}) = reduced_value; + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementAccumulator alpha = ElementAccumulator(1), + ElementAccumulator beta = ElementAccumulator(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_Reduction.device_data(), + tensor_Tensor.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_Reduction.layout().stride(0), + tensor_Tensor.layout().stride(0), + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + // + // Profile + // + + #if 0 // profiling disabled for now. + + int const kWorkspaces = 100; + + cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_D(tensor_D.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Reduction(tensor_Reduction.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Tensor(tensor_Tensor.capacity() * kWorkspaces); + + cudaEvent_t events[2]; + for (auto & event : events) { + cudaError_t result = cudaEventCreate(&event); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); + return false; + break; + } + } + + int const kWarmupIterations = 5; + int const kProfilingIterations = 100; + + for (int i = 0; i < kWarmupIterations; ++i) { + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + + cudaError_t result = cudaEventRecord(events[0]); + EXPECT_EQ(result, cudaSuccess); + + for (int i = 0; i < kProfilingIterations; ++i) { + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), + profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), + profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), + profiling_tensor_D.get() + tensor_D.capacity() * (i % kWorkspaces), + profiling_tensor_Reduction.get() + tensor_Reduction.capacity() * (i % kWorkspaces), + profiling_tensor_Tensor.get() + tensor_Tensor.capacity() * (i % kWorkspaces), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_Reduction.layout().stride(0), + tensor_Tensor.layout().stride(0), + }; + + gemm_op.initialize(arguments, workspace.get()); + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + result = cudaEventRecord(events[1]); + EXPECT_EQ(result, cudaSuccess); + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess); + + float elapsed_time = 0; + result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); + EXPECT_EQ(result, cudaSuccess); + + double average_time = double(elapsed_time) / double(kProfilingIterations); + + std::cout << problem_size << ": " << average_time << " ms" << std::endl; + + for (auto & event : events) { + cudaEventDestroy(event); + } + #endif + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmWithReduction( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count = 1, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedGemmWithReduction testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..c7317eb855477e63fe19858ca51cd5722f236eb5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h @@ -0,0 +1,501 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedGrouped { + + // + // Type definitions + // + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + int problem_count; + + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + // + // Methods + // + + TestbedGrouped( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // no fill - remain zero + } + + return true; + } + + /// Initializes data structures + void initialize() { + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + + lda_host.resize(problem_count); + ldb_host.resize(problem_count); + ldc_host.resize(problem_count); + ldd_host.resize(problem_count); + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + cutlass::gemm::GemmCoord problem( + 8 * (rand() % 64) + 24, + 8 * (rand() % 64) + 24, + 8 * (rand() % 64) + 24); + + if (!i) { + problem = cutlass::gemm::GemmCoord(48, 16, 8); + } + + problem_sizes_host.at(i) = problem; + + // std::cout << "Problem[" << i << "]: " << problem << std::endl; + + lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.m() * problem.k(); + int64_t elements_B = problem.k() * problem.n(); + int64_t elements_C = problem.m() * problem.n(); + int64_t elements_D = problem.m() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + // Random strides between problems? + } + + problem_sizes_device.reset(problem_count); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + lda.reset(problem_count); + ldb.reset(problem_count); + ldc.reset(problem_count); + ldd.reset(problem_count); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + std::vector ptr_A_host(problem_count); + std::vector ptr_B_host(problem_count); + std::vector ptr_C_host(problem_count); + std::vector ptr_D_host(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + + initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); + initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); + initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); + + cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); + cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); + cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); + cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); + } + } + + /// Verifies the result is a GEMM + bool verify( + ElementCompute alpha, + ElementCompute beta) { + + bool passed = true; + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); + cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); + cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); + cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); + cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference GEMM + cutlass::reference::host::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + alpha, + view_A, + Gemm::kTransformA, + view_B, + Gemm::kTransformB, + beta, + view_C, + view_Ref, + ElementAccumulator(0) + ); + + // Ensure that no input or output is entirely zero + EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + + // Compare against reference + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::ofstream file("testbed_grouped_errors.txt"); + + file + << "problem: " << problem << " [group: " << i << "]\n" + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << view_A + << "\nB =\n" << view_B + << "\nC =\n" << view_C + << "\n\nReference =\n" << view_Ref + << "\nComputed =\n" << view_D; + + return passed; + } + } + + return passed; + } + + /// Executes one test + bool run( + int problem_count, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->problem_count = problem_count; + + // Initialize the problem + initialize(); + + int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), problem_count); + + // Early exit + if (!threadblock_count) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; + } + return true; + } + + // Configure the GEMM arguments + typename EpilogueOutputOp::Params epilogue_op(alpha, beta); + + // Configure GEMM arguments + typename Gemm::Arguments args( + problem_sizes_device.get(), + problem_count, + threadblock_count, + epilogue_op, + ptr_A.get(), + ptr_B.get(), + ptr_C.get(), + ptr_D.get(), + lda.get(), + ldb.get(), + ldc.get(), + ldd.get(), + problem_sizes_host.data() + ); + + // Initialize the GEMM object + Gemm gemm; + + size_t workspace_size = gemm.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = gemm.initialize(args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Run the GEMM object + status = gemm.run(); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Wait for completion + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) + << "Kernel execution error: " << cudaGetErrorString(result); + + if (result != cudaSuccess) { + return false; + } + + // Verify correctness + return verify(alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h new file mode 100644 index 0000000000000000000000000000000000000000..f8f08f23c4477745648f1cf8f9e439ae6b5061e2 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h @@ -0,0 +1,502 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface + +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedGrouped { + + // + // Type definitions + // + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + using EpilogueOutputOp = typename Rank2K::EpilogueOutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Rank2K::LayoutA; + using LayoutB = typename Rank2K::LayoutB; + using LayoutC = typename Rank2K::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + int problem_count; + + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + // + // Methods + // + + TestbedGrouped( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // no fill - remain zero + } + + return true; + } + + /// Initializes data structures + void initialize() { + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + + lda_host.resize(problem_count); + ldb_host.resize(problem_count); + ldc_host.resize(problem_count); + ldd_host.resize(problem_count); + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + auto N = 8 * (rand() % 64) + 24; + auto K = 8 * (rand() % 64) + 24; + cutlass::gemm::GemmCoord problem(N, N, K); + + if (!i) { + problem = cutlass::gemm::GemmCoord(16, 16, 8); + } + + problem_sizes_host.at(i) = problem; + + lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.n() * problem.k(); + int64_t elements_B = problem.n() * problem.k(); + int64_t elements_C = problem.n() * problem.n(); + int64_t elements_D = problem.n() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + // Random strides between problems? + } + + problem_sizes_device.reset(problem_count); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + lda.reset(problem_count); + ldb.reset(problem_count); + ldc.reset(problem_count); + ldd.reset(problem_count); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + std::vector ptr_A_host(problem_count); + std::vector ptr_B_host(problem_count); + std::vector ptr_C_host(problem_count); + std::vector ptr_D_host(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.n(), problem.k()}; + MatrixCoord extent_B{problem.n(), problem.k()}; + MatrixCoord extent_C{problem.n(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + + initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); + initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); + initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); + + cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); + cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); + cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); + cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); + } + } + + /// Verifies the result is a Rank2K + bool verify( + ElementCompute alpha, + ElementCompute beta) { + + bool passed = true; + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.n(), problem.k()}; + MatrixCoord extent_B{problem.n(), problem.k()}; + MatrixCoord extent_C{problem.n(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); + cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); + cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); + cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); + cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference Rank2K + cutlass::reference::host::Rank2KComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + alpha, + view_A, + Rank2K::kTransformA, + view_B, + Rank2K::kTransformB, + beta, + view_C, + view_Ref, + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + // Ensure that no input or output is entirely zero + EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + + // Compare against reference + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::ofstream file("testbed_grouped_errors.txt"); + + file + << "problem: " << problem << " [group: " << i << "]\n" + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << view_A + << "\nB =\n" << view_B + << "\nC =\n" << view_C + << "\n\nReference =\n" << view_Ref + << "\nComputed =\n" << view_D; + + return passed; + } + } + + return passed; + } + + /// Executes one test + bool run( + int problem_count, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->problem_count = problem_count; + + // Initialize the problem + initialize(); + + int threadblock_count = Rank2K::sufficient(problem_sizes_host.data(), problem_count); + + // Early exit + if (!threadblock_count) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; + } + return true; + } + + // Configure the Rank2K arguments + typename EpilogueOutputOp::Params epilogue_op(alpha, beta); + + // Configure Rank2K arguments + typename Rank2K::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_sizes_device.get(), + problem_count, + threadblock_count, + epilogue_op, + ptr_A.get(), + ptr_B.get(), + ptr_C.get(), + ptr_D.get(), + lda.get(), + ldb.get(), + ldc.get(), + ldd.get(), + problem_sizes_host.data() + ); + + // Initialize the Rank2K object + Rank2K rank2k; + + size_t workspace_size = rank2k.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = rank2k.initialize(args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Run the Rank2K object + status = rank2k.run(); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Wait for completion + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) + << "Kernel execution error: " << cudaGetErrorString(result); + + if (result != cudaSuccess) { + return false; + } + + // Verify correctness + return verify(alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..e9315e12e8711f50256e4cfe05666201acd614d3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h @@ -0,0 +1,461 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K problem visitors +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/device_kernel.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Use simple problem visitor as a baseline +template +struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { + using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static cutlass::FillMode const kFillModeC = FillModeC; + + struct SharedStorage {}; + + int32_t tile_count_sum; + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + BaselineProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + shared_storage(shared_storage_) + { + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + tile_count_sum = this->tile_count(grid); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx < tile_count_sum) { + return true; + } + + do { + ++this->problem_idx; + + if (this->problem_idx >= this->params.problem_count) { + return false; + } + + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + + this->problem_tile_start = tile_count_sum; + tile_count_sum += this->tile_count(grid); + + } while (tile_count_sum <= this->tile_idx); + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} + + CUTLASS_DEVICE + cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { + int32_t macro_id = threadblock_id / ProblemSizeHelper::OffsetHelper::kThreadblockSkewRatio; + int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; + int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); + + if (FillModeC == cutlass::FillMode::kUpper) { + cutlass::swap(macro_row, macro_col); + } + + int32_t row = ProblemSizeHelper::OffsetHelper::macro_row_to_row(macro_row, threadblock_id); + int32_t col = ProblemSizeHelper::OffsetHelper::macro_col_to_col(macro_col, threadblock_id); + + return cutlass::gemm::GemmCoord(row, col, 0); + } +}; + +template +struct ProblemVisitorKernel { + struct SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct Params { + typename ProblemVisitor::Params problem_visitor_params; + int32_t* visited_problems_ptr; + int32_t* visited_tiles_ptr; + int32_t visits_per_block; + + Params(): + visited_problems_ptr(nullptr), + visited_tiles_ptr(nullptr), + visits_per_block(0) {} + + Params(typename ProblemVisitor::Params problem_visitor_params_, + int32_t* visited_problems_ptr_, + int32_t* visited_tiles_ptr_, + int32_t visits_per_block_): + problem_visitor_params(problem_visitor_params_), + visited_problems_ptr(visited_problems_ptr_), + visited_tiles_ptr(visited_tiles_ptr_), + visits_per_block(visits_per_block_) {} + }; + + CUTLASS_DEVICE + void operator()(const Params& params, SharedStorage &shared_storage) { + int32_t store_offset = params.visits_per_block * blockIdx.x; + ProblemVisitor problem_visitor(params.problem_visitor_params, + shared_storage.problem_visitor, + blockIdx.x); + + while (problem_visitor.next_tile()) { + cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + cutlass::gemm::GemmCoord tile_offset = problem_visitor.threadblock_offset(threadblock_idx); + + problem_visitor.advance(gridDim.x); + + // + // Early exit conditions + // 1) Out of range + // 2) Upper-triangular block in lower-triangular problem + // 3) Lower-triangular block in upper-triangular problem + // + + if (grid_shape.m() <= tile_offset.m() || + grid_shape.n() <= tile_offset.n()) { + continue; + } + + if (ProblemVisitor::kFillModeC == cutlass::FillMode::kLower && + (tile_offset.m() + 1) * ProblemVisitor::ThreadblockShape::kM <= tile_offset.n() * ProblemVisitor::ThreadblockShape::kN) { + continue; + } + + if (ProblemVisitor::kFillModeC == cutlass::FillMode::kUpper && + tile_offset.m() * ProblemVisitor::ThreadblockShape::kM >= (tile_offset.n() + 1) * ProblemVisitor::ThreadblockShape::kN) { + continue; + } + + if (threadIdx.x == 0) { + params.visited_problems_ptr[store_offset] = problem_idx; + params.visited_tiles_ptr[store_offset] = threadblock_idx; + ++store_offset; + } + } + } +}; + +template +struct ProblemVisitorRunner { + using BaseKernel = ProblemVisitorKernel; + using Params = typename BaseKernel::Params; + + Params params; + std::vector host_problem_sizes; + int32_t problem_count; + int32_t threadblock_count; + int32_t visits_per_block; + cutlass::DeviceAllocation visited_problems; + cutlass::DeviceAllocation visited_tiles; + cutlass::DeviceAllocation device_problem_sizes; + cutlass::DeviceAllocation workspace; + std::vector host_visited_problems; + std::vector host_visited_tiles; + + ProblemVisitorRunner(const std::vector& host_problem_sizes_, + int32_t threadblock_count_): + host_problem_sizes(host_problem_sizes_), + problem_count(int32_t(host_problem_sizes_.size())), + threadblock_count(threadblock_count_) {} + + /// Initializes GEMM state from arguments. + cutlass::Status initialize() { + size_t workspace_bytes = ProblemVisitor::get_workspace_size( + host_problem_sizes.data(), + problem_count, + threadblock_count); + + workspace.reset(workspace_bytes); + std::vector host_workspace(workspace_bytes); + + int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); + + ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, + threadblock_count, host_workspace.data()); + + workspace.copy_from_host(host_workspace.data(), workspace_bytes); + + device_problem_sizes.reset(problem_count); + device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); + + visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; + int32_t total_visits = visits_per_block * threadblock_count; + + visited_problems.reset(total_visits); + visited_tiles.reset(total_visits); + host_visited_problems.resize(total_visits); + host_visited_tiles.resize(total_visits); + + cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); + params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); + + return cutlass::Status::kSuccess; + } + + bool verify() { + // Sort by problem size and then by threadblock_idx + std::vector indices(host_visited_problems.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort(indices.begin(), indices.end(), + [&](int32_t i1, int32_t i2) { + if (host_visited_problems[i1] == host_visited_problems[i2]) { + return host_visited_tiles[i1] < host_visited_tiles[i2]; + } + return host_visited_problems[i1] < host_visited_problems[i2]; + }); + + int32_t idx = 0; + + // Skip any entries that were not visited + while (host_visited_problems[indices[idx]] == -1) { + ++idx; + } + + // Check that each problem visited has the tiles we expect + for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { + auto problem = host_problem_sizes[problem_idx]; + ProblemVisitor::possibly_transpose_problem(problem); + int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); + for (int i = 0; i < problem_tiles; ++i) { + EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); + EXPECT_EQ(i, host_visited_tiles[indices[idx]]); + ++idx; + } + } + + return true; + } + + bool run(bool skip_tile_check=false, cudaStream_t stream = nullptr) { + cutlass::Status status = initialize(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Initialization failed" << std::endl; + return false; + } + + dim3 grid(threadblock_count, 1, 1); + dim3 block(ProblemVisitor::kThreadCount, 1, 1); + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + cutlass::Kernel<<>>(params); + + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + visited_problems.copy_to_host(host_visited_problems.data()); + visited_tiles.copy_to_host(host_visited_tiles.data()); + + if (skip_tile_check) { + return true; + } + + return verify(); + } +}; + +template +struct TestbedGroupedRank2KScheduler { + + using BaselinePV = BaselineProblemVisitor, + ThreadblockShape, + PrefetchTileCount, + ThreadCount, + FillModeC>; + + // + // Data members + // + + // Whether to skip checking that the tiles are visited as expected. This is useful + // in cases where ThreadblockShape::kM != ThreadblockShape::kN, for which the grouped + // Rank2K scheduler may assign out-of-bounds tiles that will cause a threadblock to + // exit early, but which are difficult to detect in tests without reimplementing + // this functionality. + bool skip_tile_check; + uint32_t seed; + int problem_count; + int threadblock_count; + std::vector problem_sizes_host; + + // + // Methods + // + + TestbedGroupedRank2KScheduler(bool skip_tile_check_=false, uint32_t seed_ = 3080): + skip_tile_check(skip_tile_check_), seed(seed_) { srand(seed); } + + /// Initializes data structures + void initialize(int32_t scale_factor) { + + // + // Choose random problem sizes + // + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + int n = scale_factor * (rand() % 64) + 24; + + cutlass::gemm::GemmCoord problem( + n, + n, + scale_factor * (rand() % 64) + 24); + + problem_sizes_host.at(i) = problem; + } + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + using PV = cutlass::gemm::kernel::Rank2KGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + FillModeC>; + ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(runner.run(skip_tile_check)); + + // Check that this problem visitor visits the same problems and tiles as the baseline + EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); + EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + // Compare the next visitor with the baseline visitor + compare_visitors(baseline_runner); + + // Recurse to compare the next visitors + compare_visitors(baseline_runner); + } + + /// Executes the test on all scheduler modes + void run(int problem_count, int threadblock_count, int scale_factor=8) { + + this->problem_count = problem_count; + this->threadblock_count = threadblock_count; + + // Initialize the problem + initialize(scale_factor); + + // Run the baseline visitor to which we will compare all other visitors + ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(baseline_runner.run(skip_tile_check)); + + compare_visitors(baseline_runner); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..bda2704b517ea95052e2c2060b50712b686344f6 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h @@ -0,0 +1,407 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped GEMM problem visitors +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/util/device_memory.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Use simple problem visitor as a baseline +template +struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { + using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + + struct SharedStorage {}; + + int32_t tile_count_sum; + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + BaselineProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + shared_storage(shared_storage_) + { + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + tile_count_sum = this->tile_count(grid); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx < tile_count_sum) { + return true; + } + + do { + ++this->problem_idx; + + if (this->problem_idx >= this->params.problem_count) { + return false; + } + + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + + this->problem_tile_start = tile_count_sum; + tile_count_sum += this->tile_count(grid); + + } while (tile_count_sum <= this->tile_idx); + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ProblemVisitorKernel { + struct SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct Params { + typename ProblemVisitor::Params problem_visitor_params; + int32_t* visited_problems_ptr; + int32_t* visited_tiles_ptr; + int32_t visits_per_block; + + Params(): + visited_problems_ptr(nullptr), + visited_tiles_ptr(nullptr), + visits_per_block(0) {} + + Params(typename ProblemVisitor::Params problem_visitor_params_, + int32_t* visited_problems_ptr_, + int32_t* visited_tiles_ptr_, + int32_t visits_per_block_): + problem_visitor_params(problem_visitor_params_), + visited_problems_ptr(visited_problems_ptr_), + visited_tiles_ptr(visited_tiles_ptr_), + visits_per_block(visits_per_block_) {} + }; + + CUTLASS_DEVICE + void operator()(const Params& params, SharedStorage &shared_storage) { + int32_t store_offset = params.visits_per_block * blockIdx.x; + ProblemVisitor problem_visitor(params.problem_visitor_params, + shared_storage.problem_visitor, + blockIdx.x); + + while (problem_visitor.next_tile()) { + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + if (threadIdx.x == 0) { + params.visited_problems_ptr[store_offset] = problem_idx; + params.visited_tiles_ptr[store_offset] = threadblock_idx; + ++store_offset; + } + problem_visitor.advance(gridDim.x); + } + } +}; + +template +struct ProblemVisitorRunner { + using BaseKernel = ProblemVisitorKernel; + using Params = typename BaseKernel::Params; + + Params params; + std::vector host_problem_sizes; + int32_t problem_count; + int32_t threadblock_count; + int32_t visits_per_block; + cutlass::DeviceAllocation visited_problems; + cutlass::DeviceAllocation visited_tiles; + cutlass::DeviceAllocation device_problem_sizes; + cutlass::DeviceAllocation workspace; + std::vector host_visited_problems; + std::vector host_visited_tiles; + + ProblemVisitorRunner(const std::vector& host_problem_sizes_, + int32_t threadblock_count_): + host_problem_sizes(host_problem_sizes_), + problem_count(int32_t(host_problem_sizes_.size())), + threadblock_count(threadblock_count_) {} + + /// Initializes GEMM state from arguments. + cutlass::Status initialize() { + size_t workspace_bytes = ProblemVisitor::get_workspace_size( + host_problem_sizes.data(), + problem_count, + threadblock_count); + + workspace.reset(workspace_bytes); + std::vector host_workspace(workspace_bytes); + + int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); + + ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, + threadblock_count, host_workspace.data()); + + workspace.copy_from_host(host_workspace.data(), workspace_bytes); + + device_problem_sizes.reset(problem_count); + device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); + + visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; + int32_t total_visits = visits_per_block * threadblock_count; + + visited_problems.reset(total_visits); + visited_tiles.reset(total_visits); + host_visited_problems.resize(total_visits); + host_visited_tiles.resize(total_visits); + + cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); + params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); + + return cutlass::Status::kSuccess; + } + + bool verify() { + // Sort by problem size and then by threadblock_idx + std::vector indices(host_visited_problems.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort(indices.begin(), indices.end(), + [&](int32_t i1, int32_t i2) { + if (host_visited_problems[i1] == host_visited_problems[i2]) { + return host_visited_tiles[i1] < host_visited_tiles[i2]; + } + return host_visited_problems[i1] < host_visited_problems[i2]; + }); + + int32_t idx = 0; + + // Skip any entries that were not visited + while (host_visited_problems[indices[idx]] == -1) { + ++idx; + } + + // Check that each problem visited has the tiles we expect + for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { + auto problem = host_problem_sizes[problem_idx]; + ProblemVisitor::possibly_transpose_problem(problem); + int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); + for (int i = 0; i < problem_tiles; ++i) { + EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); + EXPECT_EQ(i, host_visited_tiles[indices[idx]]); + ++idx; + } + } + + return true; + } + + bool run(cudaStream_t stream = nullptr) { + cutlass::Status status = initialize(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Initialization failed" << std::endl; + return false; + } + + dim3 grid(threadblock_count, 1, 1); + dim3 block(ProblemVisitor::kThreadCount, 1, 1); + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + cutlass::Kernel<<>>(params); + + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + visited_problems.copy_to_host(host_visited_problems.data()); + visited_tiles.copy_to_host(host_visited_tiles.data()); + + return verify(); + } +}; + +template +struct TestbedGroupedGemmScheduler { + + using PSHelper = cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper; + using BaselinePV = BaselineProblemVisitor; + + // + // Data members + // + uint32_t seed; + int problem_count; + int threadblock_count; + std::vector problem_sizes_host; + + // + // Methods + // + + TestbedGroupedGemmScheduler(uint32_t seed_ = 3080): + seed(seed_) { srand(seed); } + + /// Initializes data structures + void initialize(int32_t scale_factor) { + + // + // Choose random problem sizes + // + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + cutlass::gemm::GemmCoord problem( + scale_factor * (rand() % 64) + 24, + scale_factor * (rand() % 64) + 24, + scale_factor * (rand() % 64) + 24); + + problem_sizes_host.at(i) = problem; + } + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + using PV = cutlass::gemm::kernel::GemmGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + Transpose>; + ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(runner.run()); + + // Check that this problem visitor visits the same problems and tiles as the baseline + EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); + EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + // Compare the next visitor with the baseline visitor + compare_visitors(baseline_runner); + + // Recurse to compare the next visitors + compare_visitors(baseline_runner); + } + + /// Executes the test on all scheduler modes + void run(int problem_count, int threadblock_count, int scale_factor=8) { + + this->problem_count = problem_count; + this->threadblock_count = threadblock_count; + + // Initialize the problem + initialize(scale_factor); + + // Run the baseline visitor to which we will compare all other visitors + ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(baseline_runner.run()); + + compare_visitors(baseline_runner); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..2a5956000db8e8c05ea22538e58149998b03e3fc --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h @@ -0,0 +1,346 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct InterleavedTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + InterleavedTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Waives test if CUDA device is insufficient + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm::ElementA, + typename Gemm::LayoutA> tensor_A(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_C(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_D(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reorder_column( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); + + cutlass::reference::host::TensorCopy( + reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), + tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nB_reordered =\n" << tensor_B_reordered.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + int problem_size_n[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + int problem_size_k[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = run( + {m, n, k}, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..32452c30e05f64763a268195ae78138f26c09735 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +class TestbedPlanarComplex { +public: + + using ElementA = typename Gemm::ElementA; + using LayoutA = typename Gemm::LayoutA; + using ElementB = typename Gemm::ElementB; + using LayoutB = typename Gemm::LayoutB; + using ElementC = typename Gemm::ElementC; + using LayoutC = typename Gemm::LayoutC; + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + // + // Data members + // + + cutlass::gemm::GemmCoord problem_size; + cutlass::HostTensorPlanarComplex tensor_A; + cutlass::HostTensorPlanarComplex tensor_B; + cutlass::HostTensorPlanarComplex tensor_C; + cutlass::HostTensorPlanarComplex tensor_D; + cutlass::HostTensorPlanarComplex tensor_D_ref; + + // + // Methods + // + + TestbedPlanarComplex(cutlass::gemm::GemmCoord const & problem_size): problem_size(problem_size) { + + tensor_A.reset({problem_size.m(), problem_size.k()}); + tensor_B.reset({problem_size.k(), problem_size.n()}); + tensor_C.reset({problem_size.m(), problem_size.n()}); + tensor_D.reset({problem_size.m(), problem_size.n()}); + tensor_D_ref.reset({problem_size.m(), problem_size.n()}, false); + } + + void initialize() { + + uint64_t seed = 1073; + + int scope_max = 8; + int scope_min = -8; + + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed * 2019, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_C.host_view(), seed * 2020, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFill(tensor_D.host_view(), cutlass::complex()); + cutlass::reference::host::TensorFill(tensor_D_ref.host_view(), cutlass::complex()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + bool run( + cutlass::complex alpha = {1, 0}, + cutlass::complex beta = {0, 0}) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + initialize(); + + int batch_count = 1; + + ElementA *ptr_A = tensor_A.device_data(); + ElementB *ptr_B = tensor_B.device_data(); + ElementC *ptr_C = tensor_C.device_data(); + ElementC *ptr_D = tensor_D.device_data(); + + typename LayoutA::Stride::Index lda = tensor_A.layout().stride(0); + typename LayoutB::Stride::Index ldb = tensor_B.layout().stride(0); + typename LayoutC::Stride::Index ldc = tensor_C.layout().stride(0); + typename LayoutC::Stride::Index ldd = tensor_D.layout().stride(0); + + int64_t imag_stride_A = tensor_A.imaginary_stride(); + int64_t imag_stride_B = tensor_B.imaginary_stride(); + int64_t imag_stride_C = tensor_C.imaginary_stride(); + int64_t imag_stride_D = tensor_D.imaginary_stride(); + + // + // Launch device kernel + // + + Gemm gemm_op; + + typename Gemm::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + batch_count, + {alpha, beta}, + ptr_A, + ptr_A + imag_stride_A, + ptr_B, + ptr_B + imag_stride_B, + ptr_C, + ptr_C + imag_stride_C, + ptr_D, + ptr_D + imag_stride_D, + lda, + lda, + ldb, + ldb, + ldc, + ldc, + ldd, + ldd + }; + + cutlass::Status status = gemm_op(args); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + + cudaError_t error = cudaDeviceSynchronize(); + + tensor_D.sync_host(); + + // + // Compute reference + // + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + tensor_D_ref.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_D.host_view(), + tensor_D_ref.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("gemm_planar_complex.txt"); + + output + << "A:\n" << tensor_A.host_view() << "\n" + << "B:\n" << tensor_B.host_view() << "\n" + << "C:\n" << tensor_C.host_view() << "\n" + << "Reference:\n" + << tensor_D_ref.host_view() << "\n" + << "Computed:\n" + << tensor_D.host_view() << "\n"; + } + + return passed; + } +}; + +template +bool TestOneGemmPlanarComplex(cutlass::gemm::GemmCoord problem_size) { + + TestbedPlanarComplex testbed(problem_size); + + return testbed.run(); +} + +template +bool TestAllGemmPlanarComplex() { + + int M[] = { + 16, 64, 72, 144, 264, 520, + }; + + int N[] = { + 16, 64, 72, 144, 248, 264, 520 + }; + + int K[] = { + 8, 64, 72, 96, 264, 520 + }; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + cutlass::complex alpha_values[] = { + {ElementCompute(1.25), ElementCompute(-0.5)} + }; + + cutlass::complex beta_values[] = { + {ElementCompute(-2.25), ElementCompute(1.5)} + }; + + for (int m : M) { + for (int n : N) { + for (int k : K) { + + test::gemm::device::TestbedPlanarComplex testbed({m, n, k}); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = testbed.run(alpha, beta); + if (!passed) { + return false; + } + } + } + } + } + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..4d9f6743a45e5dc3a7b4ddd3e2a7b2abceffbb18 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h @@ -0,0 +1,641 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Rank 2k update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/rank_2k.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedRank2KUniversal { + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + using ElementCompute = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedRank2KUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Rank2K::kFillModeC, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Rank2K::kFillModeC, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the Rank2K workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.mk()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Rank2K::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Rank2K::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Rank2K::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a Rank2K + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + cutlass::reference::host::Rank2KComplex< + typename Rank2K::ElementA, typename Rank2K::LayoutA, + typename Rank2K::ElementB, typename Rank2K::LayoutB, + typename Rank2K::ElementC, typename Rank2K::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Rank2K::kTransformA, + tensor_B.host_ref(), + Rank2K::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Rank2K::Rank2Kkernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedRank2KUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the Rank2K operator + // + + typename Rank2K::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.n() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Rank2K rank2k_op; + + size_t workspace_size = Rank2K::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the Rank2K + // + + status = rank2k_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_Rank2k_device_" + << "fill_mode_c_" + << (Rank2K::kFillModeC == cutlass::FillMode::kLower ? "lower_" : + (Rank2K::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Rank2K::ThreadblockShape::kM << "x" + << Rank2K::ThreadblockShape::kN << "x" + << Rank2K::ThreadblockShape::kK << "_" + << Rank2K::WarpShape::kM << "x" + << Rank2K::WarpShape::kN << "x" + << Rank2K::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestRank2kUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedRank2KUniversal testbed; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllRank2KUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Rank2K::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 3.25 + }; + + double problem_beta[] = { + 0.0, 2.15 + }; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { + // continue; + //} + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +template +bool TestAllRank2KHermitianUniversal() { + bool passed = true; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Rank2K::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + /* Complex alpha for HER2K */ + ElementAccumulator problem_alpha[] = { + {1.0}, + {1.25, 3.25}, + {-0.25, -2.25} + }; + + ElementAccumulator problem_beta[] = { + 0.0, -2.25 + }; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { + // continue; + //} + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + alpha, + beta + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..cb46528a049ae1254d0492b6235821210e47b957 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h @@ -0,0 +1,511 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Rank 2k update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/rank_k_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedRank2KUniversal { + + using ElementA = typename RankK::ElementA; + using ElementC = typename RankK::ElementC; + using ElementAccumulator = typename RankK::ElementAccumulator; + using ElementCompute = typename RankK::RankKkernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedRank2KUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, RankK::kFillModeC, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, RankK::kFillModeC, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the RankK workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename RankK::ElementA(1); + tensor_C.host_view().at({0, 0}) = typename RankK::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a RankK + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + cutlass::reference::host::Rank2KComplex< + typename RankK::ElementA, typename RankK::LayoutA, + typename RankK::ElementC, typename RankK::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + RankK::kTransformA, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0), + RankK::kFillModeC, + RankK::kBlasMode + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename RankK::RankKkernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedRankKUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the RankK operator + // + + typename RankK::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + RankK rank2k_op; + + size_t workspace_size = RankK::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the RankK + // + + status = rank2k_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_RankK_device_" + << "fill_mode_c_" + << (RankK::kFillModeC == cutlass::FillMode::kLower ? "lower_" : + (RankK::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << RankK::ThreadblockShape::kM << "x" + << RankK::ThreadblockShape::kN << "x" + << RankK::ThreadblockShape::kK << "_" + << RankK::WarpShape::kM << "x" + << RankK::WarpShape::kN << "x" + << RankK::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestRank2kUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedRank2KUniversal testbed; + + using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllRankKUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + int const kAlignmentN = 128 / kMinimumOperandElementSize; + int const kAlignmentK = 128 / kMinimumOperandElementSize; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + RankK::ThreadblockShape::kK * RankK::kStages - kAlignmentK, + RankK::ThreadblockShape::kK * RankK::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 2.0 + }; + + + using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h new file mode 100644 index 0000000000000000000000000000000000000000..0a01a6a32ee2db84f2e890059423cd6b8477f766 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/core_io.h" + +#include "testbed.h" + + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// List of Gemm internal paramters this testbed supports user verification +// +enum class ParameterID { + + // Threadblock-level parameters + kSmemASize, + kSmemBSize, + + // Warp-level parameters + kWarpFragmentASize, + kWarpFragmentBSize, + kWarpFragmentCSize, + kInvalid +}; + +struct Reference { + ParameterID parameter_id; + + union { + int value; + + struct { + int m, n, k; + } gemm_shape; + + struct { + int row, column; + } matrix_shape; + }; + + std::string error_msg; + + Reference( + ParameterID parameter_id_, + int value_=-1, + std::string const &error_msg_="") : parameter_id(parameter_id_), value(value_), error_msg(error_msg_) {} +}; + + +template +struct TestbedSanity { + + // + // Type definitions (All Gemm types top down) + // + + // Unpacking Gemm types in the following order + // Kernel-level > Threadblock-level > Warp-level > Instruction-level + + // kernel-level cutlass Gemm + using GemmKernel = typename Gemm::GemmKernel; + + // + // Threadblock-level gemm types + // + using MmaThreadBlock = typename GemmKernel::Mma; + + // Threadblock-level gemm shape covering one stage + using ThreadblockShape = typename MmaThreadBlock::Shape; + + // Shared memory size covering all stages + using SmemShapeA = typename MmaThreadBlock::Base::SharedStorage::ShapeA; + using SmemPaddingA = typename MmaThreadBlock::Policy::SmemPaddingA; + using SmemShapeB = typename MmaThreadBlock::Base::SharedStorage::ShapeB; + using SmemPaddingB = typename MmaThreadBlock::Policy::SmemPaddingB; + + + /// Number of stages + static int const kStages = MmaThreadBlock::Base::kStages; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = MmaThreadBlock::kWarpGemmIterations; + + + // + // Warp-level gemm types + // + + // Warp-level gemm operator + using MmaWarp = typename MmaThreadBlock::Operator; + + // Warp-level gemm shape covering all kgroups + using WarpShape = typename MmaWarp::Shape; + + // Warp-level framents holding operands A & B operand and destination C + using WarpFragmentA = typename MmaWarp::FragmentA; + using WarpFragmentB = typename MmaWarp::FragmentB; + using WarpFragmentC = typename MmaWarp::FragmentC; + + // + // Instruction-level gemm types + // + + // Instruction-level gemm operator + using MmaInstruction = typename MmaWarp::Policy::Operator; + + // Instruction shape + using InstructionShape = typename MmaInstruction::Shape; + + // Instruction-level framents holding operands A & B operand and destination C + using InstructionFragmentA = typename MmaInstruction::FragmentA; + using InstructionFragmentB = typename MmaInstruction::FragmentB; + using InstructionFragmentC = typename MmaInstruction::FragmentC; + + // + // Testbed types + // + + // Vector of values holding user provided reference + using ReferenceVector = std::vector; + + // + // Data members + // + ReferenceVector references; + + // + // Methods + // + + TestbedSanity(ReferenceVector const &references_ = ReferenceVector()) : references(references_){ } + + // verify all parameter in ReferenceVector + bool verify() { + for(auto ref : references) + verify_parameter(ref); + return true; + } + + // verify parameter of type Reference + void verify_parameter(Reference const& ref) { + switch(ref.parameter_id) { + case ParameterID::kWarpFragmentASize : EXPECT_TRUE(WarpFragmentA::kElements == ref.value) << *this; break; + case ParameterID::kWarpFragmentBSize : EXPECT_TRUE(WarpFragmentB::kElements == ref.value) << *this; break; + case ParameterID::kWarpFragmentCSize : EXPECT_TRUE(WarpFragmentC::kElements == ref.value) << *this; break; + } + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Overload output operators for TesbedSanity +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +std::ostream & operator<<(std::ostream &out, TestbedSanity const &test) { + + + out << "Gemm internal parameters" << std::endl + << " Threadblock-level parameters:" << std::endl + << " ThreadblockShape = " << typename TestbedSanity::ThreadblockShape() << std::endl + << " kStages = " << TestbedSanity::kStages << std::endl + << " kWarpGemmIterations = "<< TestbedSanity::kWarpGemmIterations << std::endl + <<" Shared memory sizes:" << std::endl + <<" SmemPaddingA = " << typename TestbedSanity::SmemPaddingA() << std::endl + <<" SmemPaddingB = " << typename TestbedSanity::SmemPaddingB() << std::endl + <<" SmemShapeA = " << typename TestbedSanity::SmemShapeA() << std::endl + <<" SmemShapeB = " << typename TestbedSanity::SmemShapeB() << std::endl + <<" Warp-level parameters" << std::endl + <<" WarpShape = " << typename TestbedSanity::WarpShape() << std::endl + <<" Fragment sizes:" << std::endl + <<" WarpFragmentA::kElements = " << TestbedSanity::WarpFragmentA::kElements << std::endl + <<" WarpFragmentB::kElements = " << TestbedSanity::WarpFragmentB::kElements << std::endl + <<" WarpFragmentC::kElements = " << TestbedSanity::WarpFragmentC::kElements << std::endl + <<" Instruction-level parameters" << std::endl + <<" InstructionShape = " << typename TestbedSanity::InstructionShape() << std::endl + <<" Fragment sizes:" << std::endl + <<" InstructionFragmentA::kElements = " << TestbedSanity::InstructionFragmentA::kElements << std::endl + <<" InstructionFragmentB::kElements = " << TestbedSanity::InstructionFragmentB::kElements << std::endl + <<" InstructionFragmentC::kElements = " << TestbedSanity::InstructionFragmentC::kElements << std::endl; + + return out; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h new file mode 100644 index 0000000000000000000000000000000000000000..a95bf996bac337b44da616dc9fbf9c9bdb2a625c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h @@ -0,0 +1,487 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + + Testbed for sparse operations not to be released for CUDA 11.0 GA. Expected release is 11.1. +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SparseTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + static int const kSparse = Gemm::GemmKernel::kSparse; + static int const kMetaSizeInBits = Gemm::GemmKernel::kMetaSizeInBits; + static int const kMaxID2 = Gemm::GemmKernel::kMaxID2; + static int const kElementsPerElementE = Gemm::GemmKernel::kElementsPerElementE; + + using ElementE = typename Gemm::GemmKernel::ElementE; + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = typename Gemm::GemmKernel::LayoutE; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_E; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_uncompressed; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_E_reordered; + + // + // Methods + // + + SparseTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), + init_B(init_B_), + init_C(init_C_), + init_E(init_E_), + seed(seed_) {} + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); + tensor_A_uncompressed.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + tensor_E.resize(cutlass::make_Coord( + problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + tensor_E_reordered.resize(cutlass::make_Coord( + problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_E.host_view(), seed, kMetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (kMaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(tensor_E.host_view(), + (ElementE)(content)); + } else { + EXPECT_TRUE(false); + } + + cutlass::reorder_meta(tensor_E_reordered.host_ref(), tensor_E.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / kSparse / kElementsPerElementE}); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_E_reordered.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nE =\n" << tensor_E.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), + tensor_E.host_ref(), problem_size.m(), problem_size.k()); + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A_uncompressed.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + split_k_slices, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_E_reordered.device_data(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_E_reordered.layout().stride(0) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + // This failure is likely due to insufficient device capabilities. Waive the test. + if (status != cutlass::Status::kSuccess) { + return true; + } + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << ", beta: " << beta << ", m: " << problem_size.m() << ", n: " << problem_size.n() << ", k:" < +bool TestAllSparseGemm() { + bool passed = true; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the reordering of operand E + int const kAlignmentM = std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), + kMinimumOperandElementSize); + + int const kAlignmentN = 128 / kMinimumOperandElementSize; + + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; + + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + + int problem_size_k[] = {Gemm::ThreadblockShape::kK * 8}; + + int split_k_slices[] = { + 1, 2 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + SparseTestbed testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + cutlass::gemm::GemmCoord problem_size(m, n, k); + + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h new file mode 100644 index 0000000000000000000000000000000000000000..8fa4a85505316d08f1d050702b78448f8fae8565 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "testbed.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedSplitK : public Testbed { + + using Base = Testbed; + + using ElementCompute = typename Base::ElementCompute; + + // + // Methods + // + + TestbedSplitK( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + Base(init_A_, init_B_, init_C_, seed_) { } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + this->tensor_A.device_ref(), + this->tensor_B.device_ref(), + this->tensor_C.device_ref(), + this->tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + return this->verify(problem_size, alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmSplitK() { + bool passed = true; + + cutlass::gemm::GemmCoord problem_sizes[] = { + {8, 8, 2048}, + {8, 8, 2056}, + {264, 72, 520}, + {264, 520, 120}, + {264, 520, 264} + }; + + int split_k_slices[] = { + 1, 2, 4, 5, 7 + }; + + double problem_alpha[] = { + 0.5 + }; + + double problem_beta[] = { + 2.0 + }; + + using Testbed = TestbedSplitK; + using ElementCompute = typename Testbed::ElementCompute; + + Testbed testbed; + + for (auto problem_size : problem_sizes) { + for (int split_k_count : split_k_slices) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = testbed.run( + problem_size, + split_k_count, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + std::cout << "Failed on size " << problem_size << " with split_k_count " << split_k_count << std::endl; + return false; + } + } + } + } + } + + EXPECT_TRUE(passed); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..b7a57f7eb0ca73c23460e5a9ce1301061c2cc286 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h @@ -0,0 +1,592 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Symm update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/symm.h" +#include "cutlass/util/reference/host/symm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedSymmUniversal { + + using ElementA = typename Symm::ElementA; + using ElementB = typename Symm::ElementB; + using ElementC = typename Symm::ElementC; + using ElementAccumulator = typename Symm::ElementAccumulator; + using ElementCompute = typename Symm::SymmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedSymmUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Symm::kFillModeA, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Symm::kFillModeA, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the Symm workspace + // + + if (Symm::kSideModeA == cutlass::SideMode::kLeft) { + tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); + } + else if (Symm::kSideModeA == cutlass::SideMode::kRight) { + tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); + } + + tensor_B.resize(problem_size.mn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Symm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Symm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Symm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a Symm + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + using HostReference = typename cutlass::platform::conditional< + (cutlass::platform::is_same + >::value || + cutlass::platform::is_same + >::value + ), + cutlass::reference::host::SymmComplex< + typename Symm::ElementA, typename Symm::LayoutA, + Symm::kSideModeA, Symm::kFillModeA, + typename Symm::ElementB, typename Symm::LayoutB, + typename Symm::ElementC, typename Symm::LayoutC, + ElementCompute, + ElementAccumulator, + Symm::kBlasMode>, + cutlass::reference::host::Symm< + typename Symm::ElementA, typename Symm::LayoutA, + Symm::kSideModeA, Symm::kFillModeA, + typename Symm::ElementB, typename Symm::LayoutB, + typename Symm::ElementC, typename Symm::LayoutC, + ElementCompute, + ElementAccumulator> + >::type; + + + HostReference reference_symm; + + reference_symm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Symm::SymmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedSymmUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the Symm operator + // + + int batch_stride_A; + if (Symm::kSideModeA == cutlass::SideMode::kLeft) + batch_stride_A = problem_size.m()*problem_size.m(); + if (Symm::kSideModeA == cutlass::SideMode::kRight) + batch_stride_A = problem_size.n()*problem_size.n(); + + typename Symm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + batch_stride_A, + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Symm symm_op; + + size_t workspace_size = Symm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = symm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the Symm + // + + status = symm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_" + << (Symm::kBlasMode == cutlass::BlasMode::kSymmetric ? "symm_" : "hemm_" ) + << "device_" + << "fill_mode_a_" + << (Symm::kSideModeA == cutlass::SideMode::kLeft ? "leftside_" : + (Symm::kSideModeA == cutlass::SideMode::kRight ? "rightside_" : "invalid_")) + << (Symm::kFillModeA == cutlass::FillMode::kLower ? "lower_" : + (Symm::kFillModeA == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Symm::ThreadblockShape::kM << "x" + << Symm::ThreadblockShape::kN << "x" + << Symm::ThreadblockShape::kK << "_" + << Symm::WarpShape::kM << "x" + << Symm::WarpShape::kN << "x" + << Symm::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "alpha: " << ElementCompute(alpha) << "\n" + << "beta: " << ElementCompute(beta) << "\n" + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestsymmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedSymmUniversal testbed; + + using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllSymmUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Symm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentK, + Symm::ThreadblockShape::kK * Symm::kStages - kAlignmentK, + Symm::ThreadblockShape::kK * Symm::kStages * 3 - kAlignmentK + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 3.0 + }; + + double problem_beta[] = { + 0, 2.0 + }; + + + using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + int k = 0; + if (Symm::kSideModeA == cutlass::SideMode::kLeft) + k = m; + else if (Symm::kSideModeA == cutlass::SideMode::kRight) + k = n; + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + #if 0 + // skip very small K problems + if (k / batch_count < 2 * Symm::ThreadblockShape::kK) { + continue; + } + #endif + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedSymmUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..b30acfed6bba547986efd3afa8eb829be2a255e4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h @@ -0,0 +1,606 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide TRMM interface + + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/trmm.h" +#include "cutlass/util/reference/host/trmm_complex.h" +#include "cutlass/core_io.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedTrmmUniversal { + + using ElementA = typename Trmm::ElementA; + using ElementB = typename Trmm::ElementB; + using ElementC = typename Trmm::ElementC; + using ElementAccumulator = typename Trmm::ElementAccumulator; + using ElementCompute = typename Trmm::TrmmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_D; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedTrmmUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_D(init_D_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Trmm::kFillMode, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Trmm::kFillMode, 0, 0.5, mantissa_in_bits); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Helper to initialize a tensor view (pad diagonal fill with zeros for up to alignment on wrong side of diagonal) + template + bool initialize_pad_diagonal_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int alignment) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillPadDiagonalRandomUniform( + view, seed, Trmm::kFillMode, scope_max, scope_min, 0, alignment); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + EXPECT_TRUE(false) << "Gaussian distribution for pad diagonal not implemented"; + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the TRMM workspace + // + + if (Trmm::kSideMode == cutlass::SideMode::kLeft) { + tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); + } + else if (Trmm::kSideMode == cutlass::SideMode::kRight) { + tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); + } + + tensor_B.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + //EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2017)); + //EXPECT_TRUE(initialize_pad_diagonal_tensor(tensor_A.host_view(), init_A, seed + 2017, Trmm::kAlignmentA)); + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2017, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2019, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Trmm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Trmm::ElementB(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_D.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a TRMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha) { + + // + // Verify + // + + using HostReference = typename cutlass::platform::conditional< + (cutlass::platform::is_same + >::value || + cutlass::platform::is_same + >::value + ), + cutlass::reference::host::TrmmComplex< + typename Trmm::ElementA, typename Trmm::LayoutA, + Trmm::kTransformA, + Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, + typename Trmm::ElementB, typename Trmm::LayoutB, + Trmm::kTransformB, + typename Trmm::ElementC, typename Trmm::LayoutC, + ElementCompute, + ElementAccumulator>, + cutlass::reference::host::Trmm< + typename Trmm::ElementA, typename Trmm::LayoutA, + Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, + typename Trmm::ElementB, typename Trmm::LayoutB, + typename Trmm::ElementC, typename Trmm::LayoutC, + ElementCompute, + ElementAccumulator> + >::type; + + + HostReference reference_trmm; + + reference_trmm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Trmm::TrmmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedTrmmUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the TRMM operator + // + + int batch_stride_A; + if (Trmm::kSideMode == cutlass::SideMode::kLeft) + batch_stride_A = problem_size.m()*problem_size.m(); + if (Trmm::kSideMode == cutlass::SideMode::kRight) + batch_stride_A = problem_size.n()*problem_size.n(); + + typename Trmm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_D.device_data(), + batch_stride_A, + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Trmm trmm_op; + + size_t workspace_size = Trmm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = trmm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the TRMM + // + + status = trmm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, alpha); + + if (!passed) { + std::stringstream fname; + + fname << "error_Trmm_device_" + << "fill_mode_" + << (Trmm::kFillMode == cutlass::FillMode::kLower ? "lower_" : + (Trmm::kFillMode == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "side_mode_" + << (Trmm::kSideMode == cutlass::SideMode::kLeft ? "left_" : + (Trmm::kSideMode == cutlass::SideMode::kRight ? "right_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Trmm::ThreadblockShape::kM << "x" + << Trmm::ThreadblockShape::kN << "x" + << Trmm::ThreadblockShape::kK << "_" + << Trmm::WarpShape::kM << "x" + << Trmm::WarpShape::kN << "x" + << Trmm::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestTrmmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0) { + + bool passed = true; + + TestbedTrmmUniversal testbed; + + using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha) + ); + + return passed; +} + +template +bool TestAllTrmmUniversal() { + bool passed = true; + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Trmm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentK, + Trmm::ThreadblockShape::kK * Trmm::kStages - kAlignmentK, + Trmm::ThreadblockShape::kK * Trmm::kStages * 3 - kAlignmentK + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 2.0 + }; + + using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int batch_count : batch_counts) { + for (auto alpha : problem_alpha) { + + int k = 0; + if (Trmm::kSideMode == cutlass::SideMode::kLeft) + k = m; + else if (Trmm::kSideMode == cutlass::SideMode::kRight) + k = n; + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + +#if 0 + // skip very small K problems + if (k / batch_count < 2 * Trmm::ThreadblockShape::kK) { + continue; + } +#endif + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedTrmmUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..00368a5e8eebc128719f64069583010c83dc0c1f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h @@ -0,0 +1,553 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedUniversal { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + bool is_unsigned_int = std::numeric_limits::is_integer && !std::numeric_limits::is_signed; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = is_unsigned_int ? 2 : 1; + scope_min = is_unsigned_int ? 0 : -1; + } else if (bits_output == 16) { + constexpr auto u8_bf16 = + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value) || + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value); + scope_max = is_unsigned_int ? 10 : (u8_bf16 ? 3 : 5); + scope_min = is_unsigned_int ? 0 : (u8_bf16 ? -3 : -5); + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<2> origin(0); + tensor_A.host_view().at(origin) = typename Gemm::ElementA(1); + tensor_B.host_view().at(origin) = typename Gemm::ElementB(1); + tensor_C.host_view().at(origin) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + /* + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + if (Relu) { + for (int i = 0; i < problem_size.m(); ++i) { + for (int j = 0; j < problem_size.n(); ++j) { + reference_D.at(cutlass::MatrixCoord(i, j)) = + ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) + ? (typename Gemm::ElementC)0 + : reference_D.at(cutlass::MatrixCoord(i, j)); + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { +/* + std::cout << "\n-----------------------\n"; + std::cout << "mode: " << (int) mode << "\n"; + std::cout << "problem size: " << problem_size << "\n"; + std::cout << "batch_count: " << batch_count << "\n"; + std::cout << "alpha: " << alpha << "\n"; + std::cout << "beta: " << beta << "\n"; + std::cout << "-----------------------\n\n"; +*/ + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedUniversal testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllGemmUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentM, 512 - 3*kAlignmentM + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1, 2, 3, 5, 7 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + if (k / batch_count < 2 * Gemm::ThreadblockShape::kK) { + continue; + } + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + } + + /* + // large problem with high coverage + for (int split_k_slices = 1; split_k_slices <= 3; ++split_k_slices) { + TestbedUniversal testbed; + + cutlass::gemm::GemmCoord problem_size(72, 56, 8192); + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + split_k_slices, + cutlass::from_real(1.0), + cutlass::from_real(2.0) + ); + + if (!passed) { + break; + } + } + */ + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..89ac33a1028061515d08d50fdb6cce7833ae88ce --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h @@ -0,0 +1,53 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +inline char const *to_string(cutlass::Status status) { + + switch (status) { + case cutlass::Status::kSuccess: return "kSuccess"; + case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; + case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; + case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; + case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; + case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; + case cutlass::Status::kErrorInternal: return "kErrorInternal"; + case cutlass::Status::kInvalid: return "kInvalid"; + default: break; + } + return "invalid"; +} diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h new file mode 100644 index 0000000000000000000000000000000000000000..8b5588f57c40c4e8f8d06adfa9f1e673350fb5e5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h @@ -0,0 +1,609 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed for running device-level GEMMs with absolute maximum calculation and scaling +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" +#include "testbed_sparse.h" +#include "testbed_utils.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename GemmTestbed, + template class ActivationFunctor +> +struct TestbedWithAmax { + + static_assert(std::is_same_v> || std::is_same_v>); + static constexpr bool IsSparseTestbed = std::is_same_v>; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + using ElementScalingFactor = typename Gemm::EpilogueOutputOp::ElementScalingFactor; + using ElementAbsmax = typename Gemm::EpilogueOutputOp::ElementAbsmax; + + static bool const kScaleAux = Gemm::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; + static bool const kScaleOutput = Gemm::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; + bool doScaleA; + bool doScaleB; + bool doScaleC; + + GemmTestbed underlying_testbed; + + cutlass::HostTensor tensor_Aux; + cutlass::HostTensor tensor_Vector; + cutlass::HostTensor tmp_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // + // Methods + // + + TestbedWithAmax( + bool scaleA = true, + bool scaleB = true, + bool scaleC = true, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform + ): + doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC), + underlying_testbed(init_A_, init_B_, init_C_) { } + + /// Helper to initialize scaling factors + template + bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { + cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + underlying_testbed.initialize(problem_size); + + tensor_Vector.resize({1, problem_size.n()}); + reference_D.resize(problem_size.mn(), false); + tmp_D.resize(problem_size.mn(), false); + + EXPECT_TRUE( + underlying_testbed.initialize_tensor(tensor_Vector.host_view(), underlying_testbed.init_C, underlying_testbed.seed + 2020) + ); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<2> origin(0); + tensor_Vector.host_view().at(origin) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), underlying_testbed.tensor_C.host_view()); + + tensor_Vector.sync_device(); + + int scale_bits = 2; + if (doScaleA) { + scale_A.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), underlying_testbed.seed + 2021, scale_bits)); + scale_A.sync_device(); + } + + if (doScaleB) { + scale_B.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), underlying_testbed.seed + 2022, scale_bits)); + scale_B.sync_device(); + } + + if (doScaleC) { + scale_C.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), underlying_testbed.seed + 2023, scale_bits)); + scale_C.sync_device(); + } + + if (kScaleOutput) { + scale_D.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), underlying_testbed.seed + 2024, scale_bits)); + scale_D.sync_device(); + + abs_max_D.resize({1, 1}); + cutlass::reference::host::TensorFill(abs_max_D.host_view()); + abs_max_D.sync_device(); + + reference_abs_max_D.resize({1, 1}); + } + + if (kScaleAux) { + tensor_Aux.resize(problem_size.mn()); + cutlass::reference::host::TensorFill(tensor_Aux.host_view()); + tensor_Aux.sync_device(); + + scale_Aux.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), underlying_testbed.seed + 2025, scale_bits)); + scale_Aux.sync_device(); + + abs_max_Aux.resize({1, 1}); + cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); + abs_max_Aux.sync_device(); + + reference_Aux.resize(problem_size.mn(), false); + reference_abs_max_Aux.resize({1, 1}); + } + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + underlying_testbed.tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), underlying_testbed.tensor_D.host_view()); + if (!passed) { + std::cout << "Comparison of D failed" << std::endl; + } + + if (kScaleAux) { + tensor_Aux.sync_host(); + abs_max_Aux.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); + if (!cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view())) { + passed = false; + std::cout << "Comparison of Aux failed" << std::endl; + } + if (!cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view())) { + passed = false; + std::cout << "Comparison of Aux absmax failed" << std::endl; + } + } + + if (kScaleOutput) { + abs_max_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); + if (!cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view())) { + passed = false; + std::cout << "Comparison of D absmax failed" << std::endl; + } + } + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + std::ofstream file("testbed_with_amax_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << underlying_testbed.tensor_A.host_view() + << "\nB =\n" << underlying_testbed.tensor_B.host_view() + << "\nC =\n" << underlying_testbed.tensor_C.host_view() + << "\nVector =\n" << tensor_Vector.host_view() + << "\nScaleA = " << scale_A.host_view() + << "\nScaleB = " << scale_B.host_view() + << "\nScaleC = " << scale_C.host_view() + << "\nScaleD = " << scale_D.host_view() + << "\nScaleAux = " << scale_Aux.host_view() + << "\n\nReference D =\n" << reference_D.host_view() + << "\nComputed D =\n" << underlying_testbed.tensor_D.host_view(); + if (kScaleAux) { + file + << "\n\nReference Aux =\n" << reference_Aux.host_view() + << "\nComputed Aux =\n" << tensor_Aux.host_view() + << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() + << "\nComputed Absmax Aux = " << abs_max_Aux.host_view(); + } + if (kScaleOutput) { + file + << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() + << "\nComputed Absmax D = " << abs_max_D.host_view(); + } + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + cutlass::Coord<2> origin(0); + ElementCompute scaled_alpha = alpha; + if (doScaleA) { + scaled_alpha *= scale_A.host_view().at(origin); + } + if (doScaleB) { + scaled_alpha *= scale_B.host_view().at(origin); + } + + ElementCompute scaled_beta = beta; + if (doScaleC) { + scaled_beta *= scale_C.host_view().at(origin); + } + + // + // Verify + // + + auto ref_tA = [&](){ + if constexpr (IsSparseTestbed) { + cutlass::uncompress( + underlying_testbed.tensor_A_uncompressed.host_ref(), + underlying_testbed.tensor_A.host_ref(), + underlying_testbed.tensor_E.host_ref(), + problem_size.m(), + problem_size.k() + ); + return underlying_testbed.tensor_A_uncompressed.host_ref(); + } + else { + return underlying_testbed.tensor_A.host_ref(); + } + }(); + + // Run reference kernel with ElementOutput of type ElementAccumulator + // so that we can compute the absmax epilogue on data that is of type + // ElementAccumulator (which is what the GEMM we are testing will do). + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator, ElementAccumulator + >( + problem_size, + scaled_alpha, + ref_tA, + Gemm::kTransformA, + underlying_testbed.tensor_B.host_ref(), + Gemm::kTransformB, + scaled_beta, + underlying_testbed.tensor_C.host_ref(), + tmp_D.host_ref(), + ElementAccumulator(0) + ); + + ElementCompute tmp_abs_max_Aux(0.); + ElementCompute tmp_abs_max_D(0.); + + cutlass::NumericConverter cvt_c_to_compute; + cutlass::NumericConverter cvt_accum_to_compute; + cutlass::NumericConverter cvt_compute_to_absmax; + cutlass::NumericConverter cvt_compute_to_d; + cutlass::NumericConverter cvt_compute_to_aux; + + cutlass::absolute_value_op abs; + cutlass::maximum_with_nan_propogation max; + ActivationFunctor act; + + ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); + + for (int m = 0; m < problem_size.m(); ++m) { + for (int n = 0; n < problem_size.n(); ++n) { + ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({m, n})); + ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, n})); + ElementCompute aux = intermediate + bias; + ElementCompute d = act(aux); + tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); + tmp_abs_max_D = max(abs(d), tmp_abs_max_D); + reference_D.host_view().at({m, n}) = cvt_compute_to_d(d * d_scale); + + if (kScaleAux) { + reference_Aux.host_view().at({m, n}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); + } + } + } + + if (kScaleAux) { + reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux); + } + + if (kScaleOutput) { + reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D); + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + return underlying_testbed.sufficient(); + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta}; + typename Gemm::EpilogueOutputOp::Params epilogue_params{ + activation_params, + scale_A.device_data(), + scale_B.device_data(), + scale_C.device_data(), + scale_D.device_data(), + scale_Aux.device_data(), + abs_max_Aux.device_data(), + abs_max_D.device_data() + }; + + auto arguments = [&]() { + if constexpr (IsSparseTestbed) { + return typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + batch_count, + epilogue_params, + underlying_testbed.tensor_A.device_data(), + underlying_testbed.tensor_B.device_data(), + underlying_testbed.tensor_C.device_data(), + underlying_testbed.tensor_D.device_data(), + underlying_testbed.tensor_E_reordered.device_data(), + tensor_Aux.device_data(), + tensor_Vector.device_data(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + underlying_testbed.tensor_A.layout().stride(0), + underlying_testbed.tensor_B.layout().stride(0), + underlying_testbed.tensor_C.layout().stride(0), + underlying_testbed.tensor_D.layout().stride(0), + underlying_testbed.tensor_E_reordered.layout().stride(0), + tensor_Aux.layout().stride(0), + 0 // stride vector + }; + } + else { + return typename Gemm::Arguments{ + mode, + problem_size, + batch_count, + epilogue_params, + underlying_testbed.tensor_A.device_data(), + underlying_testbed.tensor_B.device_data(), + underlying_testbed.tensor_C.device_data(), + underlying_testbed.tensor_D.device_data(), + tensor_Aux.device_data(), + tensor_Vector.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + 0, // stride vector + underlying_testbed.tensor_A.layout().stride(0), + underlying_testbed.tensor_B.layout().stride(0), + underlying_testbed.tensor_C.layout().stride(0), + underlying_testbed.tensor_D.layout().stride(0), + (int64_t)0 // Leading dimension of vector. This must be 0 + }; + } + }(); + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + cudaError_t cuda_error = cudaDeviceSynchronize(); + EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename GemmTestbed, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAllGemmWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) { + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int constexpr kAlignmentM = [&]() { + if constexpr (std::is_same_v>) { + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the reordering of operand E + return std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), + kMinimumOperandElementSize); + } + else { + return 128 / kMinimumOperandElementSize; + } + }(); + + int const kAlignmentN = 128 / kMinimumOperandElementSize; + + int M_problems[] = {kAlignmentM, 128 + 32}; + int N_problems[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + int K_problems[] = {Gemm::ThreadblockShape::kK * 2}; + double alpha_problems[] = {1.}; + double beta_problems[] = {0.}; + int split_k_slices[] = { + 1, 2 + }; + + bool passed = true; + + for (int M : M_problems) { + for (int N : N_problems) { + for (int K : K_problems) { + for (int split_k : split_k_slices) { + if (cutlass::sizeof_bits_v <= 8 && split_k > 1) { + // Don't test split-K with FP8 output. The kernel being tested will writie partial accumulations + // for different splits to global memory in FP8, while the reference kernel will not. This leads + // to mismatches that are difficult to capture without a permissive relative equality check threshold. + continue; + } + + for (double alpha : alpha_problems) { + for (double beta : beta_problems) { + TestbedWithAmax testbed(scaleA, scaleB, scaleC); + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + EXPECT_TRUE(passed) + << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta << ", split_k:" << split_k; + + if (!passed) { + + return passed; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h new file mode 100644 index 0000000000000000000000000000000000000000..8e939f9710403a5f5c3fd8c61e34c4e8021ff423 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h @@ -0,0 +1,358 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/core_io.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "cutlass/gemm/kernel/default_gemv.h" +#include "cutlass/gemm/kernel/gemv_batched_strided.h" + +namespace test { +namespace gemm { +namespace kernel { + +template +void batched_gemv_kernel_test(cutlass::gemm::BatchedGemmCoord problem_size, + ElementCD_ alpha = ElementCD_(1), + ElementCD_ beta = ElementCD_(0), + bool perf_test = false, + int perf_test_iter = 1) +{ + using ThreadBlockShape = ThreadBlockShape_; + using ThreadShape = ThreadShape_; + using ElementA = ElementAB_; + using LayoutA = LayoutA_; + using ElementB = ElementAB_; + using LayoutB = LayoutB_; + using ElementAccumulator = ElementCD_; + using ElementCD = ElementCD_; + using LayoutCD = LayoutCD_; + + using GemvKernel = cutlass::gemm::kernel::DefaultGemv; + + using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv; + using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle; + + if (DEBUG) + { + problem_size = cutlass::gemm::BatchedGemmCoord( + problem_size.m(), problem_size.n(), problem_size.k(), 1); + } + + // Create host tensors that will be the backing store for the batches + // Note that no device memory is initially allocated + cutlass::HostTensor matrix_A({problem_size.m(), problem_size.k()}, false); + cutlass::HostTensor matrix_B({problem_size.k(), problem_size.n()}, false); + cutlass::HostTensor matrix_C_computed({problem_size.m(), problem_size.n()}, false); + cutlass::HostTensor matrix_C_reference({problem_size.m(), problem_size.n()}, false); + + // Reserve memory for the batch of tensors + matrix_A.reserve(problem_size.m()*problem_size.k()*problem_size.batch()); + matrix_B.reserve(problem_size.n()*problem_size.k()*problem_size.batch()); + matrix_C_computed.reserve(problem_size.m()*problem_size.n()*problem_size.batch()); + matrix_C_reference.reserve(problem_size.m()*problem_size.n()*problem_size.batch(), false); + + // Fill eatch tensor batch + const int seed = 9876; + for (int b = 0; b < problem_size.batch(); b++) + { + if(DEBUG) + { + cutlass::reference::host::BlockFillSequential( + matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity()); + cutlass::reference::host::BlockFillSequential( + matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity()); + } + else + { + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(b*matrix_A.capacity()), + seed + 1660, + 8, + -8, + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(b*matrix_B.capacity()), + seed + 1880, + 8, + -8, + 0 + ); + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity())); + cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity())); + } + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + ThreadBlockSwizzle swizzle; + + cutlass::gemm::BatchedGemmCoord tiled_size{ThreadBlockShape::kM, + ThreadBlockShape::kN, + problem_size.k(), // no split-k + DEBUG ? 1 : THREAD_B }; + + cutlass::gemm::BatchedGemmCoord tiled_shape = swizzle.get_tiled_shape(problem_size, tiled_size); + + #if 0 + printf("tiled_size = %d %d %d %d\n", tiled_size.m(), tiled_size.n(), tiled_size.k(), tiled_size.batch()); + printf("tiled_shape = %d %d %d %d\n", tiled_shape.m(), tiled_shape.n(), tiled_shape.k(), tiled_shape.batch()); + #endif + + // No split-k + EXPECT_EQ(tiled_size.k(), problem_size.k()); + + dim3 grid = swizzle.get_grid_shape(tiled_shape); + dim3 block(tiled_size.n() / ThreadShape::kN, tiled_size.batch(), tiled_size.k() / problem_size.k()); + + // Some sanity checks + EXPECT_TRUE( block.x*block.y*block.z <= 1024 ); + EXPECT_TRUE( block.x <= 1024 ); + EXPECT_TRUE( block.y <= 1024 ); + EXPECT_TRUE( block.z <= 64 ); + + #if 0 + printf("grid dim = %d, %d, %d\n", grid.x, grid.y, grid.z); + printf("block dim = %d, %d, %d\n", block.x, block.y, block.z); + #endif + + cudaError_t result; + cudaEvent_t start_event, end_event; + + for (int iter = 0; iter < (perf_test ? (perf_test_iter+1) : 1); ++iter) + { + if (perf_test && iter == 1) + { + result = cudaEventCreate(&start_event); + EXPECT_EQ(result, cudaSuccess); + + result = cudaEventCreate(&end_event); + EXPECT_EQ(result, cudaSuccess); + + result = cudaEventRecord(start_event); + EXPECT_EQ(result, cudaSuccess); + } + + if (beta == ElementCD(0)) + { + if (alpha == ElementCD(1)) + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + else + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + alpha, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + } + else + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + alpha, + beta, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + + if (iter == 0) + { + result = cudaGetLastError(); + EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); + } + } + + if (perf_test) + { + result = cudaEventRecord(end_event); + EXPECT_EQ(result, cudaSuccess); + } + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); + + if (perf_test) + { + float ms; + result = cudaEventElapsedTime(&ms, start_event, end_event); + EXPECT_EQ(result, cudaSuccess); + + double flops = (double(problem_size.m()) * + double(problem_size.n()) * + double(problem_size.k()) * + double(problem_size.batch()) * 2); // 2 for MAC + + double read_bytes = double(problem_size.batch()) * (sizeof(ElementA)*double(problem_size.m())*double(problem_size.k()) + + sizeof(ElementB)*double(problem_size.k())*double(problem_size.n())); + + double write_bytes = double(problem_size.batch()) * (sizeof(ElementCD)*double(problem_size.m())*double(problem_size.n())); + + double avg_runtime = double(ms) / perf_test_iter; + double gflops_per_sec = flops / 1.0e6 / avg_runtime; + double read_bandwidth = read_bytes / 1.0e6 / avg_runtime; + double write_bandwidth = write_bytes / 1.0e6 / avg_runtime; + + std::cout << "\n\nProblem size: " + << problem_size.m() + << " x " << problem_size.n() + << " x " << problem_size.k() + << " x " << problem_size.batch() + << std::endl; + + std::cout << " GFLOPs: " << gflops_per_sec << std::endl; + std::cout << "BW (R/W): " << read_bandwidth << " / " << write_bandwidth << " GB/sec" << std::endl; + std::cout << " Runtime: " << avg_runtime << " ms" << std::endl; + } + else + { + matrix_C_computed.sync_host(); + + // Compute the batched gemms + for (int b = 0; b < problem_size.batch(); b++) + { + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size.mnk(), alpha, + matrix_A.host_ref(b * matrix_A.capacity()), + matrix_B.host_ref(b * matrix_B.capacity()), beta, + matrix_C_reference.host_ref(b * matrix_C_computed.capacity())); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(b * matrix_C_computed.capacity()), + matrix_C_reference.host_view(b * matrix_C_reference.capacity())); + + EXPECT_TRUE(passed) + //<< "A:\n" << matrix_A.host_view() << "\n" + //<< "B:\n" << matrix_B.host_view() << "\n" + << "Batch: " << b << "\n" + << "Reference:\n" + << matrix_C_reference.host_view(b * matrix_C_reference.capacity()) + << "\n" + << "Computed:\n" + << matrix_C_computed.host_view(b * matrix_C_computed.capacity()) + << "\n"; + } + } +} + +template +void batched_gemv_kernel_perf_test(cutlass::gemm::BatchedGemmCoord problem_size, + ElementCD_ alpha = ElementCD_(1), + ElementCD_ beta = ElementCD_(0), + int iter = 50) +{ + batched_gemv_kernel_test(problem_size, alpha, beta, true, iter); +} + +} // namespace threadblock +} // namespace kernel +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h new file mode 100644 index 0000000000000000000000000000000000000000..6e3d6ab079d44345f2f55f4126ba3efc1eba47cb --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h @@ -0,0 +1,232 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/gemm/thread/mma.h" +#include "cutlass/layout/vector.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace test { +namespace gemm { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +void kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + using Btype = typename Mma::ElementB; + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK), false); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN), false); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run() { + + // + // initialize device memory + // + + cutlass::reference::host::detail::RandomUniformFunc< ElementA > tfill_rand_func( + 0, // seed + 10, // max + 0, // min + 0); // bits after decimal + + cutlass::reference::host::detail::TensorFillRandomUniformFunc< ElementA, LayoutA > tfill_rand( + tensor_A.host_view(), + tfill_rand_func); + + for (auto i=0; i< Shape::kM; i++) + for (auto j=0; j< Shape::kK; j++) + tfill_rand(cutlass::make_Coord(i,j)); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + + // Host side call + kernel( + tensor_D_computed.host_data(), + tensor_A.host_data(), + tensor_B.host_data(), + tensor_C.host_data()); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace gemm +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..8d34d7992b57cefa0eaf7300a5e1fb49f41a93e2 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/gemm/thread/mma.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace test { +namespace gemm { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +__global__ void kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run() { + + // + // initialize device memory + // + + cutlass::reference::host::BlockFillSequential( + tensor_A.host_data(), + tensor_A.capacity() + ); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + //tensor_D_reference.fill(tensor_C.host_view()); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace gemm +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..1f3bc8cf114d7eb2ac00bd19ae92c984558b7228 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h @@ -0,0 +1,435 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +namespace test { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma_sparse(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc, + typename Mma::IteratorE::Params params_E, + typename Mma::IteratorE::TensorRef ref_E) { + // Shared storage needed by threadblock-scoped matrix multiply- + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k() / Mma::kSparse}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + cutlass::MatrixCoord tb_offset_E{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k() / Mma::kSparse}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k() / Mma::kSparse}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorE iterator_E( + params_E, ref_E.data(), + {problem_size.m(), + problem_size.k() / Mma::kSparse / Mma::kElementsPerElementE}, + tb_thread_id, tb_offset_E); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_E, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct SparseTestbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ElementE = typename MmaCore::ElementE; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using ThreadMapE = typename MmaCore::IteratorThreadMapE; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + using AccessTypeE = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + static cutlass::arch::CacheOperation::Kind const CacheOpE = + MmaCore::kCacheOpE; + + static int const Sparse = MmaCore::kSparse; + static int const MetaSizeInBits = MmaCore::kMetaSizeInBits; + static int const MaxID2 = MmaCore::kMaxID2; + + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = typename MmaCore::GmemLayoutE; + + static int const ElementsPerElementE = MmaCore::kElementsPerElementE; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define iterators over tiles from the E operand + using IteratorE = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementE, ReorderedLayoutE, 1, ThreadMapE, AccessTypeE>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::SparseMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, + LayoutC, IteratorE, typename MmaCore::SmemIteratorE, CacheOpE, + typename MmaCore::MmaPolicy, Stages>; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_A_uncompressed; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_E; + cutlass::HostTensor matrix_E_reordered; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + SparseTestbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k / Sparse)); + matrix_A_uncompressed.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_E.reset(cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); + matrix_E_reordered.reset( + cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + + // Waive the test + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + matrix_E.host_view(), seed, MetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(matrix_E.host_view(), + (ElementE)(content)); + } else { + return false; + } + + cutlass::reorder_meta(matrix_E_reordered.host_ref(), matrix_E.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / Sparse / ElementsPerElementE}); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + matrix_E_reordered.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + typename IteratorE::Params params_E(matrix_E_reordered.layout()); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma_sparse, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return true; + } + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma_sparse, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + return true; + } + } + + test::gemm::threadblock::kernel_multistage_mma_sparse + <<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), params_E, + matrix_E_reordered.device_ref()); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::uncompress(matrix_A_uncompressed.host_ref(), matrix_A.host_ref(), + matrix_E.host_ref(), problem_size.m(), + problem_size.k()); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm(problem_size, ElementC(alpha), + matrix_A_uncompressed.host_view(), matrix_B.host_view(), + ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + + std::cout + << __FILE__ << ":" << __LINE__ << " " + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "E:\n" << matrix_E.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..5caaf38ace92758bbc86970d8d4ff339d87348ab --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h @@ -0,0 +1,372 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, + LayoutC, typename MmaCore::MmaPolicy, Stages>; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + } + + test::gemm::threadblock::kernel_multistage_mma + <<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0)); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::Gemm reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cout + << __FILE__ << ":" << __LINE__ << " " + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h new file mode 100644 index 0000000000000000000000000000000000000000..4e617d6327594570b1a88a5b28f2ec4d0467b534 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h @@ -0,0 +1,387 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC **ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + int lane_id = threadIdx.x; + + int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); + + int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_idx_mn % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_idx_mn / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, CacheOpA, + IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, LayoutC, + typename MmaCore::MmaPolicy, Stages>; + + static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed[kPartitionsK]; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_C_pointers; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); + + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); + + matrix_C_pointers.sync_device(); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributeMaxDynamicSharedMemorySize error: " + << cudaGetErrorString(result); + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributePreferredSharedMemoryCarveout error: " + << cudaGetErrorString(result); + } + + test::gemm::threadblock::kernel_multistage_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_pointers.device_data(), + matrix_C_computed[0].layout().stride(0)); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_host(); + + // TODO: this is temporary. it will be removed after slicing can de + // reduction + // + // Reduce matrix_C_computed + // + CUTLASS_PRAGMA_UNROLL + for(int k = 1; k < kPartitionsK; k++) { + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ + matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); + } + } + } + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_multistage_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed[0].host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..7eb62f9a39fe4472f77446efc591267001758c58 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h @@ -0,0 +1,353 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_, + /// Number of stages + int Stages = 2> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + static const int kStages = Stages; + + // Define iterators over tiles from the A operand + static const bool use_idp4a = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value; + + static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; + static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; + + using IteratorA = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> + >::type; + + // Define iterators over tiles from the B operand + using IteratorB = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> + >::type; + + // Define MmaPipeline Single Stage + using MmaPipelineSingleStage = cutlass::gemm::threadblock::MmaSingleStage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + // Define MmaPipeline Two Stages + using MmaPipelineTwoStages = cutlass::gemm::threadblock::MmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + // Define the threadblock-scoped pipelined matrix multiply (Select between Single vs. Two stages) + using Mma = typename cutlass::platform::conditional<(kStages==1), MmaPipelineSingleStage, MmaPipelineTwoStages>::type; + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_, float beta_) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + bool sufficient() { + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + test::gemm::threadblock::kernel_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0)); + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result) << " on device " << GetCudaDevice(); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed) << "Failed on device " << GetCudaDevice(); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h new file mode 100644 index 0000000000000000000000000000000000000000..36e55b2542b2258542336a052cdd14bf4b85f78d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h @@ -0,0 +1,370 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC **ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); + + + int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_idx_mn % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_idx_mn / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + + // Define iterators over tiles from the A operand + static const bool use_idp4a = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value; + + static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; + static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; + + using IteratorA = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> + >::type; + + // Define iterators over tiles from the B operand + using IteratorB = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> + >::type; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed[kPartitionsK]; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_C_pointers; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_, float beta_) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); + + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); + + matrix_C_pointers.sync_device(); + + test::gemm::threadblock::kernel_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_pointers.device_data(), + matrix_C_computed[0].layout().stride(0)); + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_host(); + + // TODO: this is temporary. it will be removed after slicing can de + // reduction + // + // Reduce matrix_C_computed + // + CUTLASS_PRAGMA_UNROLL + for(int k = 1; k < kPartitionsK; k++) { + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ + matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); + } + } + } + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed[0].host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e5fdc07769726353b33c1a5da65dedfadb4ce1e7 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma_planar_complex( + cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::Element *ptr_A, + int64_t imaginary_stride_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::Element *ptr_B, + int64_t imaginary_stride_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc, int64_t imaginary_stride_C) { + + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A operand + typename Mma::IteratorA iterator_A_real(params_A, ptr_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorA iterator_A_imag(params_A, ptr_A + imaginary_stride_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + // Construct iterators to B operand + typename Mma::IteratorB iterator_B_real(params_B, ptr_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorB iterator_B_imag(params_B, ptr_B + imaginary_stride_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum.real); + + iterator_C.store_with_pointer_offset(accum.imag, imaginary_stride_C); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename Mma_> +struct TestbedPlanarComplex { + + using Mma = Mma_; + using ThreadblockShape = typename Mma::Shape; + using IteratorA = typename Mma::IteratorA; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using IteratorB = typename Mma::IteratorB; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Mma::ElementC; + using ElementAccumulator = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + using ThreadMapA = typename Mma::IteratorA::ThreadMap; + using ThreadMapB = typename Mma::IteratorB::ThreadMap; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = Mma::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + Mma::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + Mma::kCacheOpB; + + // + // Data members + // + + cutlass::HostTensorPlanarComplex matrix_A; + cutlass::HostTensorPlanarComplex matrix_B; + cutlass::HostTensorPlanarComplex matrix_C_computed; + cutlass::HostTensorPlanarComplex matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + + // + // Methods + // + + /// Allocates workspace in device memory + TestbedPlanarComplex(int m, int n, int k) + : problem_size(m, n, k) { + + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + + } else if (init_A == cutlass::Distribution::Sequential) { + + for (int i = 0; i < matrix_A.capacity() * 2; ++i) { + matrix_A.host_data()[i] = cutlass::half_t(float(i % 5) - 2); + } + /* + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity() * 2); + */ + } else if (init_A == cutlass::Distribution::Identity) { + //cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + + + } else if (init_B == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity() * 2); + + for (int i = 0; i < matrix_B.capacity() * 2; ++i) { + matrix_B.host_data()[i] = cutlass::half_t(float((i + 3) % 5) - 2); + } + + + } else if (init_B == cutlass::Distribution::Identity) { + + //cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + + } else { + return false; + } + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + test::gemm::threadblock::kernel_mma_planar_complex<<>>( + problem_size, + params_A, + matrix_A.device_data(), + matrix_A.imaginary_stride(), + params_B, + matrix_B.device_data(), + matrix_B.imaginary_stride(), + matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), + matrix_C_computed.imaginary_stride() + ); + + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + cutlass::complex(ElementAccumulator(1)), + matrix_A.host_ref(), + Mma::kTransformA, + matrix_B.host_ref(), + Mma::kTransformB, + cutlass::complex(ElementAccumulator(0)), + matrix_C_reference.host_ref(), + matrix_C_reference.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), + matrix_C_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..921d1abdc40c2040104815cfffb8b2ea32384136 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h @@ -0,0 +1,1543 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_types.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/platform/platform.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +namespace test { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void kernel( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); + typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + FragmentA frag_A; + FragmentB frag_B; + + FragmentC accum; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(frag_A); + iter_B.load(frag_B); + + ++iter_A; + ++iter_B; + + mma(accum, frag_A, frag_B, accum); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The inner product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + + cutlass::reference::host::BlockFillRandomUniform(tensor_A.host_data(), + tensor_A.capacity(), seed, scope_max, scope_min, 0); + + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + + cutlass::reference::host::BlockFillRandomUniform(tensor_B.host_data(), + tensor_B.capacity(), seed, scope_max, scope_min, 0); + + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] + << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] + << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_ +> +struct TestbedComplex { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TestbedComplex() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), + seed, 8, -8, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), + seed + 16, 8, -8, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::GemmComplex( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + Mma::kTransformA, + tensor_B.host_ref(), + Mma::kTransformB, + ElementC(0), + tensor_C.host_ref(), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void kernel_transform( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + + using TransformedFragmentA = typename Mma::TransformedFragmentA; + using TransformedFragmentB = typename Mma::TransformedFragmentB; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); + typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + FragmentA loaded_frag_A; + FragmentB loaded_frag_B; + TransformedFragmentA transformed_frag_A; + TransformedFragmentB transformed_frag_B; + + FragmentC accum; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(loaded_frag_A); + iter_B.load(loaded_frag_B); + + ++iter_A; + ++iter_B; + + mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, + loaded_frag_B); + + mma(accum, transformed_frag_A, transformed_frag_B, accum); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct TransformTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformTestbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<>>( + tensor_D_computed.device_data(), tensor_A.device_data(), + tensor_B.device_data(), tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_ +> +struct TransformedTestbedComplex { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformedTestbedComplex() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), + seed, 8, -8, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), + seed + 16, 8, -8, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::GemmComplex( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + Mma::kTransformA, + tensor_B.host_ref(), + Mma::kTransformB, + ElementC(0), + tensor_C.host_ref(), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void sparse_kernel( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + typename Mma::ElementE const *input_E, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer + smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementE, Mma::Shape::kM * Mma::Shape::kK / + Mma::kSparse / Mma::kElementsPerElementE> + smem_buffer_E; + + __syncthreads(); + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + + typename Mma::ElementE *smem_ptr_E = smem_buffer_E.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_E.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_E, i) = + cutlass::ReferenceFactory::type>::get(input_E, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + using FragmentE = typename Mma::FragmentE; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed( + {ThreadblockShape::kM, ThreadblockShape::kK / Mma::kSparse}); + typename Mma::LayoutB layout_B = + Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + typename Mma::LayoutE layout_E = + Mma::LayoutE::packed({Mma::Shape::kM * Mma::kInterleaved, + Mma::Shape::kK / Mma::kSparse / + Mma::kElementsPerElementE / Mma::kInterleaved}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E}, cutlass::arch::LaneId()); + + FragmentA frag_A; + FragmentB frag_B; + + FragmentC accum; + + FragmentE frag_E; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(frag_A); + iter_B.load(frag_B); + iter_E.load(frag_E); + + ++iter_A; + ++iter_B; + ++iter_E; + + mma(accum, frag_A, frag_B, accum, frag_E); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct SparseTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + static int const Sparse = Mma::kSparse; + static int const MetaSizeInBits = Mma::kMetaSizeInBits; + static int const MaxID2 = Mma::kMaxID2; + static int const Interleaved = Mma::kInterleaved; + + using ElementE = typename Mma::ElementE; + + static int const ElementsPerElementE = Mma::kElementsPerElementE; + + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = + cutlass::layout::ColumnMajorInterleaved; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_uncompressed; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_E_reordered; + + // + // Methods + // + + /// Allocates workspace in device memory + SparseTestbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, + ThreadblockShape::kK / Sparse)); + tensor_A_uncompressed.reset( + cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_E.reset(cutlass::make_Coord( + Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); + tensor_E_reordered.reset(cutlass::make_Coord( + Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_E.host_view(), seed, MetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(tensor_E.host_view(), + (ElementE)(content)); + } else { + return false; + } + + cutlass::reorder_meta( + tensor_E_reordered.host_ref(), tensor_E.host_ref(), + {Shape::kM, Shape::kN, Shape::kK / Sparse / ElementsPerElementE}); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_E_reordered.sync_device(); + + // launch kernel + sparse_kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_E_reordered.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), + tensor_E.host_ref(), Shape::kM, Shape::kK); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A_uncompressed.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "A:\n" << tensor_A.host_view() << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "B:\n" << tensor_B.host_view() << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "E:\n" << tensor_E.host_view() << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h new file mode 100644 index 0000000000000000000000000000000000000000..3311e915db892466a9a4c52c82d100c2e1319966 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h @@ -0,0 +1,43 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace nvrtc { + +extern char const *kCutlassHeaders[]; +extern char const *kCutlassHeaderNames[]; +extern size_t const kCutlassHeaderCount; +} // namespace nvrtc +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55df44379c847034ed38cfab23477331ee4a537c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + + +namespace nvrtc { +namespace thread { + +template< + typename ElementA, typename ElementB, typename ElementC, + typename TileShape, typename ClusterShape, + bool kTransA, bool kTransB, + int RANK_M, int RANK_N, int RANK_K, int RANK_L +> +struct ContractionKernel { + +using ElementScalar = float; +using ElementAccum = float; +using EpilogueThread = cutlass::epilogue::thread::LinearCombination; + +static constexpr cute::GMMA::Major majorA = ! kTransA ? cute::GMMA::Major::MN : cute::GMMA::Major::K; +static constexpr cute::GMMA::Major majorB = ! kTransB ? cute::GMMA::Major::K : cute::GMMA::Major::MN; + +/// Kernel config +typedef int64_t stride_type; +typedef int32_t extent_type; + +static constexpr const stride_type* stride_null = nullptr; +static constexpr const extent_type* extent_null = nullptr; + +template +static constexpr +auto +make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) { + static_assert(Rank > 1); + if constexpr (IsMajor) { + return cute::transform(cute::make_seq{}, [&](auto i) { + if constexpr (i == 0) { + return cute::Int<1>{}; + } + else { + return i < n ? t[i] : init_default; + } + }); + } + else { + return cute::make_int_tuple(t, n, init_default); + } +} + +using StrideA = decltype(cute::make_stride( + make_stride_tuple(stride_null, 0, 0), + make_stride_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using StrideB = decltype(cute::make_stride( + make_stride_tuple(stride_null, 0, 0), + make_stride_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using StrideC = decltype(cute::make_stride( + cute::make_int_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using ProblemShape = decltype(cute::make_shape( + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0))); + +using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 16 / sizeof(ElementA), + ElementB, StrideB, 16 / sizeof(ElementB), + ElementAccum, + TileShape, ClusterShape, cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized +>::CollectiveOp; + +using EpilogueOutputOp = cutlass::epilogue::collective::DefaultEpilogue; +using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter; +using Kernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveOp, + CollectiveEpilogue>; + +}; + +} // namespace nvrtc +} // namespace thread diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..576f55cd868cd64c8c09c055d8b9a956e40c87ae --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/array.h" + +namespace test { +namespace nvrtc { +namespace kernel { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +__global__ void testbed_kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +} +} +} +} + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h new file mode 100644 index 0000000000000000000000000000000000000000..c7e6e94691c82b2f343959421c884c8b0b06f9b4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h @@ -0,0 +1,30 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h new file mode 100644 index 0000000000000000000000000000000000000000..5ba5432fd568af71e15b20b8cdab1571f303bcdf --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +typedef char int8_t; +typedef unsigned char uint8_t; +typedef short int16_t; +typedef unsigned short uint16_t; +typedef int int32_t; +typedef unsigned int uint32_t; +typedef long long int int64_t; +typedef unsigned long long int uint64_t; + +#if defined __x86_64__ && !defined __ILP32__ +# define __WORDSIZE 64 +#else +# define __WORDSIZE 32 +#endif + + +/* Small types. */ + +/* Signed. */ +typedef signed char int_least8_t; +typedef short int int_least16_t; +typedef int int_least32_t; +#if __WORDSIZE == 64 +typedef long int int_least64_t; +#else +__extension__ +typedef long long int int_least64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_least8_t; +typedef unsigned short int uint_least16_t; +typedef unsigned int uint_least32_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_least64_t; +#else +__extension__ +typedef unsigned long long int uint_least64_t; +#endif + + +/* Fast types. */ + +/* Signed. */ +typedef signed char int_fast8_t; +#if __WORDSIZE == 64 +typedef long int int_fast16_t; +typedef long int int_fast32_t; +typedef long int int_fast64_t; +#else +typedef int int_fast16_t; +typedef int int_fast32_t; +__extension__ +typedef long long int int_fast64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_fast8_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_fast16_t; +typedef unsigned long int uint_fast32_t; +typedef unsigned long int uint_fast64_t; +#else +typedef unsigned int uint_fast16_t; +typedef unsigned int uint_fast32_t; +__extension__ +typedef unsigned long long int uint_fast64_t; +#endif + +/* Types for `void *' pointers. */ +#if __WORDSIZE == 64 +# ifndef __intptr_t_defined +typedef long int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned long int uintptr_t; +#else +# ifndef __intptr_t_defined +typedef int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned int uintptr_t; +#endif + + +/* Largest integral types. */ +#if __WORDSIZE == 64 +typedef long int intmax_t; +typedef unsigned long int uintmax_t; +#else +__extension__ +typedef long long int intmax_t; +__extension__ +typedef unsigned long long int uintmax_t; +#endif + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..8fd6863e8fa003d3fbc4e0b498e3b9b454ade190 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h @@ -0,0 +1,398 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/gemm/thread/mma.h" +#include "../kernel/thread/testbed_kernel.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/trace.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include +#include +#include "../cutlass/nvrtc/environment.h" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace nvrtc { +namespace thread { + +#define NVRTC_RETURN_IF_ERROR(api) \ + do { \ + nvrtcResult _result = api; \ + if (_result != NVRTC_SUCCESS) { \ + CUTLASS_TRACE_HOST("Nvrtc error: " << _result); \ + return false; \ + } \ + } while(0) + +inline const char * cuda_source_fmt = R"""( + +#include "kernel/thread/contraction.hpp" + +using Operator = %s; + +extern "C" __global__ void global_entry(__grid_constant__ Operator::Params const params) { + extern __shared__ char smem[]; + + Operator op; + op(params, smem); +} + +)"""; + +struct TestbedKernel { + static bool compile(std::string const &kernel, std::vector const &opts) { + int sz = std::snprintf(nullptr, 0, cuda_source_fmt, kernel.c_str()); + std::vector cuda_source(sz + 1); + std::snprintf(&cuda_source[0], cuda_source.size(), cuda_source_fmt, kernel.c_str()); + + nvrtcProgram program; + NVRTC_RETURN_IF_ERROR( + nvrtcCreateProgram( + &program, + cuda_source.data(), + nullptr, + static_cast(cutlass::nvrtc::kCutlassHeaderCount), + cutlass::nvrtc::kCutlassHeaders, + cutlass::nvrtc::kCutlassHeaderNames) + ); + + nvrtcResult compile_result = + nvrtcCompileProgram( + program, + static_cast(opts.size()), + opts.data()); + + size_t log_size; + NVRTC_RETURN_IF_ERROR( + nvrtcGetProgramLogSize(program, &log_size) + ); + + if (log_size > 1) { + auto log = std::make_unique(log_size); + + NVRTC_RETURN_IF_ERROR( + nvrtcGetProgramLog(program, log.get()) + ); + + std::cout << log.get() << std::endl; + } + + NVRTC_RETURN_IF_ERROR(compile_result); + + NVRTC_RETURN_IF_ERROR( + nvrtcDestroyProgram(&program) + ); + + return true; + } +}; + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + static inline bool check_nvrtc_error(nvrtcResult error) { + if (error != NVRTC_SUCCESS) { + std::cerr << "failed to compile "; + return false; + } + return true; + } + + /// Runs the test + bool run(std::string const &gemm_traits) { + + // + // initialize device memory + // + + cutlass::reference::host::BlockFillSequential( + tensor_A.host_data(), + tensor_A.capacity() + ); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + +#if 0 + // launch kernel + cutlass::gemm::kernel::testbed_kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + +#else + // Instantiate gemm_kernel + nvrtcResult result_nvrtc; + nvrtcProgram program; + static char const *src = + "#include \"cutlass/gemm/thread/mma.h\"\n" + "#include \"cutlass/gemm/gemm.h\"\n" + "#include \"cutlass/layout/matrix.h\"\n" + "#include \"unit/nvrtc/kernel/thread/testbed_kernel.h\"\n" + ; + + std::string type_name; +#if 0 + // TODO Ideally we'd use nvrtcGetTypeName to determine the type, but it cannot resolve enum symbol names + // As altername solution we might want to implement to_string() to get the traits string. + nvrtcGetTypeName(&type_name); +#else + type_name = gemm_traits; +#endif + + result_nvrtc = nvrtcCreateProgram(&program, + src, + NULL, + (int)cutlass::nvrtc::kCutlassHeaderCount, + cutlass::nvrtc::kCutlassHeaders, + cutlass::nvrtc::kCutlassHeaderNames); + check_nvrtc_error(result_nvrtc); + + std::string gemm_kernel_instantiation = + "test::nvrtc::kernel::thread::testbed_kernel< " + type_name + " >"; + nvrtcAddNameExpression(program, gemm_kernel_instantiation.c_str()); + + const char *opts[] = {"--gpu-architecture=compute_75", + "--std=c++17", + "--include-path=/usr/local/cuda-10.1/include"}; + + result_nvrtc = nvrtcCompileProgram(program, 3, opts); + if (result_nvrtc != NVRTC_SUCCESS) { + size_t logSize; + nvrtcGetProgramLogSize(program, &logSize); + std::vector log(logSize); + nvrtcGetProgramLog(program, log.data()); + std::cout << "Compile log:" << std::endl << log.data() << std::endl; + } + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // The lowered name is the name of the template instantiation in the generated PTX code. + char const *gemm_kernel_lowered_name; + nvrtcGetLoweredName(program, gemm_kernel_instantiation.c_str(), &gemm_kernel_lowered_name); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // Query the size of the genereated PTX so that we can allocate storage and retrieve it afterwards + size_t ptx_size; + result_nvrtc = nvrtcGetPTXSize(program, &ptx_size); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + std::vector ptx(ptx_size); + result_nvrtc = nvrtcGetPTX(program, ptx.data()); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // we do not need the nvrtc program anymore + //nvrtcDestroyProgram(&program); + + CUmodule module; + CUresult result_cuda; + result_cuda = cuModuleLoadDataEx(&module, ptx.data(), 0, 0, 0); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } + + CUfunction kernel; + result_cuda = cuModuleGetFunction(&kernel, module, gemm_kernel_lowered_name); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } + + void* d_a = (void*)tensor_A.device_data(); + void* d_b = (void*)tensor_B.device_data(); + void* d_c = (void*)tensor_C.device_data(); + void* d_d = (void*)tensor_D_computed.device_data(); + void* args[] = { &d_d, &d_a, &d_b, &d_c }; + + // CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra + result_cuda = cuLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0, 0 /*cudaStreamDefault*/, args, 0); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } else { +} +#endif + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cout << "CUDA ERROR: " << cudaGetErrorString(result); + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + //tensor_D_reference.fill(tensor_C.host_view()); + + cutlass::reference::host::Gemm reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + if(!passed) std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + std::cout << "passed " << passed << std::endl; + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace nvrtc +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..6cc2946a2c51cfb8c1971345c81c1910bd667208 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h @@ -0,0 +1,145 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Common Testbed file shared by Pipeline unit tests +*/ + +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" +#include "../common/cutlass_unit_test.h" + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + #define CUTLASS_UNIT_TEST_PIPELINE true +#else + #define CUTLASS_UNIT_TEST_PIPELINE false +#endif + +// Command line test options +struct Options { + // + // Data Members + // + bool help; + bool verification_enabled; + int SM_count; + int clock_MHz; + + // + // Methods + // + Options(): + help(false), + verification_enabled(true), + SM_count(116), + clock_MHz(1477) + { } + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("verification-enabled", verification_enabled, true); + cmd.get_cmd_line_argument("sm-count", SM_count, 116); + cmd.get_cmd_line_argument("clock", clock_MHz, 1477); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --verification-enabled= Enable/Disable verification\n" + << " --sm-count= Number of SMs on the chip\n" + << " --clock= Locked clock value in Mhz\n"; + + return out; + } +}; + +// +// Testbed +// + +template +struct Testbed { +private: + // Commandline options + Options options; + + void run_test(uint32_t const kNumIters) { + + // Run CuTe Gemm + Pipeline pipeline; + + cudaError_t result = pipeline.run(kNumIters); + + CUTE_CHECK_LAST(); + } + + +public: + Testbed(Options const &options_) : options(options_) { + int device_id = 0; + cudaDeviceProp device_prop; + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + } + + /// Run verification Gemm problem sizes + bool verification() { + + std::array kNumIters; + + for (size_t i = 0; i < kNumIters.size(); ++i) { + kNumIters[i] = static_cast( (rand() % 1000) + 1 ); + } + + for (int n : kNumIters) { + std::cout << "Stages = " << Pipeline::Stages << " kNumIters = " << n << "\n"; + run_test(n); + } + + return true; + } +}; diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h new file mode 100644 index 0000000000000000000000000000000000000000..50a68a1437956c95aa4e7912e93adc8b1481c9cc --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed file used by cluster launch control pipeline unit test +*/ + +// + +// + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + #define CUTLASS_UNIT_TEST_PIPELINE true +#else + #define CUTLASS_UNIT_TEST_PIPELINE false +#endif + +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +// Command line test options +struct OptionsClusterLaunch { + // + // Data Members + // + bool help = false; + bool verification_enabled = true; + int SM_count = 116; + int clock_MHz = 1477; + dim3 grid_dim = {0,0,0}; + + // + // Methods + // + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("verification-enabled", verification_enabled, verification_enabled); + cmd.get_cmd_line_argument("sm-count", SM_count, SM_count); + cmd.get_cmd_line_argument("clock", clock_MHz, clock_MHz); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --verification-enabled= Enable/Disable verification\n" + << " --sm-count= Number of SMs on the chip\n" + << " --clock= Locked clock value in Mhz\n"; + + return out; + } +}; + +// +// Testbed +// + +template +class TestbedClusterLaunch { +private: + // Commandline options + OptionsClusterLaunch options; + + bool run_test() { + + // Run CuTe Gemm + Pipeline pipeline; + + bool success = false; + cudaError_t result = pipeline.run(success, this->options.grid_dim); + + CUTE_CHECK_LAST(); + return success; + } + + +public: + TestbedClusterLaunch(OptionsClusterLaunch const &options_) : options(options_) { + int device_id = 0; + cudaDeviceProp device_prop; + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + } + + /// Run verification Gemm problem sizes + bool verification() { + +#if !defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + printf( + "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be set, but it is not. \n" + "This test is waived.\n" + ); + return true; +#endif + +#if 0 + bool is_success = false; + for (int i = 0; i< 10; i++){ + printf("iteration = %d\n", i); + is_success = run_test(); + if ( not is_success ) + return is_success; + } + return is_success; +#else + // Run the test with single launch + return run_test(); +#endif + } +}; diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e44a42463ae95e4f76388d791c661de875092c93 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h @@ -0,0 +1,45 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level Reduction +*/ + +#pragma once + +#include "cutlass/reduction/thread/reduce.h" + +#include "cutlass/layout/vector.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..239f228831a25527106af1659383112535943df1 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level Reduction +*/ + +#pragma once + +#include "cutlass/reduction/thread/reduce.h" + +#include "cutlass/layout/vector.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +namespace test { +namespace reduction { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the reduction +template < + /// Data type of elements + typename Element, + /// Number of elements + int N +> +struct Testbed_reduce_host { + + /// Thread-level reduction operator + using Reduce = cutlass::reduction::thread::Reduce< + cutlass::plus, + cutlass::Array + >; + + // + // Data members + // + + cutlass::Array tensor_in; + cutlass::Array reduced_tensor_computed; + cutlass::Array reduced_tensor_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed_reduce_host() { + tensor_in.clear(); + reduced_tensor_computed.clear(); + reduced_tensor_reference.clear(); + } + + /// Runs the test + bool run() { + + // + // initialize memory + // + + for(int i = 0; i < N; i++) + tensor_in.at(i) = Element(i); + + + Reduce reduce; + + cutlass::Array *out_ptr = &reduced_tensor_computed; + out_ptr[0] = reduce(tensor_in); + + // + // Reference implementation + // + Element e(0); + for (int i = 0; i < N; i++) + e = e + Element(i); + + reduced_tensor_reference.at(0) = e; + + // + // Verify equivalence + // + + // compare + bool passed = reduced_tensor_reference[0] == reduced_tensor_computed[0]; + + EXPECT_TRUE(passed) + << "Expected = " << float(reduced_tensor_reference.at(0)) << "\n\n" + << "Actual = " << float(reduced_tensor_computed.at(0)) << "\n\n" + << std::endl; + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level reduction kernel +template +__global__ void kernel_reduce(Element const *array_in, Element *result) { + + /// Thread-level reduction operator + using Reduce = cutlass::reduction::thread::Reduce< + cutlass::plus, + cutlass::Array + >; + + Reduce reduce; + + auto ptr_in = reinterpret_cast const *>(array_in); + auto result_ptr = reinterpret_cast *>(result); + auto in = *ptr_in; + result_ptr[0] = reduce(in); +} + + +/// Structure to compute the reduction +template < + /// Data type of elements + typename Element, + /// Number of elements + int N +> +struct Testbed_reduce_device { + + using Layout = cutlass::layout::PackedVectorLayout; + + // + // Data members + // + + cutlass::HostTensor tensor_in; + cutlass::HostTensor reduced_tensor_computed; + cutlass::HostTensor reduced_tensor_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed_reduce_device() { + + tensor_in.reset(cutlass::make_Coord(N), true); + reduced_tensor_computed.reset(cutlass::make_Coord(1), true); + reduced_tensor_reference.reset(cutlass::make_Coord(1), true); + } + + + /// Runs the test + bool run() { + + // + // initialize memory + // + + cutlass::reference::host::TensorFill( + tensor_in.host_view(), + Element(1) + ); + + cutlass::reference::host::TensorFill( + reduced_tensor_computed.host_view(), + Element(0) + ); + + cutlass::reference::host::TensorFill( + reduced_tensor_reference.host_view(), + Element(N) + ); + + tensor_in.sync_device(); + reduced_tensor_computed.sync_device(); + reduced_tensor_reference.sync_device(); + + /// call the kernel + kernel_reduce<<< dim3(1, 1), dim3(1, 1, 1) >>> ( + tensor_in.device_data(), + reduced_tensor_computed.device_data() + ); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + // Copy back results + reduced_tensor_computed.sync_host(); + + // Verify equivalence + bool passed = cutlass::reference::host::TensorEquals( + reduced_tensor_computed.host_view(), + reduced_tensor_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "Expected = " << reduced_tensor_reference.host_view() << "\n\n" + << "Actual = " << reduced_tensor_computed.host_view() << "\n\n" + << std::endl; + + return passed; + } +}; + +} // namespace thread +} // namespace reduction +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c4e7de4351076dba3a699b4cb1c8a6e01485bc20 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Compress utils specific for SM90 structure sparse kernels +*/ + +#pragma once + +#include // std::fill +#include // std::array +#include +#include // std::mt19937 + +#include "cute/container/bit_field.hpp" // cute::bit_field +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor, cute::print_tensor +#include "cutlass/arch/arch.h" // cutlass::arch::Sm90 +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" // cutlass::TagToStrideA_t +#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up +#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo +#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes +#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter + +namespace cutlass +{ +namespace transform +{ +namespace kernel +{ + +using namespace cute; + +namespace detail { + + template + CUTLASS_HOST_DEVICE + static uint8_t + encode_in_chunk_idx_legacy(int in_chunk_idx){ + if (sizeof(T) == 4) { + return in_chunk_idx == 0 ? 0b0100 : 0b1110; + } + else { + uint8_t res = 0; + if (in_chunk_idx == 0) { + res = 0b00; + } + else if (in_chunk_idx == 1) { + res = 0b01; + } + else if (in_chunk_idx == 2) { + res = 0b10; + } + else { + res = 0b11; + } + return res; + } + } + + template < + class SparseConfig, + class EngineA, + class LayoutA, + class EngineAc, + class LayoutAc + > + CUTLASS_HOST_DEVICE + static void + compress_two_chunks_legacy( + Tensor tensorA, + Tensor tensorAc, + uint8_t& meta_two_chunk, + int effective_elems) { + + using ElementA = typename EngineAc::value_type; + + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int ElementEBitsPerElementAMma = typename SparseConfig::ElementEBitsPerElementAMma{}; + static constexpr int LogicalSubChunk = ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalSubChunk = ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + /* + Legal metadata chunk in SM90 + Index Bin HEX + 0, 1 0b0100 4 + 1, 2 0b1001 9 + 2, 3 0b1110 E + 0, 2 0b1000 8 + 1, 3 0b1101 D + 0, 3 0b1100 C + 2, 1 0b0110 6 (Not used) + ----------------------------------- + TF32 + 0 0b0100 4 + 1 0b1110 E + */ + + if (effective_elems <= 0) { + return; + } + + // initialize + // 0 is the initial value for this function while 0x44 is the initial value for hardware. + meta_two_chunk = 0; + + for (int chunk_idx = 0; chunk_idx < 2; ++chunk_idx) { + // If Only One Chunk within this Two Chunk + if ( effective_elems <= chunk_idx * ElemsARawPerElementAMmaRaw * LogicalSubChunk ) { + break; + } + /// init result; + int non_zero_cnt = 0; + int32_t nnz_chunk_idx[PhysicalSubChunk] = { 0 }; + ElementA Ac_chunk[PhysicalSubChunk][ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + + for (int subchunk_idx = 0; subchunk_idx < LogicalSubChunk; ++subchunk_idx) { + bool is_nz = true; + ElementA subchunk_elems[ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + /// Check if subchunk is non-zero + for(int elem_idx = 0; elem_idx < ElemsARawPerElementAMmaRaw; elem_idx++) { + int offset = chunk_idx * LogicalElemsAPerChunk + subchunk_idx * ElemsARawPerElementAMmaRaw + elem_idx; + subchunk_elems[elem_idx] = offset < effective_elems ? tensorA(offset) : ElementA(0); + + ElementA zero = static_cast(0); + ElementA minus_zero = static_cast(ElementA(1) << cutlass::sizeof_bits_v - 1); + if (subchunk_elems[elem_idx] != zero && subchunk_elems[elem_idx] != minus_zero) { + if (non_zero_cnt >= PhysicalSubChunk) { + #ifdef __CUDA_ARCH__ + asm volatile ("brkpt;\n" ::); + #else + throw std::runtime_error("Found extra non-zero elements in a chunk!\n"); + #endif + } + is_nz = false; + } + } + + /// There is non-zero element in the subchunk + if(!is_nz) { + nnz_chunk_idx[non_zero_cnt] = subchunk_idx; + memcpy(Ac_chunk[non_zero_cnt], subchunk_elems, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + non_zero_cnt++; + } + } + + /* + Special cases + nnz == 1 and non-tf32 and nnz_idx = 3 + */ + ElementA elementA_zeros[ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + if constexpr (sizeof_bits_v < 32) { + if (non_zero_cnt == 1 && nnz_chunk_idx[0] == 3) { + memcpy(Ac_chunk[1], Ac_chunk[0], sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + memcpy(Ac_chunk[0], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + nnz_chunk_idx[1] = 3; + nnz_chunk_idx[0] = 0; + } + else if (non_zero_cnt == 1) { + memcpy(Ac_chunk[1], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + nnz_chunk_idx[1] = 3; + } + } + + /// Setup metadata + uint8_t meta_chunk = 0; + for (int i = 0; i < PhysicalSubChunk; i++) { + meta_chunk = static_cast(meta_chunk | (encode_in_chunk_idx_legacy(nnz_chunk_idx[i]) << (i * ElementEBitsPerElementAMma))); + for(int j = 0; j < ElemsARawPerElementAMmaRaw; j++) { + tensorAc(chunk_idx * PhysicalElemsAPerChunk + i * ElemsARawPerElementAMmaRaw + j) = Ac_chunk[i][j]; + } + } + meta_two_chunk = uint8_t(meta_two_chunk | (meta_chunk << (chunk_idx * _4{}))); + } + } +} + +template< + class ProblemShape_, + class ElementA_, + class LayoutATag_, + class SparseConfig_ +> +class SM90StructuredSparseCompressorLegacy { +public: + using SparseConfig = SparseConfig_; + using ProblemShape = ProblemShape_; + + // * EltA + using ElementA = ElementA_; + using ElementAUint = cute::uint_bit_t>; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + using ArrayElementA = cute::conditional_t>, + ElementA>; + using ElementAMma = typename SparseConfig::ElementAMma; + using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; + using ElementASparsity = typename SparseConfig::ElementASparsity; + using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; + using LayoutATag = LayoutATag_; + using LayoutA = LayoutATag; + using StrideA = cutlass::gemm::TagToStrideA_t; + + // * EltE + using ElementEMma = typename SparseConfig::ElementEMma; + using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; + using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; + + // * AtomE + using TensorEAtom = typename SparseConfig::TensorEAtom; + using TensorEAtomK = typename SparseConfig::TensorEAtomK; + using TensorEAtomM = typename SparseConfig::TensorEAtomM; + + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + // * Alignment + static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; + static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; + + // Required by `device_kernel` + static constexpr int MaxThreadsPerBlock = 1; + static constexpr int MinBlocksPerMultiprocessor = 1; + using ArchTag = arch::Sm90; + + struct SharedStorage { + /* empty, no smem needed */ + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct TransformArguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementA* ptr_ACompress{nullptr}; + ElementEMmaRaw* ptr_E{nullptr}; + }; + + using TransformParams = TransformArguments; + + struct Arguments { + ProblemShape problem_shape{}; + TransformArguments transform{}; + KernelHardwareInfo hw_info{}; + }; + + struct Params { + ProblemShape problem_shape{}; + TransformParams transform{}; + KernelHardwareInfo hw_info{}; + void* workspace = nullptr; + }; + + static Params + to_underlying_arguments(Arguments & args, void* workspace) { + return Params{{args.problem_shape}, + {args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E}, + {args.hw_info}, + workspace}; + } + + static Status + can_implement(Arguments const& args) { + auto [M, N, K, L] = args.problem_shape; + if (K % LogicalElemsAPerChunk != 0) { + CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size\n"); + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + static size_t + get_workspace_size(Arguments const& args) { + auto problem = args.problem_shape; + const int m = cute::size<0>(problem); + const int k = cute::size<2>(problem); + const int l = cute::size<3>(problem); + const int metadata_k = round_up(k, TensorEAlignmentK); + const int metadata_m = round_up(m, TensorEAlignmentM); + const int metadata_bytes = metadata_m * metadata_k / ElementEMmaSparsity{} * l; + return metadata_bytes; + } + + static Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + cudaError_t cuda_error; + + auto workspace_size = get_workspace_size(args); + if (workspace_size == 0) { + return Status::kSuccess; + } else if (workspace == nullptr) { + return Status::kErrorInternal; + } + + cudaPointerAttributes attri; + cuda_error = cudaPointerGetAttributes(&attri, workspace); + if (cuda_error != cudaSuccess) { + return Status::kErrorInternal; + } + + if ( attri.type == cudaMemoryTypeDevice ) { +#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER + CUTLASS_ASSERT(cuda_adapter); + if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast(0), workspace_size, stream)) { + return Status::kErrorInternal; + } +#else + cudaMemsetAsync(workspace, 0, workspace_size, stream); + cuda_error = cudaGetLastError(); + if (cuda_error != cudaSuccess) { + return Status::kErrorInternal; + } +#endif + } else { + memset(workspace, 0, workspace_size); + } + + return Status::kSuccess; + } + + static dim3 + get_grid_shape(Params const& params) { + return dim3(1, 1, 1); + } + + static dim3 + get_block_shape() { + return dim3(1, 1, 1); + } + + CUTE_HOST_DEVICE + void + operator()(Params params, char* smem_buf = nullptr) { + run(params, smem_buf); + } + + CUTE_HOST_DEVICE + static void + run(Params params, char* smem_buf = nullptr) { + do_compress_device_host(params); + } + +private: + + CUTE_HOST_DEVICE + static void + do_compress_device_host(Params params) { + auto [m, n, k, l] = params.problem_shape; + auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform; + auto workspace = params.workspace; + + const int aligned_k = (k + TensorAAlignmentK - 1) / TensorAAlignmentK * TensorAAlignmentK; + const int aligned_m = (m + TensorAAlignmentM - 1) / TensorAAlignmentM * TensorAAlignmentM; + const int metadata_k = (k + TensorEAlignmentK - 1) / TensorEAlignmentK * TensorEAlignmentK; + const int metadata_m = (m + TensorEAlignmentM - 1) / TensorEAlignmentM * TensorEAlignmentM; + const int k_compressed = aligned_k / ElementASparsity{}; + + // Convert to CuTe tensors. But don't want to use sparse_ptr, which is making everything complicated here. + cute::Tensor tensorA = make_tensor(recast_ptr(ptr_A), make_layout(make_shape(m, k, l), dA)); + + cute::Tensor tensorAc = make_tensor(recast_ptr(ptr_ACompress), + make_shape(aligned_m, k_compressed, l), + make_cute_packed_stride(StrideA{}, cute::make_shape(aligned_m, k_compressed, l))); + + cute::Tensor tensorE_raw_compress_logical = make_tensor(recast_ptr>(workspace), + make_shape(metadata_m, make_shape(TensorEAtomK{}, metadata_k / TensorEAtomK{}), l), + make_stride(TensorEAtomK{}, make_stride(_1{}, metadata_m*TensorEAtomK{}), metadata_m*metadata_k)); + + cute::Tensor tensorE_raw_compress = recast(tensorE_raw_compress_logical); + + // The following vars are all logical. + int atom_m = size<0>(TensorEAtom{}); + int atom_k = size<1>(TensorEAtom{}); + int tiled_m = metadata_m / atom_m; + int tiled_ke = metadata_k / atom_k; + // Col major when viewing atoms + int stride_tile_m = cosize(TensorEAtom{}); + int stride_tile_ke = atom_k * metadata_m; + + // Logical metadata tensor + cute::Tensor tensorE_logical = make_tensor(recast_ptr>(ptr_E), + make_layout(make_shape(append(shape<0>(TensorEAtom{}), tiled_m), + append(shape<1>(TensorEAtom{}), tiled_ke), + shape<2>(tensorE_raw_compress_logical)), + make_stride(append(stride<0>(TensorEAtom{}), stride_tile_m), + append(stride<1>(TensorEAtom{}), stride_tile_ke), + stride<2>(tensorE_raw_compress_logical)))); + // Physical metadata tensor + cute::Tensor tensorE = recast(tensorE_logical); + + // void do_init() + cute::clear(tensorAc); + cute::clear(tensorE_raw_compress); + + // void do_raw_compress() + using TileStepA = Int; + using TileStepAc = Int; + + cute::Tensor tensorATiled = logical_divide(tensorA, make_shape(_, TileStepA{}, _)); + cute::Tensor tensorAcTiled = logical_divide(tensorAc, make_shape(_, TileStepAc{}, _)); + + for (int batch_idx = 0; batch_idx < l; batch_idx++) { + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int tiler_k_idx = 0; tiler_k_idx < size<1,1>(tensorATiled); tiler_k_idx++) { + int effective_elems = cute::min(TileStepA{}, k - (tiler_k_idx * TileStepA{})); + detail::compress_two_chunks_legacy(tensorATiled(m_idx, make_coord(_, tiler_k_idx), batch_idx), + tensorAcTiled(m_idx, make_coord(_, tiler_k_idx), batch_idx), + tensorE_raw_compress(m_idx, tiler_k_idx, batch_idx), + effective_elems); + } + } + } + + // void do_reorder() + // Fast path when we don't permute. + if constexpr (sizeof_bits_v <= 8) { + memcpy(tensorE.data(), tensorE_raw_compress.data(), tensorE.size()); + } + else { + cute::copy(tensorE_raw_compress, tensorE); + } + + #if 0 + print("--> TensorA\n"); + auto tensorA_eltA = cute::recast(tensorA); + cute::print_tensor(tensorA_eltA); printf("\n\n"); + + print("--> REF TensorAC\n"); + auto tensorAc_eltA = cute::recast(tensorAc); + cute::print_tensor(tensorAc_eltA); printf("\n\n"); + + print("--> REF TensorE\n"); + cute::print_tensor(tensorE); printf("\n\n"); + #endif + + } +}; + +} // namespace kernel +} // namespace transform +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f44458244e0d3c4c80ecc29a0115cd6906211559 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp @@ -0,0 +1,877 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* + * @brief Test for structured sparse gemm compressor device kernel + */ + +#pragma once + +#include // cudaGetLastError + +#include // uint64_t +#include // printf +#include // malloc +#include // std::cout +#include +#include + +#include "cute/layout.hpp" // cute::make_shape +#include "cute/util/type_traits.hpp" // cute::is_same_v +#include "cutlass/coord.h" // cutlass::make_Coord +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/kernel_hardware_info.hpp" // cutlass::KernelHardwareInfo +#include "cutlass/layout/matrix.h" // cutlass::layout::Affine2Layout_Factory +#include "cutlass/numeric_types.h" // cutlass::sizeof_bits, cutlass::float_ +#include "cutlass/tensor_view.h" // cutlass::TensorView +#include "cutlass/transform/device/transform_universal_adapter.hpp" // cutlass::transform::device::TransformUniversalAdapter +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // cutlass::transform::kernel::StructuredSparseCompressorUtility +#include "cutlass/util/device_memory.h" // cutlass::device_memory::allocation +#include "cutlass/util/distribution.h" // cutlass::Distribution +#include "cutlass/util/host_tensor.h" // cutlass::HostTensor +#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride +#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals +#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform, TensorFillIdentity, TensorFillRandomGaussian, BlockFillSequential, TensorFill +#include "cutlass/detail/collective.hpp" + +#include "sm90_sparse_gemm_compressor_legacy.hpp" // Legacy host compressor +#include "../../common/cutlass_unit_test.h" // CUTLASS UT, EXPECT_TRUE + + +#define CUDA_CHECK_FALSE(cuda_error) \ + { \ + if (cuda_error != cudaSuccess) { \ + printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ + return false; \ + } \ + } + +#define CUDA_CHECK(cuda_error) \ + { \ + if (cuda_error != cudaSuccess) { \ + printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ + return; \ + } \ + } + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// * Test Bed +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test +{ +namespace transform +{ +namespace device +{ + +// Helper Functions +template +bool +initialize_tensor(cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) +{ + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else { + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Testbed +template +struct TestbedSparseGemmCompressor { +public: + using Compressor = Compressor_; + using CompressorKernel = typename Compressor::TransformKernel; + + using ElementA = typename CompressorKernel::ElementA; + using LayoutATag = typename CompressorKernel::LayoutATag; + using StrideA = typename CompressorKernel::StrideA; + using ArrayElementA = + ElementA + ; + + using ElementE = typename CompressorKernel::ElementEMmaRaw; + using LayoutETag = cutlass::layout::RowMajor; // We don't care about the major here, just to allocate tensor + + using SparseConfig = typename CompressorKernel::SparseConfig; + using ProblemShapeType = typename CompressorKernel::ProblemShape; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShapeType, + ElementA, + LayoutATag, + SparseConfig>; + + using CompressorKernelHost = cutlass::transform::kernel::SM90StructuredSparseCompressorLegacy< + ProblemShapeType, + ElementA, + LayoutATag, + SparseConfig>; + + using CompressorHost = cutlass::transform::device::TransformUniversalAdapter; + + static constexpr auto LogicalElemsAPerChunk = CompressorKernel::LogicalElemsAPerChunk; + static constexpr auto PhysicalElemsAPerChunk = CompressorKernel::PhysicalElemsAPerChunk; + + struct Data { + // Data Storage + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_A_Comp_ref; + cutlass::HostTensor tensor_E_ref; + }; + + struct CudaRAII { + cudaStream_t stream; + cudaEvent_t start; + cudaEvent_t stop; + + CudaRAII(){ + CUDA_CHECK(cudaStreamCreate( &stream )); + CUDA_CHECK(cudaEventCreate( &start )); + CUDA_CHECK(cudaEventCreate( &stop )); + }; + + CudaRAII(const CudaRAII&) = delete; + CudaRAII& operator=(const CudaRAII&) = delete; + CudaRAII(CudaRAII&&) = delete; + CudaRAII& operator=(CudaRAII&&) = delete; + + ~CudaRAII(){ + CUDA_CHECK(cudaStreamDestroy( stream )); + CUDA_CHECK(cudaEventDestroy( start )); + CUDA_CHECK(cudaEventDestroy( stop )); + } + }; + +public: + TestbedSparseGemmCompressor( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_Comp_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 7) + : init_A(init_A_) + , init_E(init_E_) + , init_A_Comp(init_A_Comp_) + , seed(seed_) + { + } + + bool valid_test(ProblemShapeType problem_shape_MNKL) + { + const int GemmK = cute::size<2>(problem_shape_MNKL); + + if ( GemmK % LogicalElemsAPerChunk != 0 ) { + printf("GemmK needs to be multiplier of LogicalElemsAPerChunk\n"); + return false; + } + + return true; + } + + bool initialize(ProblemShapeType problem_shape_MNKL, Data& datas) + { + CUDA_CHECK_FALSE(cudaGetLastError()); + + // In unit of ElementARaw + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + // Compressor utility to get allocated data size + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorA + // In unit of ElementARaw, after alignment requirement + // M-dim: no alignment requirement + // K-dim: multiplier of chunk size + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + const int GemmMAlignedAC = compressor_utility.get_tensorA_m_physical(); + const int GemmKAlignedAC = compressor_utility.get_tensorA_k_physical(); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + const int GemmMAlignedE = compressor_utility.get_metadata_m_physical(); + const int GemmKAlignedE = compressor_utility.get_metadata_k_physical(); + + auto a_coord = cutlass::make_Coord(GemmM * GemmL, GemmK); + auto e_coord = cutlass::make_Coord(GemmMAlignedE * GemmL, GemmKAlignedE); + auto a_comp_coord = cutlass::make_Coord(GemmMAlignedAC * GemmL, GemmKAlignedAC); + + typename LayoutATag::Stride stride_factor_A; + typename LayoutETag::Stride stride_factor_E; + + datas.tensor_A.resize(a_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + datas.tensor_A_Comp.resize(a_comp_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + datas.tensor_A_Comp_ref.resize(a_comp_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A), + false); + datas.tensor_E.resize(e_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + datas.tensor_E_ref.resize(e_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E), + false); + + EXPECT_TRUE(initialize_tensor(datas.tensor_A.host_view(), init_A, seed + 1)); + EXPECT_TRUE(initialize_tensor(datas.tensor_E.host_view(), init_E, seed + 2)); + EXPECT_TRUE(initialize_tensor(datas.tensor_E_ref.host_view(), init_E, seed + 3)); + EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp.host_view(), init_A_Comp, seed + 4)); + EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp_ref.host_view(), init_A_Comp, seed + 5)); + + compressor_utility.structure_sparse_zero_mask_fill(datas.tensor_A.host_data(), seed + 6); + + // Check for failed devide + CUDA_CHECK_FALSE(cudaGetLastError()); + + datas.tensor_A.sync_device(); + datas.tensor_A_Comp.sync_device(); + datas.tensor_E.sync_device(); + + // Check for failed devide + CUDA_CHECK_FALSE(cudaGetLastError()); + + return true; + } + + bool run_device(ProblemShapeType problem_shape_MNKL, Data& datas, float* time = nullptr) + { + CudaRAII cuda_raii; + + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {GemmM, GemmN, GemmK, GemmL}, + {datas.tensor_A.device_data(), + stride_a, + datas.tensor_A_Comp.device_data(), + datas.tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status {cutlass::Status::kSuccess }; + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + status = compressor_op.initialize(arguments, workspace.get(), cuda_raii.stream); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream)); + CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.start, cuda_raii.stream)); + + status = compressor_op.run(cuda_raii.stream); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.stop, cuda_raii.stream)); + CUDA_CHECK_FALSE(cudaEventSynchronize(cuda_raii.stop)); + CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream)); + if ( time != nullptr ){ + CUDA_CHECK_FALSE(cudaEventElapsedTime(time, cuda_raii.start, cuda_raii.stop)); + } + + datas.tensor_A_Comp.sync_host(); + datas.tensor_E.sync_host(); + + #if 0 + { + printf("\n--> DEVICE OUTPUT\n"); + printf("datas.tensor_A\n"); + std::cout << datas.tensor_A.host_view() << std::endl << std::endl; + printf("datas.tensor_A_Comp\n"); + std::cout << datas.tensor_A_Comp.host_view() << std::endl << std::endl; + printf("datas.tensor_E\n"); + std::cout << datas.tensor_E.host_view() << std::endl << std::endl; + } + #endif + + return true; + } + + bool run_host_ref(ProblemShapeType problem_shape_MNKL, Data& datas) + { + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + + typename CompressorKernelHost::Arguments arguments{ + {GemmM, GemmN, GemmK, GemmL}, + {datas.tensor_A.host_data(), + stride_a, + datas.tensor_A_Comp_ref.host_data(), + datas.tensor_E_ref.host_data()}, + {}}; + + const auto can_imp = CompressorKernelHost::can_implement(arguments); + if (can_imp != cutlass::Status::kSuccess) { + printf("can_implement() check failed\n"); + return false; + } + + // Relies on std::vector for RAII + auto workspace_size = + static_cast::size_type>(CompressorKernelHost::get_workspace_size(arguments)); + std::vector workspace_vector(workspace_size); + auto workspace = static_cast(workspace_vector.data()); + + cutlass::Status status = CompressorKernelHost::initialize_workspace(arguments, workspace); + if (status != cutlass::Status::kSuccess) { + printf("initialize_workspace() failed\n"); + return false; + } + + auto params = CompressorKernelHost::to_underlying_arguments(arguments, workspace); + CompressorKernelHost::run(params); + + return true; + } + + bool compare_reference(Data& datas) + { + bool check_tensor_a_compressed = + cutlass::reference::host::TensorEquals(datas.tensor_A_Comp_ref.host_view(), datas.tensor_A_Comp.host_view()); + if (!check_tensor_a_compressed) { + printf("A-Compressed Mismatch\n"); + } + + bool check_tensor_e = cutlass::reference::host::TensorEquals(datas.tensor_E_ref.host_view(), datas.tensor_E.host_view()); + if (!check_tensor_e) { + printf("E Mismatch\n"); + } + + return check_tensor_a_compressed && check_tensor_e; + } + + bool run_auto_small() + { + return run_auto(true); + } + + bool run_auto(bool run_small = false) + { + constexpr auto TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + constexpr auto TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + + constexpr int GemmN = 1; + + using ProblemType = typename std::array; + + std::vector problems; + + const std::vector problems_multiplier_of_tensor_e_atom = { + // * Regular Cases (multiplier of TensorEAlignment) + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 3}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 3}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 3}, + }; + + const std::vector problems_multiplier_of_tensor_e_atom_large = { + // * Large Case (multiplier of TensorEAlignment) + {TensorEAlignmentM * 10, GemmN, TensorEAlignmentK * 13, 1}, + // {TensorEAlignmentM * 11, GemmN, TensorEAlignmentK * 14, 2}, + // {TensorEAlignmentM * 12, GemmN, TensorEAlignmentK * 15, 3}, + }; + + const std::vector problems_multiplier_of_twochunk { + // * Corner Cases + {4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + }; + + const std::vector problems_multiplier_of_onechunk { + {4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + }; + + // Run small only run multiplier of chunk size cases + if (run_small) { + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end()); + } + // Run full run all corner cases + else { + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom_large.begin(), problems_multiplier_of_tensor_e_atom_large.end()); + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end()); + problems.insert(problems.end(), problems_multiplier_of_twochunk.begin(), problems_multiplier_of_twochunk.end()); + problems.insert(problems.end(), problems_multiplier_of_onechunk.begin(), problems_multiplier_of_onechunk.end()); + } + + for (const auto& problem_shape_MNKL : problems) { + const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL; + bool passed = run({GemmM, GemmN, GemmK, GemmL}); + printf("run() (%.4d,%.4d,%.4d,%.4d) %s\n", GemmM, GemmN, GemmK, GemmL, passed ? "PASS" : "FAIL"); + CUTLASS_TRACE_HOST("run() " << GemmM << " " << GemmN << " " << GemmK << " " << GemmL << passed ? " PASS" : " FAIL"); + if (not passed) { + return false; + } + } + + return true; + } + + bool run(ProblemShapeType problem_shape_MNKL) + { + // Check if valid test + if (not valid_test(problem_shape_MNKL)) { + CUTLASS_TRACE_HOST("valid_test() fail\n"); + return false; + } + + // Data Storage + Data datas; + + // Initialize Data + if (not initialize(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("initialize() fail\n"); + return false; + } + + // Run Compressor (Host Ref) + if (not run_host_ref(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("run_host() fail\n"); + return false; + } + + // Run Compressor (Device) + if (not run_device(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("run_device() fail\n"); + return false; + } + + // Verify + if (not compare_reference(datas)) { + CUTLASS_TRACE_HOST("compare_reference() DEVICE <-> LEGACY HOST fail\n"); + printf("compare_reference() DEVICE <-> LEGACY HOST fail\n"); + return false; + } + // else { + // printf("DEVICE <-> HOST PASS\n"); + // } + + return true; + } + + bool benchmark(ProblemShapeType problem_shape_MNKL) { + const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL; + printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) START\n", GemmM, GemmN, GemmK, GemmL); + + // Check if valid test + if (valid_test(problem_shape_MNKL) == false) { + CUTLASS_TRACE_HOST("valid_test() fail\n"); + return false; + } + + // 2 warm-up iterations and 10 timing iterations + constexpr int num_warmup = 5; + constexpr int num_iter = 10; + + // Duplicate data to mimic cold cache + Data data[num_warmup + num_iter]; + double total_time_milliseconds{0.0}; + + for (int i = 0; i < num_warmup + num_iter; ++i ) { + printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) ITER %d\n", GemmM, GemmN, GemmK, GemmL, i ); + + auto& datum_i = data[i]; + + // Initialize Data + if (initialize(problem_shape_MNKL, datum_i) == false) { + CUTLASS_TRACE_HOST("initialize() fail\n"); + return false; + } + + // Run Compressor (Device) + double time_i_milliseconds{0.0f}; + if (not run_device(problem_shape_MNKL, datum_i, &time_i_milliseconds)) { + CUTLASS_TRACE_HOST("run_device() fail\n"); + return false; + } + + if ( i >= num_warmup ) { + total_time_milliseconds += time_i_milliseconds; + } + } + + const double mean_time_milliseconds = total_time_milliseconds / num_iter; + printf("Mean time (ms): %.5f\n", mean_time_milliseconds); + + return true; + } + +public: + // Data Init Setting + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_A_Comp; + cutlass::Distribution::Kind init_E; + uint64_t seed; +}; + +} // namespace device +} // namespace transform +} // namespace test diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h new file mode 100644 index 0000000000000000000000000000000000000000..df241e3ca6e6e584af7351402d990a8028e2abed --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/arch/arch.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ArchMap; + +template <> struct ArchMap { + static int const kMin = 50; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 60; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 61; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 75; +}; + +template struct ArchMap { + static int const kMin = 75; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 80; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 86; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 89; + static int const kMax = 100; +}; + +template struct ArchMap { + static int const kMin = 90; + static int const kMax = 1024; +}; + +// Arch conditional WGMMA +template <> struct ArchMap { + static int const kMin = 90; + static int const kMax = 90; +}; + +// Arch conditional sparse WGMMA +template <> struct ArchMap { + static int const kMin = 90; + static int const kMax = 90; +}; + + +template struct ArchMap { + static int const kMin = 100; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 100; + #if (__CUDACC_VER_MAJOR__ >= 13) + static int const kMax = 110; + #else + static int const kMax = 103; + #endif // __CUDACC_VER_MAJOR__ >= 13 +}; + +template struct ArchMap { + static int const kMin = 103; + static int const kMax = 1024; +}; +template <> struct ArchMap { + static int const kMin = 103; + static int const kMax = 103; +}; + +template struct ArchMap { + static int const kMin = 120; + static int const kMax = 121; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h new file mode 100644 index 0000000000000000000000000000000000000000..5e80c124e59d24cd90c7c1b0c06bcc3bedfee62f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h @@ -0,0 +1,815 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct MathInstructionDescription { + + /// Shape of the target math instruction + cutlass::gemm::GemmCoord instruction_shape; + + /// Describes the data type of the internal accumulator + NumericTypeID element_accumulator; + + /// Classification of math instruction + OpcodeClassID opcode_class; + + /// Type of math operation performed + MathOperationID math_operation; + + // + // Methods + // + + MathInstructionDescription( + cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), + NumericTypeID element_accumulator = NumericTypeID::kInvalid, + OpcodeClassID opcode_class = OpcodeClassID::kInvalid, + MathOperationID math_operation = MathOperationID::kMultiplyAdd + ): + instruction_shape(instruction_shape), + element_accumulator(element_accumulator), + opcode_class(opcode_class), + math_operation(math_operation) {} + + // Equality operator + inline + bool operator==(MathInstructionDescription const& rhs) const{ + return ( + (instruction_shape == rhs.instruction_shape) && + (element_accumulator == rhs.element_accumulator) && + (opcode_class == rhs.opcode_class) && + (math_operation == rhs.math_operation)); + } + + // Inequality operator + inline + bool operator!=(MathInstructionDescription const& rhs) const { + return !(*this == rhs); + } + +}; + +/// Structure describing the tiled structure of a GEMM-like computation +struct TileDescription { + + /// Describes the shape of a threadblock (in elements) + cutlass::gemm::GemmCoord threadblock_shape; + + /// Describes the number of pipeline stages in the threadblock-scoped mainloop + int threadblock_stages; + + /// Number of warps in each logical dimension + cutlass::gemm::GemmCoord warp_count; + + /// Core math instruction + MathInstructionDescription math_instruction; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. + int minimum_compute_capability; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. + int maximum_compute_capability; + + /// Describes the shape of a cluster (in blocks) + cutlass::gemm::GemmCoord cluster_shape; + + // + // Methods + // + + TileDescription( + cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(), + int threadblock_stages = 0, + cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), + MathInstructionDescription math_instruction = MathInstructionDescription(), + int minimum_compute_capability = 0, + int maximum_compute_capability = 0, + cutlass::gemm::GemmCoord cluster_shape = cutlass::gemm::GemmCoord(1,1,1) + ): + threadblock_shape(threadblock_shape), + threadblock_stages(threadblock_stages), + warp_count(warp_count), + math_instruction(math_instruction), + minimum_compute_capability(minimum_compute_capability), + maximum_compute_capability(maximum_compute_capability), + cluster_shape(cluster_shape) { } + + // Equality operator + inline + bool operator==(TileDescription const& rhs) const{ + return ( + (threadblock_shape == rhs.threadblock_shape) && + (threadblock_stages == rhs.threadblock_stages) && + (warp_count == rhs.warp_count) && + (math_instruction == rhs.math_instruction) && + (minimum_compute_capability == rhs.minimum_compute_capability) && + (maximum_compute_capability == rhs.maximum_compute_capability)); + } + + // Inequality operator + inline + bool operator!=(TileDescription const& rhs) const { + return !(*this == rhs); + } +}; + +/// High-level description of an operation +struct OperationDescription { + + /// Unique identifier describing the operation + char const * name; + + /// Operation provider + Provider provider; + + /// Kind of operation + OperationKind kind; + + /// Describes the tiled structure of a GEMM-like computation + TileDescription tile_description; + + // + // Methods + // + OperationDescription( + char const * name = "unknown", + Provider provider = Provider::kInvalid, + OperationKind kind = OperationKind::kInvalid, + TileDescription const& tile_description = TileDescription() + ): + name(name), provider(provider), kind(kind), tile_description(tile_description) { } +}; + +/// Structure describing the properties of a tensor +struct TensorDescription { + + /// Numeric type of an individual element + NumericTypeID element; + + /// Enumerant identifying the layout function for the tensor + LayoutTypeID layout; + + /// Alignment restriction on pointers, strides, and extents + int alignment; + + /// log2() of the maximum extent of each dimension + int log_extent_range; + + /// log2() of the maximum value each relevant stride may have + int log_stride_range; + + // + // Methods + // + + TensorDescription( + NumericTypeID element = NumericTypeID::kInvalid, + LayoutTypeID layout = LayoutTypeID::kInvalid, + int alignment = 1, + int log_extent_range = 24, + int log_stride_range = 24 + ): + element(element), + layout(layout), + alignment(alignment), + log_extent_range(log_extent_range), + log_stride_range(log_stride_range) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all GEMM computations +struct GemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the sparse meta matrices + TensorDescription E; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + GemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + GemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +struct BlockScaleDescription { + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the SFD operand + TensorDescription SFD; + + /// Describes the input ScaleFactor VectorSize + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + + /// Describes the Output ScaleFactor VectorSize + int EpilogueSFVecSize; + + /// Describes the underlying kind of scaling: + /// Tensor Core supported (BlockScaled) or manual scaling (Blockwise) + OperationKind kind; +}; + +struct GroupedGemmDescription : public OperationDescription { + GemmDescription gemm; + std::optional block_scales; +}; + +/// Description of all GEMM computations +struct BlockScaledGemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the SFD operand + TensorDescription SFD; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + /// Describes the input ScaleFactor VectorSize + int SFVecSize; + + /// Describes the Output ScaleFactor VectorSize + int EpilogueSFVecSize; + + // + // Methods + // + + BlockScaledGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + BlockScaledGemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +/// Description of all GEMM computations +struct BlockwiseGemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + /// Describes the input ScaleFactor VectorSize + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + + // + // Methods + // + + BlockwiseGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + BlockwiseGemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description for structured sparse GEMMs. +struct SparseGemmDescription : public GemmDescription { + + /// Description structure for structured sparse GEMM + SparseGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + TensorDescription const& E = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + GemmDescription(gemm_kind, A, B, C, D, element_epilogue, split_k_mode, transform_A, transform_B) + {this->E = E;} +}; + +/// Description of all Reduction operations +struct ReductionDescription : public OperationDescription { + + /// Describes the data type of workspace + NumericTypeID element_workspace; + + /// Describes the data type of final output + NumericTypeID element_output; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; +}; + +/// Description of all Rank K update computations (SYRK, HERK, SYR2K, HER2K) +struct RankKDescription : public OperationDescription { + + /// Indicates which device template is used (universal or regular) + RankKKind rank_k_kind; + + /// Number of rank update (rank k or rank 2k) + int num_ranks; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand (used only for SYR2K and HER2K) + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription C; + + /// Describes the fill mode for matrix C + FillMode fill_mode; + + /// Describes the blas mode (symmetric/hermitian) + BlasMode blas_mode; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + RankKDescription( + RankKKind rank_k_kind = RankKKind::kUniversal, + int num_ranks = 1, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + FillMode fill_mode = FillMode::kInvalid, + BlasMode blas_mode = BlasMode::kInvalid, + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + rank_k_kind(rank_k_kind), + num_ranks(num_ranks), + A(A), + B(B), + C(C), + fill_mode(fill_mode), + blas_mode(blas_mode), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all TRMM computations +struct TrmmDescription : public OperationDescription { + + /// Indicates the kind of TRMM performed + TrmmKind trmm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the side mode for matrix A + SideMode side_mode; + + /// Describes the fill mode for matrix A + FillMode fill_mode; + + /// Describes the diag type for matrix A + DiagType diag_type; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription D; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + // + // Methods + // + + TrmmDescription( + TrmmKind trmm_kind = TrmmKind::kUniversal, + TensorDescription const& A = TensorDescription(), + SideMode side_mode = SideMode::kInvalid, + FillMode fill_mode = FillMode::kInvalid, + DiagType diag_type = DiagType::kInvalid, + TensorDescription const& B = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone + ): + trmm_kind(trmm_kind), + A(A), + side_mode(side_mode), + fill_mode(fill_mode), + diag_type(diag_type), + B(B), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all SYMM/HEMM update computations +struct SymmDescription : public OperationDescription { + + /// Indicates which device template is used (universal or regular) + SymmKind symm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription C; + + /// Describes the side mode for matrix A + SideMode side_mode; + + /// Describes the fill mode for matrix A + FillMode fill_mode; + + /// Describes the blas mode (symmetric/hermitian) + BlasMode blas_mode; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + SymmDescription( + SymmKind symm_kind = SymmKind::kUniversal, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + SideMode side_mode = SideMode::kInvalid, + FillMode fill_mode = FillMode::kInvalid, + BlasMode blas_mode = BlasMode::kInvalid, + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + symm_kind(symm_kind), + A(A), + B(B), + C(C), + side_mode(side_mode), + fill_mode(fill_mode), + blas_mode(blas_mode), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all Conv2d operations +struct ConvDescription : public OperationDescription { + /// Describes the convolution dimension support (2D or 3D) + int conv_dim; + + /// Describes the kind of convolution + ConvKind conv_kind; + + /// Describes the type of iterator algorithm (analytic or precomputed) + IteratorAlgorithmID iterator_algorithm; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the C operand + TensorDescription C; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + // + // Methods + // + // Returns Activation TensorDescription + TensorDescription activation() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return A; + case library::ConvKind::kDgrad : return C; + case library::ConvKind::kWgrad : return B; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Filter TensorDescription + TensorDescription filter() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return B; + case library::ConvKind::kDgrad : return B; + case library::ConvKind::kWgrad : return C; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Output TensorDescription + TensorDescription output() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return C; + case library::ConvKind::kDgrad : return A; + case library::ConvKind::kWgrad : return A; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h new file mode 100644 index 0000000000000000000000000000000000000000..027944eb6ac8c6e8f250d83ed33c0899adfbd3e8 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h @@ -0,0 +1,365 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief BLAS-like handle used to launch operations on the CUDA device. +*/ + +#pragma once + +#include +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Handle object +class Handle { +private: + + /// Host workspace + static int const kHostWorkspaceSize = (4 << 10); + + /// Provider of operations + Provider provider_; + + /// CUDA device properties + cudaDeviceProp device_; + + /// CUDA stream + cudaStream_t stream_; + + /// Device workspace + void *workspace_; + + /// Size of device workspace in bytes + size_t workspace_size_; + + /// Indicates whether scalars are host or device pointers + ScalarPointerMode scalar_pointer_mode_; + + /// Pointer to the most recently executed operation + Operation const *last_operation_; + + int device_idx_; + +public: + + /// Constructor + Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20)); + + /// Destructor + ~Handle(); + + /// Move constructor + Handle(Handle && handle); + + /// Move assignment operator + Handle &operator=(Handle && handle); + + // + // Persistent state accessors + // + + /// Returns compute capability of the selected device + int compute_capability() const; + + /// Sets the current CUDA stream + void set_stream(cudaStream_t stream); + + /// Gets the current CUDA stream + cudaStream_t get_stream() const; + + /// Gets the current provider + Provider get_provider() const; + + /// Sets the provider of operations + void set_provider(Provider provider); + + /// Gets the device workspace size + size_t get_workspace_size() const; + + /// Gets a pointer to the device workspace allocation in Global Memory + void *get_workspace() const; + + /// Sets the size of device workspace, invalidating calls to get_device_workspace() + void set_workspace_size(size_t bytes); + + /// Gets the scalar pointer mode + ScalarPointerMode get_scalar_pointer_mode() const; + + /// Sets the scalar pointer mode + void set_scalar_pointer_mode(ScalarPointerMode mode); + + /// Gets the most recently executed operation + Operation const *get_last_operation() const; + + // + // Computations + // + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C + Status gemm( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + + void const * ptr_A, /// Pointer to A matrix in Global Memory + int64_t lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + + void const * ptr_B, /// Pointer to B matrix in Global Memory + int64_t ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrices + + void const * ptr_C, /// Pointer to C matrix + int64_t ldc, /// Leading dimension of C matrix + + void * ptr_D, /// Pointer to D matrix + int64_t ldd /// Leading dimension of D matrix + ); + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C. + // + // Supports batched-strided, batched array or split-K serial or split-K parallel. + // + Status gemm_universal( + + GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + int cluster_m, /// cluster shape M dimension + int cluster_n, /// cluster shape N dimension + int cluster_k, /// cluster shape K dimension + int cluster_m_fallback, /// Fallback cluster shape M dimension + int cluster_n_fallback, /// Fallback cluster shape N dimension + int cluster_k_fallback, /// Fallback cluster shape K dimension + + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + void const * ptr_A, /// Pointer to A matrix in Global Memory + int64_t lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + void const * ptr_B, /// Pointer to B matrix in Global Memory + int64_t ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C matrix + LayoutTypeID layout_C, /// Layout of D matrix + void const * ptr_C, /// Pointer to C matrix + int64_t ldc, /// Leading dimension of C matrix + + NumericTypeID element_D, /// Data type of D matrix + LayoutTypeID layout_D, /// Layout of D matrix + void * ptr_D, /// Pointer to D matrix + int64_t ldd, /// Leading dimension of D matrix + + int batch_count = 1, /// Batch count or number of split-K slices + + int64_t batch_stride_A = 0, /// Batch stride of A operand + int64_t batch_stride_B = 0, /// Batch stride of B operand + int64_t batch_stride_C = 0, /// Batch stride of C operand + int64_t batch_stride_D = 0 /// Batch stride of D operand + ); + + /// Planar complex GEMM + /// + /// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel. + /// + Status gemm_planar_complex( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * ptr_A_real, /// Pointer to real part of A matrix + void const * ptr_A_imag, /// Pointer to imaginary part of A matrix + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * ptr_B_real, /// Pointer to real part of B matrix + void const * ptr_B_imag, /// Pointer to imaginary part of B matrix + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * ptr_C_real, /// Pointer to real part of C matrix + void const * ptr_C_imag, /// Pointer to imaginary part of C matrix + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * ptr_D_real, /// Pointer to real part of D matrix + void * ptr_D_imag, /// Pointer to imaginary part of D matrix + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix + + int batch_count = 1, /// Number of batched GEMMs to execute + + int64_t batch_stride_A_real = 0, + int64_t batch_stride_A_imag = 0, + + int64_t batch_stride_B_real = 0, + int64_t batch_stride_B_imag = 0, + + int64_t batch_stride_C_real = 0, + int64_t batch_stride_C_imag = 0, + + int64_t batch_stride_D_real = 0, + int64_t batch_stride_D_imag = 0 + ); + + /// Planar complex GEMM loading pointers from arrays in global memory + Status gemm_planar_complex_array( + + int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) + int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) + int expected_K, /// Expected GEMM K dimension + int batch_count, /// Number of independent GEMM computations to execute + + int const *M, /// Array containing the GEMM M dimension for each batch index + int const *N, /// Array containing the GEMM N dimension for each batch index + int const *K, /// Array containing the GEMM K dimension for each batch index + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices + void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices + + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices + void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices + + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices + void const * const * ptr_C_imag, /// Pointer to array containing pointers to imaginary part of C matrices + + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices + void * const * ptr_D_imag, /// Pointer to array containing pointers to imaginary part of D matrices + + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag /// Leading dimension of imaginary part of D matrix + ); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Unique pointer storing the handle +using HandlePtr = std::unique_ptr; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace +Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation); +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace +Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation); +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h new file mode 100644 index 0000000000000000000000000000000000000000..6764d9a6d81286c8bba0f5184b17819bfae86978 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h @@ -0,0 +1,995 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#ifndef CUTLASS_LIBRARY_LIBRARY_H +#define CUTLASS_LIBRARY_LIBRARY_H + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/library/types.h" +#include "cutlass/library/descriptions.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/blas3.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mode of Universal GEMM +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Base class for all operations +class Operation { +public: + + virtual ~Operation() { } + + virtual OperationDescription const & description() const = 0; + + virtual Status can_implement( + void const *configuration, + void const *arguments) const = 0; + + virtual uint64_t get_host_workspace_size( + void const *configuration) const = 0; + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const = 0; + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const = 0; + + // Originally designed for metadata, but should be useful for FP8/6/4 too. + virtual Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspace_ptrs, + int problem_count, + cudaStream_t stream = nullptr) { + return Status::kErrorNotSupported; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const = 0; + + // Set arguments that should only be set once before verifying or profiling the kernel. + // This should encompass any expensive operations that don't vary from run to run + // (e.g., max_active_clusters). + virtual Status initialize_with_arguments(void* arguments_ptr) const { + return Status::kSuccess; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic GEMM operations +// +// OperationKind: Gemm +// GemmKind: Gemm +// +struct GemmConfiguration { + + /// GEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Number of partitions of K dimension + int split_k_slices{0}; +}; + +/// Arguments for GEMM +struct GemmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for batched GEMM in which multiple matrix products are computed +// +// OperationKind: Gemm +// GemmKind: Batched + +struct GemmBatchedConfiguration { + + /// GEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Stride between instances of the A matrix in memory + int64_t batch_stride_A{0}; + + /// Stride between instances of the B matrix in memory + int64_t batch_stride_B{0}; + + /// Stride between instances of the C matrix in memory + int64_t batch_stride_C{0}; + + /// Stride between instances of the D matrix in memory + int64_t batch_stride_D{0}; + + /// Number of GEMMs in batch + int batch_count{1}; +}; + +/// Arguments to batched GEMM +using GemmBatchedArguments = GemmArguments; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for batched GEMM in which multiple matrix products are computed +// +// OperationKind: Gemm +// GemmKind: Array + +struct GemmArrayConfiguration { + + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + int batch_count{1}; +}; + +/// Arguments for GEMM - used by all the GEMM operations +struct GemmArrayArguments { + void const * const *A{nullptr}; + void const * const *B{nullptr}; + void const * const *C{nullptr}; + void * const *D{nullptr}; + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Universal GEMM supporting multiple split-K modes, multiple batched modes, real and complex +// +// OperationKind: Gemm +// GemmKind: Universal + +struct GemmUniversalConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int device_count{1}; +}; + +enum class Sm90MixedInputWiderOperand { + A = 0, + B = 1 +}; + +struct GemmUniversalArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + library::RuntimeDatatype runtime_input_datatype_a{}; + library::RuntimeDatatype runtime_input_datatype_b{}; + int swizzle_size{1}; + int split_k_slices{1}; + + // For SM90 mixed input dtype kernels + bool is_sm90_mixed_dtype{false}; + Sm90MixedInputWiderOperand wider_operand{Sm90MixedInputWiderOperand::B}; + bool generate_scale_and_zero{false}; + bool generate_dequantized_AB{false}; + void *Scale{nullptr}; // Scale tensor + void *Zero{nullptr}; // Zero tensor + void *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification + void *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle + void *packed_Scale{nullptr}; // Packed scale for int4 * fp8 + + int device_index{0}; + + bool use_pdl{false}; +}; + +/// Block Scaled GEMM +// +// OperationKind: kBlockScaledGemm +// GemmKind: Universal + +struct BlockScaledGemmArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *SFA{nullptr}; + void const *SFB{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + void *SFD{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + // Needed for ScaleFactor Generation + void const *norm_constant{nullptr}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + int swizzle_size{1}; + int split_k_slices{1}; + + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + + bool use_pdl{false}; +}; + +/// Blockwise GEMM +// +// OperationKind: kBlockwiseGemm +// GemmKind: Universal + +struct BlockwiseGemmArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *SFA{nullptr}; + void const *SFB{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + int sf_m_vec_size{0}; + int sf_n_vec_size{0}; + int sf_k_vec_size{0}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + int swizzle_size{1}; + int split_k_slices{1}; + + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + + bool use_pdl{false}; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Complex valued GEMM in which real and imaginary parts are separated by a stride +// +// OperationKind: Gemm +// GemmKind: Planar complex + +struct GemmPlanarComplexConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + int batch_count{1}; + int64_t lda_real{0}; + int64_t lda_imag{0}; + int64_t ldb_real{0}; + int64_t ldb_imag{0}; + int64_t ldc_real{0}; + int64_t ldc_imag{0}; + int64_t ldd_real{0}; + int64_t ldd_imag{0}; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArguments { + + void const *A_real{nullptr}; + void const *A_imag{nullptr}; + void const *B_real{nullptr}; + void const *B_imag{nullptr}; + void const *C_real{nullptr}; + void const *C_imag{nullptr}; + void *D_real{nullptr}; + void *D_imag{nullptr}; + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A_real{0}; + int64_t batch_stride_A_imag{0}; + int64_t batch_stride_B_real{0}; + int64_t batch_stride_B_imag{0}; + int64_t batch_stride_C_real{0}; + int64_t batch_stride_C_imag{0}; + int64_t batch_stride_D_real{0}; + int64_t batch_stride_D_imag{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This is a special form of planar complex which loads pointers and problem size +/// from memory. +struct GemmPlanarComplexArrayConfiguration { + + gemm::GemmCoord problem_size{}; + int batch_count{1}; + + int64_t lda_real{0}; + int64_t lda_imag{0}; + int64_t ldb_real{0}; + int64_t ldb_imag{0}; + int64_t ldc_real{0}; + int64_t ldc_imag{0}; + int64_t ldd_real{0}; + int64_t ldd_imag{0}; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArrayArguments { + + int const *M{nullptr}; + int const *N{nullptr}; + int const *K{nullptr}; + + void const * const * A_real{nullptr}; + void const * const * A_imag{nullptr}; + void const * const * B_real{nullptr}; + void const * const * B_imag{nullptr}; + void const * const * C_real{nullptr}; + void const * const * C_imag{nullptr}; + void * const * D_real{nullptr}; + void * const * D_imag{nullptr}; + + void const * alpha{nullptr}; + void const * beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Grouped GEMM supporting +// +// OperationKind: Gemm +// GemmKind: Grouped + +struct GemmGroupedConfiguration { + int problem_count{0}; + // GemmGroupedConfiguration is passed to initialize(), which + // is responsible for allocating the device-side stride storage. + int64_t* lda; + int64_t* ldb; + int64_t* ldc; + + cute::Shape* problem_sizes_3x_host; +}; + +struct GemmGroupedArguments { + int problem_count{}; + gemm::GemmCoord* problem_sizes{nullptr}; + + void* ptr_A{nullptr}; + void* ptr_B{nullptr}; + void* ptr_C{nullptr}; + void* ptr_D{nullptr}; + + int64_t* lda{nullptr}; + int64_t* ldb{nullptr}; + int64_t* ldc{nullptr}; + int64_t* ldd{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; + + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + + library::RasterOrder raster_order{}; + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + int swizzle_size{1}; + + // these should really be in the configuration but staying consistent with GEMM + int sm_count{0}; + int max_active_clusters{0}; + + // The user is responsible for allocating storage for problem sizes. + // Since GemmGroupedArguments is used by both the 2.x and 3.x APIs, we + // unfortunately need to have both options in this struct, and the + // underlying operation uses the one it needs. + cute::Shape* problem_sizes_3x; + cute::Shape* problem_sizes_3x_host; +}; + +struct GroupedGemmBlockScaledArguments : GemmGroupedArguments { + void* SFA{nullptr}; + void* SFB{nullptr}; + void* SFD{nullptr}; + void* norm_constant{nullptr}; +}; + +struct GroupedGemmBlockwiseArguments : GemmGroupedArguments { + void* SFA{nullptr}; + void* SFB{nullptr}; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// OperationKind: kSparseGemm +// + +/// Computes GEMM assuming one of the inputs has 2:4 structured sparsity. +struct SparseGemmConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + int batch_count{1}; /// number of sparse matrix products in batch + int64_t lda{0}; /// leading dimension of A operand + int64_t ldb{0}; /// leading dimension of B operand + int64_t ldc{0}; /// leading dimension of C operand + int64_t ldd{0}; /// leading dimension of D operand + int64_t lde{0}; /// leading dimension of E operand (metadata matrix) + int64_t batch_stride_A{0}; // stride between matrices + int64_t batch_stride_B{0}; // stride between matrices + int64_t batch_stride_C{0}; // stride between matrices + int64_t batch_stride_D{0}; // stride between matrices + int64_t batch_stride_E{0}; // stride between matrices +}; + +/// Arguments for sparse GEMMs +struct SparseGemmArguments { + void const *A{nullptr}; /// pointer to A matrix + void const *B{nullptr}; /// pointer to B matrix + void const *C{nullptr}; /// pointer to C matrix + void *D{nullptr}; /// pointer to D matrix + void const *E{nullptr}; /// pointer to E matrix (metadata) + void const *alpha{nullptr}; /// pointer to alpha scalar + void const *beta{nullptr}; /// pointer to beta scalar + ScalarPointerMode pointer_mode{}; /// enumerant indicating whether alpha/beta pointers are host + /// or device pointers. + bool use_pdl{false}; /// Whether to use PDL when launching the kernel +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic Rank K update operations +// +// OperationKind: (Syrk, Herk, Syr2k, Her2k) +// RankKKind: Universal +// +struct RankKConfiguration { + + /// SYRK problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for (Syrk, Herk, Syr2k, Her2k) +struct RankKArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix (used only for Syr2k and Her2k) + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic TRMM operations +// +// OperationKind: Trmm +// TrmmKind: Universal +// +struct TrmmConfiguration { + + /// TRMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for TRMM +struct TrmmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic SYMM/HEMM update operations +// +// OperationKind: (Symm, Hemm) +// SymmKind: Universal +// +struct SymmConfiguration { + + /// SYMM/HEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for (Symm, Hemm) +struct SymmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Two dimensional convolution +// +// OperationKind: Conv2d +// +struct Conv2dConfiguration { + + conv::SplitKMode split_k_mode; + + /// Conv2d problem size + // contains strictly conv2d size (N,H,W,C,K,R,S,P,Q,padding,stride,dilation,mode) + // also includes (split_k_slices, groups) + conv::Conv2dProblemSize problem_size{}; + + // stride of operand A + std::vector stride_a{}; + + // stride of operand B + std::vector stride_b{}; + + // stride of operand C + std::vector stride_c{}; +}; + + +/// Three dimensional convolution +// +// OperationKind: Conv3d +// +struct Conv3dConfiguration { + + conv::SplitKMode split_k_mode{}; + + /// Conv2d problem size + // contains strictly conv2d size (N,D,H,W,C,K,T,R,S,Z,P,Q,padding,stride,dilation,mode) + // also includes (split_k_slices, groups) + conv::Conv3dProblemSize problem_size{}; + + /// Layout object for activations tensor + layout::TensorNDHWC layout_activations{}; + + /// Layout object for filters tensor + layout::TensorNDHWC layout_filters{}; + + /// Layout object for source tensor + layout::TensorNDHWC layout_source{}; + + /// Layout object for output tensor + layout::TensorNDHWC layout_output{}; + + // + // Methods + // + + // Mapping functions (A,B,C -> activation,filter,output) + layout::TensorNDHWC layout_a(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_activations; + case library::ConvKind::kDgrad: return layout_output; + case library::ConvKind::kWgrad: return layout_output; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + layout::TensorNDHWC layout_b(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_filters; + case library::ConvKind::kDgrad: return layout_filters; + case library::ConvKind::kWgrad: return layout_activations; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + layout::TensorNDHWC layout_c(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_output; + case library::ConvKind::kDgrad: return layout_activations; + case library::ConvKind::kWgrad: return layout_filters; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } +}; + +/// Arguments for CONV +struct ConvArguments { + + ///////////////////////////////////////////////////////// + /// ImplicitGemm matrices A, B, C, D + ///////////////////////////////////////////////////////// + /// pointer to implicit gemm matrix A + void const *A{nullptr}; + + /// pointer to implicit gemm matrix B + void const *B{nullptr}; + + /// pointer to reordered matrix B + void const *reordered_B{nullptr}; + + /// pointer to implicit gemm matrix C + void const *C{nullptr}; + + /// pointer to implicit gemm destination matrix D + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for Reduction operations +// +// OperationKind: Reduction +// +struct ReductionConfiguration { + + /// Reduction problem size + MatrixCoord problem_size{}; + + /// Number of partitions to reduce + int partitions{0}; + + /// Number of elements between each partition + int64_t partition_stride{0}; + + /// leading dimension of 'w'orkspace operand + int64_t ldw{0}; + + /// leading dimension of 's'ource operand + int64_t lds{0}; + + /// leading dimension of 'd'estination operand + int64_t ldd{0}; +}; + +/// Arguments for Reduction +struct ReductionArguments { + + /// Pointer to workspace matrix + void const *workspace{nullptr}; + + /// Pointer to source matrix + void const *source{nullptr}; + + /// Pointer to destination matrix + void *destination{nullptr}; + + /// pointer to reference matrix + void *reference{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h new file mode 100644 index 0000000000000000000000000000000000000000..c4fb0ee8ca32124450b1063cc3613078e600479d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h @@ -0,0 +1,114 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Manifest of CUTLASS Library + + This is the root of the data structure containing CUTLASS objects +*/ + +#pragma once + +#include +#include +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "library.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Forward declaration +class Manifest; + +// init and insert all cutlass gemm operations in manifest object (procedurally generated using generator.py) +void initialize_all(Manifest &manifest); + +// init and insert all reduction op in manifest object (manually instantiated in library/reduction) +void initialize_all_reduction_op(Manifest &manifest); + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +/// List of operations +using OperationVector = std::vector>; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Manifest of CUTLASS Library +class Manifest { +private: + + /// Operation provider + Provider provider_; + + /// Global list of operations + OperationVector operations_; + +public: + Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { } + + /// Top-level initialization + Status initialize(); + + /// Used for initialization + void reserve(size_t operation_count); + + /// Graceful shutdown + Status release(); + + /// Appends an operation and takes ownership + void append(Operation *operation_ptr) {\ + // This function is inline s.t. it is present in generated libraries + // without having to compile or link in manifest.cpp + operations_.emplace_back(operation_ptr); + } + + /// Returns an iterator to the first operation + OperationVector const &operations() const; + + /// Returns a const iterator + OperationVector::const_iterator begin() const; + + /// Returns a const iterator + OperationVector::const_iterator end() const; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h new file mode 100644 index 0000000000000000000000000000000000000000..f36232c8dc833e2b24d681686f6662e79b7ecd0a --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h @@ -0,0 +1,905 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + \file + \brief Defines a data structure in which a set of functionally equivalent library::Operation + instances may be queried. +*/ + +#pragma once +#include +#include +#include +#include + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct GemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + ComplexTransform transform_A; + NumericTypeID element_B; + LayoutTypeID layout_B; + ComplexTransform transform_B; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + + // + // Methods + // + + inline + GemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + ComplexTransform transform_A = ComplexTransform::kNone, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + ComplexTransform transform_B = ComplexTransform::kNone, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor + ): + provider(provider), + gemm_kind(gemm_kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + transform_A(transform_A), + element_B(element_B), + layout_B(layout_B), + transform_B(transform_B), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D) + { } + + inline + bool operator==(GemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (transform_A == rhs.transform_A) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (transform_B == rhs.transform_B) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D); + } + + inline + bool operator!=(GemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " transform_A: " << to_string(k.transform_A) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " transform_B: " << to_string(k.transform_B) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for GemmFunctionalKey +struct GemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(GemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.element_compute)), 3) ^ + rotl(hash(int(key.element_scalar)), 4) ^ + rotl(hash(int(key.element_A)), 5) ^ + rotl(hash(int(key.layout_A)), 6) ^ + rotl(hash(int(key.transform_A)), 7) ^ + rotl(hash(int(key.element_B)), 8) ^ + rotl(hash(int(key.layout_B)), 9) ^ + rotl(hash(int(key.transform_B)), 10) ^ + rotl(hash(int(key.element_C)), 11) ^ + rotl(hash(int(key.layout_C)), 12) ^ + rotl(hash(int(key.element_D)), 13) ^ + rotl(hash(int(key.layout_D)), 14); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes a partial ordering to search for GEMM operators +struct GemmPreferenceKey { + + int compute_capability; + int alignment; + + // + // Methods + // + + GemmPreferenceKey(): compute_capability(), alignment() { } + + GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } + + bool operator<(GemmPreferenceKey const &rhs) const { + return (compute_capability < rhs.compute_capability) || + ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); + } + + bool operator==(GemmPreferenceKey const &rhs) const { + return compute_capability == rhs.compute_capability; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline +std::ostream& operator<< (std::ostream& out, const cutlass::library::GemmPreferenceKey& key) { + out << "{\n" + << "compute_capability : " << key.compute_capability << std::endl + << "alignment : " << key.alignment << std::endl + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Maps minimum compute capability onto a vector of possible operations +using GemmOperationVectorMap = std::map< + GemmPreferenceKey, + std::vector +>; + +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using GemmOperationFunctionalMap = std::unordered_map< + GemmFunctionalKey, + GemmOperationVectorMap, + GemmFunctionalKeyHasher +>; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for BlockScaled Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct BlockScaledGemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + OperationKind kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + NumericTypeID element_SFA; + NumericTypeID element_B; + LayoutTypeID layout_B; + NumericTypeID element_SFB; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + NumericTypeID element_SFD; + LayoutTypeID layout_SFD; + int SFVecSize; + int EpilogueSFVecSize; + // + // Methods + // + + inline + BlockScaledGemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + OperationKind kind = OperationKind::kBlockScaledGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFA = NumericTypeID::kF16, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFB = NumericTypeID::kF16, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFD = NumericTypeID::kF16, + LayoutTypeID layout_SFD = LayoutTypeID::kRowMajor, + int sf_vec_size = 32 + , int epilogue_sf_vec_size = 32 + ): + provider(provider), + gemm_kind(gemm_kind), + kind(kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + element_SFA(element_SFA), + element_B(element_B), + layout_B(layout_B), + element_SFB(element_SFB), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D), + element_SFD(element_SFD), + layout_SFD(layout_SFD), + SFVecSize(sf_vec_size) + , EpilogueSFVecSize(epilogue_sf_vec_size) + { } + + inline + bool operator==(BlockScaledGemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (kind == rhs.kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_SFA == rhs.element_SFA) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_SFB == rhs.element_SFB) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D) && + (element_SFD == rhs.element_SFD) && + (layout_SFD == rhs.layout_SFD) && + (SFVecSize == rhs.SFVecSize) + && (EpilogueSFVecSize == rhs.EpilogueSFVecSize) + ; + } + + inline + bool operator!=(BlockScaledGemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::BlockScaledGemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " kind: " << to_string(k.kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " element_SFA: " << to_string(k.element_SFA) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " element_SFB: " << to_string(k.element_SFB) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << " element_SFD: " << to_string(k.element_SFD) << "\n" + << " layout_SFD: " << to_string(k.layout_SFD) << "\n" + << " SFVecSize: " << k.SFVecSize << "\n" + << "EpilogueSFVecSize: " << k.EpilogueSFVecSize << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for BlockScaledGemmFunctionalKeyHasher +struct BlockScaledGemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(BlockScaledGemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.kind)), 3) ^ + rotl(hash(int(key.element_compute)), 4) ^ + rotl(hash(int(key.element_scalar)), 5) ^ + rotl(hash(int(key.element_A)), 6) ^ + rotl(hash(int(key.layout_A)), 7) ^ + rotl(hash(int(key.element_SFA)), 8) ^ + rotl(hash(int(key.element_B)), 9) ^ + rotl(hash(int(key.layout_B)), 10) ^ + rotl(hash(int(key.element_SFB)), 11) ^ + rotl(hash(int(key.element_C)), 12) ^ + rotl(hash(int(key.layout_C)), 13) ^ + rotl(hash(int(key.element_D)), 14) ^ + rotl(hash(int(key.layout_D)), 15) ^ + rotl(hash(int(key.element_SFD)), 16) ^ + rotl(hash(int(key.layout_SFD)), 17) ^ + rotl(hash(int(key.SFVecSize)), 18) ^ + rotl(hash(int(key.EpilogueSFVecSize)), 19) + ; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using BlockScaledGemmOperationFunctionalMap = std::unordered_map< + BlockScaledGemmFunctionalKey, + GemmOperationVectorMap, + BlockScaledGemmFunctionalKeyHasher +>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Blockwise Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct BlockwiseGemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + OperationKind kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + NumericTypeID element_SFA; + NumericTypeID element_B; + LayoutTypeID layout_B; + NumericTypeID element_SFB; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + // + // Methods + // + + inline + BlockwiseGemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + OperationKind kind = OperationKind::kBlockwiseGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFA = NumericTypeID::kF16, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFB = NumericTypeID::kF16, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, + int sfm_vec_size = 32, + int sfn_vec_size = 32, + int sfk_vec_size = 32 + ): + provider(provider), + gemm_kind(gemm_kind), + kind(kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + element_SFA(element_SFA), + element_B(element_B), + layout_B(layout_B), + element_SFB(element_SFB), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D), + SFMVecSize(sfm_vec_size), + SFNVecSize(sfn_vec_size), + SFKVecSize(sfk_vec_size) + { } + + inline + bool operator==(BlockwiseGemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (kind == rhs.kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_SFA == rhs.element_SFA) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_SFB == rhs.element_SFB) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D) && + (SFMVecSize == rhs.SFMVecSize) && + (SFNVecSize == rhs.SFNVecSize) && + (SFKVecSize == rhs.SFKVecSize); + } + + inline + bool operator!=(BlockwiseGemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::BlockwiseGemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " kind: " << to_string(k.kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " element_SFA: " << to_string(k.element_SFA) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " element_SFB: " << to_string(k.element_SFB) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << " SFMVecSize: " << k.SFMVecSize << "\n" + << " SFNVecSize: " << k.SFNVecSize << "\n" + << " SFKVecSize: " << k.SFKVecSize << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for BlockwiseGemmFunctionalKeyHasher +struct BlockwiseGemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(BlockwiseGemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.kind)), 3) ^ + rotl(hash(int(key.element_compute)), 4) ^ + rotl(hash(int(key.element_scalar)), 5) ^ + rotl(hash(int(key.element_A)), 6) ^ + rotl(hash(int(key.layout_A)), 7) ^ + rotl(hash(int(key.element_SFA)), 8) ^ + rotl(hash(int(key.element_B)), 9) ^ + rotl(hash(int(key.layout_B)), 10) ^ + rotl(hash(int(key.element_SFB)), 11) ^ + rotl(hash(int(key.element_C)), 12) ^ + rotl(hash(int(key.layout_C)), 13) ^ + rotl(hash(int(key.element_D)), 14) ^ + rotl(hash(int(key.layout_D)), 15) ^ + rotl(hash(int(key.SFMVecSize)), 16) ^ + rotl(hash(int(key.SFNVecSize)), 17) ^ + rotl(hash(int(key.SFKVecSize)), 18) + ; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using BlockwiseGemmOperationFunctionalMap = std::unordered_map< + BlockwiseGemmFunctionalKey, + GemmOperationVectorMap, + BlockwiseGemmFunctionalKeyHasher +>; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Conv Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying conv2d functional behavior +struct ConvFunctionalKey { + library::Provider provider; + library::ConvKind conv_kind; + library::NumericTypeID element_A; + library::LayoutTypeID layout_A; + library::NumericTypeID element_B; + library::LayoutTypeID layout_B; + library::NumericTypeID element_C; + library::LayoutTypeID layout_C; + library::NumericTypeID element_accumulator; + library::NumericTypeID element_compute; + + + // + // Methods + // + + inline + ConvFunctionalKey( + library::Provider provider = library::Provider::kInvalid, + library::ConvKind conv_kind = library::ConvKind::kFprop, + library::NumericTypeID element_A = library::NumericTypeID::kF16, + library::LayoutTypeID layout_A = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_B = library::NumericTypeID::kF16, + library::LayoutTypeID layout_B = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_C = library::NumericTypeID::kF16, + library::LayoutTypeID layout_C = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, + library::NumericTypeID element_compute = library::NumericTypeID::kF32 + ): + provider(provider), + conv_kind(conv_kind), + element_A(element_A), + layout_A(layout_A), + element_B(element_B), + layout_B(layout_B), + element_C(element_C), + layout_C(layout_C), + element_accumulator(element_accumulator), + element_compute(element_compute) + { } + + inline + bool operator==(ConvFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (conv_kind == rhs.conv_kind) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_accumulator == rhs.element_accumulator) && + (element_compute == rhs.element_compute); + } + + inline + bool operator!=(ConvFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream& operator<< (std::ostream& out, const cutlass::library::ConvFunctionalKey& key) { + out << "{\n" + << "provider: " << to_string(key.provider) << std::endl + << "conv_kind: " << to_string(key.conv_kind) << std::endl + << "element_A: " << to_string(key.element_A) << std::endl + << "layout_A: " << to_string(key.layout_A) << std::endl + << "element_B: " << to_string(key.element_B) << std::endl + << "layout_B: " << to_string(key.layout_B) << std::endl + << "element_C: " << to_string(key.element_C) << std::endl + << "layout_C: " << to_string(key.layout_C) << std::endl + << "element_accumulator: " << to_string(key.element_accumulator) << std::endl + << "element_compute: " << to_string(key.element_compute) << std::endl + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +struct ConvFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(ConvFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.conv_kind)), 2) ^ + rotl(hash(int(key.element_A)), 3) ^ + rotl(hash(int(key.layout_A)), 4) ^ + rotl(hash(int(key.element_B)), 5) ^ + rotl(hash(int(key.layout_B)), 6) ^ + rotl(hash(int(key.element_C)), 7) ^ + rotl(hash(int(key.layout_C)), 8) ^ + rotl(hash(int(key.element_accumulator)), 9) ^ + rotl(hash(int(key.element_compute)), 10); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes a partial ordering to search for Conv2d operators +struct ConvPreferenceKey { + + int compute_capability; + IteratorAlgorithmID iterator_algorithm; + + + // + // Methods + // + + ConvPreferenceKey(): compute_capability(), iterator_algorithm() { } + + ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): + compute_capability(cc), iterator_algorithm(iterator_algorithm) { } + + bool operator<(ConvPreferenceKey const &rhs) const { + return (compute_capability < rhs.compute_capability) || + ((compute_capability == rhs.compute_capability) && (iterator_algorithm < rhs.iterator_algorithm)); + } + + bool operator==(ConvPreferenceKey const &rhs) const { + return (compute_capability == rhs.compute_capability) && + (iterator_algorithm == rhs.iterator_algorithm); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Maps minimum compute capability onto a vector of possible operations +using ConvOperationVectorMap = std::map< + ConvPreferenceKey, + std::vector +>; + +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using ConvOperationFunctionalMap = std::unordered_map< + ConvFunctionalKey, + ConvOperationVectorMap, + ConvFunctionalKeyHasher +>; +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Tuple uniquely identifying conv2d functional behavior +struct ReductionFunctionalKey { + library::Provider provider; + library::NumericTypeID element_workspace; + library::NumericTypeID element_accumulator; + library::NumericTypeID element_output; + library::NumericTypeID element_compute; + library::MathOperationID reduce_math_op; + library::EpilogueKind epilogue_math_op; + + + // + // Methods + // + + inline + ReductionFunctionalKey( + library::Provider provider = library::Provider::kInvalid, + library::NumericTypeID element_workspace = library::NumericTypeID::kF16, + library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, + library::NumericTypeID element_output = library::NumericTypeID::kF16, + library::NumericTypeID element_compute = library::NumericTypeID::kF32, + library::MathOperationID reduce_math_op = library::MathOperationID::kAdd, + library::EpilogueKind epilogue_math_op = library::EpilogueKind::kLinearCombination + ): + provider(provider), + element_workspace(element_workspace), + element_accumulator(element_accumulator), + element_output(element_output), + element_compute(element_compute), + reduce_math_op(reduce_math_op), + epilogue_math_op(epilogue_math_op) + { } + + inline + bool operator==(ReductionFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (element_workspace == rhs.element_workspace) && + (element_accumulator == rhs.element_accumulator) && + (element_output == rhs.element_output) && + (element_compute == rhs.element_compute) && + (reduce_math_op == rhs.reduce_math_op) && + (epilogue_math_op == rhs.epilogue_math_op); + } + + inline + bool operator!=(ReductionFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +struct ReductionFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(ReductionFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.element_workspace)), 2) ^ + rotl(hash(int(key.element_accumulator)), 3) ^ + rotl(hash(int(key.element_output)), 4) ^ + rotl(hash(int(key.element_compute)), 5) ^ + rotl(hash(int(key.reduce_math_op)), 6) ^ + rotl(hash(int(key.epilogue_math_op)), 7); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline +std::ostream& operator<< (std::ostream& out, const ReductionFunctionalKey& key) { + out << "{\n" + << "provider: " << library::to_string(key.provider) << std::endl + << "element_workspace : " << library::to_string(key.element_workspace) << std::endl + << "element_accumulator : " << library::to_string(key.element_accumulator) << std::endl + << "element_output : " << library::to_string(key.element_output) << std::endl + << "element_compute : " << library::to_string(key.element_compute) << std::endl + << "}"; + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ReductionOperationFunctionalMap has NO preference key and a single instance per functional key +// i.e. only one tile size configuration per functional key +using ReductionOperationFunctionalMap = std::unordered_map< + ReductionFunctionalKey, + library::Operation const *, + ReductionFunctionalKeyHasher +>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Table of cutlass::library::Operation instances +class OperationTable { +public: + + /// Map of all operations of type kGemm + // provider (kCUTLASS) + GemmOperationFunctionalMap gemm_operations; + + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + BlockScaledGemmOperationFunctionalMap block_scaled_gemm_operations; + + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + BlockwiseGemmOperationFunctionalMap blockwise_gemm_operations; + + /// Map of all operations of type kConv2d + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + ConvOperationFunctionalMap conv2d_operations; + + /// Map of all operations of type kConv3d + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + ConvOperationFunctionalMap conv3d_operations; + + /// Map of all operations of type kConv2d + // provider (kCUTLASS) + ReductionOperationFunctionalMap reduction_operations; + +public: + + void append(Manifest const &manifest); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k); diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h new file mode 100644 index 0000000000000000000000000000000000000000..9a757433f38fbf10d9a352e07c7f3084a99e4098 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h @@ -0,0 +1,68 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/operation_table.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Singleton instance stores a Manifest and Operation table +class Singleton { +public: + + /// Manifest object + Manifest manifest; + + /// Operation table referencing the Manifest + OperationTable operation_table; + +public: + + Singleton(); + + static Singleton const &get(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h new file mode 100644 index 0000000000000000000000000000000000000000..9f8c4ff13ba543b4ec63997ba55e9278bfb357a6 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h @@ -0,0 +1,295 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + #pragma once + + ///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Layout type identifier +enum class LayoutTypeID { + kUnknown, + kColumnMajor, + kRowMajor, + kBlockScalingTensor, + kColumnMajorInterleavedK2, + kRowMajorInterleavedK2, + kColumnMajorInterleavedK4, + kRowMajorInterleavedK4, + kColumnMajorInterleavedK16, + kRowMajorInterleavedK16, + kColumnMajorInterleavedK32, + kRowMajorInterleavedK32, + kColumnMajorInterleavedK64, + kRowMajorInterleavedK64, + kTensorNCHW, + kTensorNCDHW, + kTensorNHWC, + kTensorNDHWC, + kTensorNC32HW32, + kTensorC32RSK32, + kTensorNC64HW64, + kTensorC64RSK64, + kInvalid +}; + +/// Numeric data type +enum class NumericTypeID { + kUnknown, + kVoid, + kB1, + kU2, + kU4, + kU8, + kU16, + kU32, + kU64, + kS2, + kS4, + kS8, + kS16, + kS32, + kS64, + kFE4M3, + kFE5M2, + + kFE2M3, + kFE3M2, + kFE2M1, + kFUE8M0, + kFUE4M3, + kF8, + kF6, + kF4, + + kF16, + kBF16, + kTF32, + kF32, + kF64, + kCF16, + kCBF16, + kCF32, + kCTF32, + kCF64, + kCS2, + kCS4, + kCS8, + kCS16, + kCS32, + kCS64, + kCU2, + kCU4, + kCU8, + kCU16, + kCU32, + kCU64, + kInvalid +}; + +/// Enumerated type describing a transformation on a complex value. +enum class ComplexTransform { + kNone, + kConjugate, + kInvalid +}; + +/// Providers +enum class Provider { + kNone, + kCUTLASS, + kReferenceHost, + kReferenceDevice, + kCUBLAS, + kCUDNN, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating the kind of operation +enum class OperationKind { + kGemm, + kBlockScaledGemm, + kBlockwiseGemm, + kRankK, + kRank2K, + kTrmm, + kSymm, + kConv2d, + kConv3d, + kEqGemm, + kSparseGemm, + kReduction, + kGroupedGemm, + kInvalid +}; + +/// Enumeration indicating whether scalars are in host or device memory +enum class ScalarPointerMode { + kHost, + kDevice, + kInvalid +}; + +/// Describes how reductions are performed across threadblocks +enum class SplitKMode { + kNone, + kSerial, + kParallel, + kParallelSerial, + kInvalid +}; + +/// Indicates the classificaition of the math instruction +enum class OpcodeClassID { + kSimt, + kTensorOp, + kWmmaTensorOp, + kSparseTensorOp, + kBlockScaledOp, + kInvalid +}; + +enum class MathOperationID { + kAdd, + kMultiplyAdd, + kMultiplyAddSaturate, + kMultiplyAddMixedInputUpcast, + kMultiplyAddFastBF16, + kMultiplyAddFastF16, + kMultiplyAddFastF32, + kMultiplyAddComplex, + kMultiplyAddComplexFastF32, + kMultiplyAddGaussianComplex, + kXorPopc, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating what kind of GEMM operation to perform +enum class GemmKind { + kGemm, + kBlockScaledGemm, + kSparse, + kUniversal, + kPlanarComplex, + kPlanarComplexArray, + kGrouped, + kInvalid +}; + +/// Enumeration indicating what kind of RankK update operation to perform +enum class RankKKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of TRMM operation to perform +enum class TrmmKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of SYMM/HEMM operation to perform +enum class SymmKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of Conv2d operation to perform +enum class ConvKind { + kUnknown, + kFprop, + kDgrad, + kWgrad, + kInvalid +}; + +enum class ConvModeID { + kCrossCorrelation, + kConvolution, + kInvalid +}; + +// Iterator algorithm enum in order of general performance-efficiency +enum class IteratorAlgorithmID { + kNone, + kAnalytic, + kOptimized, + kFixedChannels, + kFewChannels, + kInvalid +}; + + +enum class EpilogueKind { + kUnknown, + kConversion, + kLinearCombination, + kLinearCombinationClamp, + kLinearCombinationPlanarComplex, + kLinearCombinationRelu, + kLinearCombinationSigmoid, + kInvalid +}; + + +enum class RuntimeDatatype { + kStatic, + kE4M3, + kE5M2, + kE3M2, + kE2M3, + kE2M1, + + kInvalid +}; + + +enum class RasterOrder { + kAlongN, + kAlongM, + kHeuristic, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h new file mode 100644 index 0000000000000000000000000000000000000000..f537421751c1f2af3b95a2e1951006af441b28e0 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h @@ -0,0 +1,281 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief Utilities accompanying the CUTLASS library for interacting with Library types. +*/ + +#ifndef CUTLASS_LIBRARY_UTIL_H +#define CUTLASS_LIBRARY_UTIL_H + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexical cast from string +template T from_string(std::string const &); + +/// Converts a Provider enumerant to a string +char const *to_string(Provider provider, bool pretty = false); + +/// Parses a Provider enumerant from a string +template <> Provider from_string(std::string const &str); + +/// Converts a GemmKind enumerant to a string +char const *to_string(GemmKind type, bool pretty = false); + +/// Converts a RankKKind enumerant to a string +char const *to_string(RankKKind type, bool pretty = false); + +/// Converts a TrmmKind enumerant to a string +char const *to_string(TrmmKind type, bool pretty = false); + +/// Converts a SymmKind enumerant to a string +char const *to_string(SymmKind type, bool pretty = false); + +/// Converts a SideMode enumerant to a string +char const *to_string(SideMode type, bool pretty = false); + +/// Converts a FillMode enumerant to a string +char const *to_string(FillMode type, bool pretty = false); + +/// Converts a BlasMode enumerant to a string +char const *to_string(BlasMode type, bool pretty = false); + +/// Converts a DiagType enumerant to a string +char const *to_string(DiagType type, bool pretty = false); + +/// Converts a NumericType enumerant to a string +char const *to_string(OperationKind type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> OperationKind from_string(std::string const &str); + +/// Converts a NumericType enumerant to a string +char const *to_string(NumericTypeID type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> NumericTypeID from_string(std::string const &str); + +/// Returns the size of a data type in bits +int sizeof_bits(NumericTypeID type); + +/// Returns true if the numeric type is a complex data type or false if real-valued. +bool is_complex_type(NumericTypeID type); + +/// Returns the real-valued type underlying a type (only different from 'type' if complex) +NumericTypeID get_real_type(NumericTypeID type); + +/// Returns true if numeric type is integer +bool is_integer_type(NumericTypeID type); + +/// Returns true if numeric type is signed +bool is_signed_type(NumericTypeID type); + +/// Returns true if numeric type is a signed integer +bool is_signed_integer(NumericTypeID type); + +/// returns true if numeric type is an unsigned integer +bool is_unsigned_integer(NumericTypeID type); + +/// Returns true if numeric type is floating-point type +bool is_float_type(NumericTypeID type); + +/// To string method for cutlass::Status +char const *to_string(Status status, bool pretty = false); + +/// Converts a LayoutTypeID enumerant to a string +char const *to_string(LayoutTypeID layout, bool pretty = false); + +/// Parses a LayoutType enumerant from a string +template <> LayoutTypeID from_string(std::string const &str); + +/// Returns the rank of a layout's stride base on the LayoutTypeID +int get_layout_stride_rank(LayoutTypeID layout_id); + +/// Converts a OpcodeClassID enumerant to a string +char const *to_string(OpcodeClassID type, bool pretty = false); + +/// Converts a OpcodeClassID enumerant from a string +template <> +OpcodeClassID from_string(std::string const &str); + +/// Converts a ComplexTransform enumerant to a string +char const *to_string(ComplexTransform type, bool pretty = false); + +/// Converts a ComplexTransform enumerant from a string +template <> +ComplexTransform from_string(std::string const &str); + + +/// Converts a SplitKMode enumerant to a string +char const *to_string(SplitKMode split_k_mode, bool pretty = false); + +/// Converts a SplitKMode enumerant from a string +template <> +SplitKMode from_string(std::string const &str); + +/// Converts a ConvModeID enumerant to a string +char const *to_string(ConvModeID type, bool pretty = false); + +/// Converts a ConvModeID enumerant from a string +template <> +ConvModeID from_string(std::string const &str); + +/// Converts a IteratorAlgorithmID enumerant to a string +char const *to_string(IteratorAlgorithmID type, bool pretty = false); + +/// Converts a IteratorAlgorithmID enumerant from a string +template <> +IteratorAlgorithmID from_string(std::string const &str); + +/// Converts a ConvKind enumerant to a string +char const *to_string(ConvKind type, bool pretty = false); + +/// Converts a ConvKind enumerant from a string +template <> +ConvKind from_string(std::string const &str); + + +/// Converts a RuntimeDatatype enumerant to a string +char const *to_string(cutlass::library::RuntimeDatatype type, bool pretty = false); + +/// Convers a RuntimeDatatype enumerant from a string +template<> +cutlass::library::RuntimeDatatype from_string(std::string const &str); + + +/// Converts a RasterOrder enumerant to a string +char const *to_string(RasterOrder type, bool pretty = false); + +/// Convers a RasterOrder enumerant from a string +template<> +RasterOrder from_string(std::string const &str); + +/// Converts a bool to a string +char const *to_string(bool type, bool pretty = false); + +/// Convers a bool from a string +template<> +bool from_string(std::string const &str); + +/// Lexical cast from int64_t to string +std::string lexical_cast(int64_t int_value); + +/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. +bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); + +/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. +std::string lexical_cast(std::vector &bytes, NumericTypeID type); + +/// Casts from a signed int64 to the destination type. Returns true if successful. +bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); + +/// Casts from an unsigned int64 to the destination type. Returns true if successful. +bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); + +/// Casts from a real value represented as a double to the destination type. Returns true if successful. +bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); + +NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type); + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = (call); \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(err) << " in " << __func__ << " at " \ + << __FILE__ << ":" << __LINE__ << std::endl; \ + return Status::kInvalid; \ + } \ + } while (0) + +// RAII CUDA buffer container +class CudaBuffer { +public: + CudaBuffer() : size_(0), d_ptr_(nullptr) {} + + explicit CudaBuffer(size_t size) : size_(size), d_ptr_(nullptr) { + cudaError_t err = cudaMalloc(&d_ptr_, size_); + if (err != cudaSuccess) { + throw std::runtime_error("cudaMalloc failed: " + std::string(cudaGetErrorString(err))); + } + } + + ~CudaBuffer() { + if (d_ptr_) { + cudaFree(d_ptr_); + } + } + + CudaBuffer(CudaBuffer const&) = delete; + CudaBuffer& operator=(CudaBuffer const&) = delete; + + CudaBuffer(CudaBuffer&& other) noexcept : size_(other.size_), d_ptr_(other.d_ptr_) { + other.d_ptr_ = nullptr; + other.size_ = 0; + } + + CudaBuffer& operator=(CudaBuffer&& other) noexcept { + if (this != &other) { + if (d_ptr_) { + cudaFree(d_ptr_); + } + d_ptr_ = other.d_ptr_; + size_ = other.size_; + other.d_ptr_ = nullptr; + other.size_ = 0; + } + return *this; + } + + void* data() const noexcept { return d_ptr_; } + size_t size() const noexcept { return size_; } + +private: + size_t size_; + void* d_ptr_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c96b9a2212b42c191551ea70da3ac3baecbed487 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "gemm_operation_3x.hpp" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::CollectiveMainloop::ElementA; + using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::CollectiveMainloop::ElementB; + using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + constexpr static int SFVecSize = TiledMma::SFVecSize; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using Sm1xxBlkScaledConfig = typename CollectiveMainloop::Sm1xxBlkScaledConfig; + + static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; + static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; + using ElementSFD = cute::conditional_t; + using LayoutSFD = cute::conditional_t; + + + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB; + + +private: + BlockScaledGemmDescription description_; + +public: + + /// Constructor + BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + description_.kind = OperationKind::kBlockScaledGemm; + description_.SFA.element = NumericTypeMap::kId; + description_.SFA.layout = LayoutTypeID::kRowMajor; + description_.SFA.alignment = 128; + description_.SFA.log_extent_range = 32; + description_.SFA.log_stride_range = 32; + + description_.SFB.element = NumericTypeMap::kId; + description_.SFB.layout = LayoutTypeID::kRowMajor; + description_.SFB.alignment = 128; + description_.SFB.log_extent_range = 32; + description_.SFB.log_stride_range = 32; + + description_.SFVecSize = SFVecSize; + + description_.SFD = make_TensorDescription(128); + description_.EpilogueSFVecSize = SFD_VectorSize; + + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.gemm_kind = GemmKind::kUniversal; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + BlockScaledGemmDescription const& get_gemm_description() const { + return description_; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) { + + if constexpr (epilogue_scalefactor_generation) { + fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); + fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); + } + + + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + BlockScaledGemmArguments const *arguments) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; + + static_assert(cute::is_same_v, + "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); + using RuntimeDatatypeArg = RuntimeDataTypeA; + + auto mapping = [](RuntimeDatatype type) { + if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE3M2) { + return cute::UMMA::MXF8F6F4Format::E3M2; + } else if (type == RuntimeDatatype::kE2M3) { + return cute::UMMA::MXF8F6F4Format::E2M3; + } else if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF8F6F4Format::E2M1; + } else { + assert("Invalid input datatype."); + } + } + else if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF4Format::E2M1; + } else { + assert("Invalid input datatype."); + } + } + // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype + CUTE_GCC_UNREACHABLE; + }; + + operator_args.mainloop.runtime_data_type_a = mapping(arguments->runtime_input_datatype_a); + operator_args.mainloop.runtime_data_type_b = mapping(arguments->runtime_input_datatype_b); + + } + else { + + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + operator_args.mainloop.layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); + operator_args.mainloop.layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + BlockScaledGemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..00347a993e29035e58401e69698267045b399f7d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp @@ -0,0 +1,429 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "gemm_operation_3x.hpp" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BlockwiseGemmUniversal3xOperation : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::CollectiveMainloop::ElementA; + using ElementSFA = typename Operator::ElementAccumulator; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::CollectiveMainloop::ElementB; + using ElementSFB = typename Operator::ElementAccumulator; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + +private: + BlockwiseGemmDescription description_; + +public: + + /// Constructor + BlockwiseGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + description_.kind = OperationKind::kBlockwiseGemm; + description_.SFA.element = NumericTypeMap::kId; + description_.SFA.layout = size<0,1>(typename CollectiveMainloop::LayoutSFA{}.stride()) == 1 ? + LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + description_.SFA.alignment = CollectiveMainloop::AlignmentSFA; + description_.SFA.log_extent_range = 32; + description_.SFA.log_stride_range = 32; + + description_.SFB.element = NumericTypeMap::kId; + description_.SFB.layout = size<0,1>(typename CollectiveMainloop::LayoutSFB{}.stride()) == 1 ? + LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor; + description_.SFB.alignment = CollectiveMainloop::AlignmentSFA; + description_.SFB.log_extent_range = 32; + description_.SFB.log_stride_range = 32; + + description_.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM; + description_.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN; + description_.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK; + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.gemm_kind = GemmKind::kUniversal; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + BlockwiseGemmDescription const& get_gemm_description() const { + return description_; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockwiseGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockwiseGemmArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + BlockwiseGemmArguments const *arguments) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + operator_args.mainloop.layout_SFA = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); + operator_args.mainloop.layout_SFB = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + BlockwiseGemmArguments const *arguments = + static_cast(arguments_ptr); + + if (arguments->sf_m_vec_size != description_.SFMVecSize && arguments->sf_m_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + if (arguments->sf_n_vec_size != description_.SFNVecSize && arguments->sf_n_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + if (arguments->sf_k_vec_size != description_.SFKVecSize && arguments->sf_k_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..3b1a1584db92c4379e04c84a2658f79313b3eaad --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h @@ -0,0 +1,650 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" +#include "cutlass/conv/kernel/default_depthwise_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/kernel/default_conv2d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Conv2dOperationBase : public Operation { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ConvDescription description_; + +public: + + /// Constructor + Conv2dOperationBase(char const *name = "unknown_conv2d") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kConv2d; + description_.conv_dim = Operator::kConvDim; + + description_.iterator_algorithm = IteratorAlgorithmMap::kId; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::UnderlyingKernel::WarpCount::kM, + Operator::UnderlyingKernel::WarpCount::kN, + Operator::UnderlyingKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + + // TODO: Add split k mode Serial and parallel to convolutions + // description_.split_k_mode = Operator::kSplitK ? SplitKMode::kSerial : SplitKMode::kNone; + + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Conv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Conv2dOperation : public Conv2dOperationBase { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + Conv2dOperation(char const *name = "unknown_conv2d_fprop") : Conv2dOperationBase(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv2dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv2dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv2dOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << "}" << std::endl; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// DirectConv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DirectConv2dOperation : public Conv2dOperation { +public: + + using Operator = Operator_; + using Base = Conv2dOperation; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + DirectConv2dOperation(char const *name = "unknown_direct)conv2d_fprop") : Conv2dOperation(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv2dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_reordered_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + operator_args.ref_reordered_B.reset(static_cast(const_cast(arguments->reordered_B))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv2dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv2dOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << "}" << std::endl; + } +}; + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..fe402c4494c27a882bf42f867a708e954ee87dc0 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h @@ -0,0 +1,389 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv3d_fprop.h" +#include "cutlass/conv/kernel/default_conv3d_dgrad.h" +#include "cutlass/conv/kernel/default_conv3d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Conv3dOperationBase : public Operation { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ConvDescription description_; + +public: + + /// Constructor + Conv3dOperationBase(char const *name = "unknown_conv3d") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kConv3d; + description_.conv_dim = Operator::kConvDim; + + description_.iterator_algorithm = IteratorAlgorithmMap::kId; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::UnderlyingKernel::WarpCount::kM, + Operator::UnderlyingKernel::WarpCount::kN, + Operator::UnderlyingKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Conv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Conv3dOperation : public Conv3dOperationBase { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + Conv3dOperation(char const *name = "unknown_conv3d_fprop") : Conv3dOperationBase(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv3dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv3dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv3dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv3dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv3dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv3dOperation::OperatorArguments" << std::endl + << " problem_size: " + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << ", " + << operator_args.ref_A.stride(3) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << ", " + << operator_args.ref_B.stride(3) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << ", " + << operator_args.ref_C.stride(3) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << ", " + << operator_args.ref_D.stride(3) << "}" << std::endl; + } +}; + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86c1513e9c934c22e281cf37e1c5e7783e23d305 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp @@ -0,0 +1,980 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include +#include +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) +#include +#endif + +namespace cutlass::library { + +namespace detail { + +template +constexpr cute::array +vector_to_array_strides_helper(const std::vector& v, + std::index_sequence) +{ + return {v[(sizeof...(Indices) - 1u) - Indices]..., ValueType(1)}; +} + +template +cute::array +vector_to_array_strides(const std::vector& v, std::integral_constant) +{ + static_assert(Size != 0); + CUTLASS_ASSERT(v.size() + 1u == Size); + return vector_to_array_strides_helper(v, std::make_index_sequence{}); +} + +template +constexpr cute::array +coord_to_array_strides_helper( + const ::cutlass::Coord coord, + std::index_sequence) +{ + return {int64_t(coord[(sizeof...(Indices) - 1u) - Indices])..., int64_t(1)}; +} + +template +cute::array +coord_to_array_strides(const ::cutlass::Coord& coord) +{ + static_assert(Rank >= 0); + return coord_to_array_strides_helper(coord, std::make_index_sequence{}); +} + +} // namespace detail + +// Tells the profiler about CUTLASS 3's 2-D and 3-D convolutions. +// For CUTLASS 2's 2-D convolutions, see Conv2dOperation. +// For CUTLASS 2's 3-D convolutions, see Conv3dOperation. +template +class ConvOperation3x : public Operation { +public: + using Operator = Operator_; + + static_assert(Operator::NumSpatialDimensions == 2 || + Operator::NumSpatialDimensions == 3, + "The profiler currently only supports convolutions with 2 or 3 spatial dimensions."); + using LayoutA = cute::conditional_t + >; + using LayoutB = LayoutA; + using LayoutC = LayoutA; + + using ElementA = typename Operator::ElementA; + using ElementB = typename Operator::ElementB; + using ElementC = typename Operator::ElementC; + using ElementD = typename Operator::ElementD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + ConvOperation3x(const char* name = "unknown_cutlass_3_conv") { + // Initialize OperationDescription (the base class) + description_.name = name; + description_.provider = Provider::kCUTLASS; + + if constexpr (Operator::NumSpatialDimensions == 2) { + description_.kind = OperationKind::kConv2d; + } + else if constexpr (Operator::NumSpatialDimensions == 3) { + description_.kind = OperationKind::kConv3d; + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationID::kMultiplyAdd; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + // Initialize ConvDescription (the subclass) + + // kConvDim does not exist in Operator for CUTLASS 3 convolutions. + // For CUTLASS 2 convolutions, it is the number of spatial dimensions. + description_.conv_dim = Operator::NumSpatialDimensions; + description_.conv_kind = ConvKindMap::kId; + + description_.iterator_algorithm = {}; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + } + + ~ConvOperation3x() override = default; + + OperationDescription const& description() const override { + return static_cast(description_); + } + +private: + Status update_operator_arguments_from_configuration_2d_or_3d( + typename Operator::Arguments& out_args, + void const* configuration) const { + Status status = Status::kInvalid; + + CUTLASS_ASSERT(configuration != nullptr); + + if constexpr (Operator::NumSpatialDimensions == 2) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d); + // tools/library/include/cutlass/library/library.h + // defines Conv2dConfiguration. + // tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h + // uses Conv2dConfiguration. + auto* conf_ptr = reinterpret_cast(configuration); + status = update_operator_arguments_from_configuration(out_args, *conf_ptr); + } + else if constexpr (Operator::NumSpatialDimensions == 3) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d); + auto* conf_ptr = reinterpret_cast(configuration); + status = update_operator_arguments_from_configuration(out_args, *conf_ptr); + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + + return status; + } + +public: + Status can_implement( + void const* configuration, + void const* arguments) const override { + Status status = Status::kInvalid; + + // gemm_operation_3x.hpp accesses "configuration" as + // GemmUniversalConfiguration (which lives in + // tools/library/include/cutlass/library/library.h) and + // "arguments" as GemmUniversalArguments (which lives in + // tools/library/include/cutlass/library/library.h). + // Those things don't apply to convolutions. + // Despite the existence of ConvUniversal, there's no + // corresponding "ConvUniversalConfiguration" or + // "ConvUniversalArguments." + + CUTLASS_ASSERT(configuration != nullptr); + CUTLASS_ASSERT(arguments != nullptr); + + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_configuration_2d_or_3d failed"); + return status; + } + + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_arguments failed"); + return status; + } + + return Operator::can_implement(out_args); + } + + uint64_t get_host_workspace_size(void const* /* configuration */) const override { + return sizeof(Operator); + } + + uint64_t get_device_workspace_size( + void const* configuration, + void const* arguments = nullptr) const override + { + // This presumes that at least one of configuration or arguments is nonnull. + Status status = Status::kInvalid; + + // gemm_operation_3x.hpp has get_device_workspace_size return 0 on + // error. It's not clear that this is what we want -- perhaps we + // should return something like expected? -- but + // it's the only option that preserves the current interface. + constexpr uint64_t error_indication = 0; + + typename Operator::Arguments out_args{}; + if (configuration != nullptr) { + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + return error_indication; + } + } + if (arguments != nullptr) { + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + return error_indication; + } + } + + if (status == Status::kSuccess) { + return static_cast(Operator::get_workspace_size(out_args)); + } + else { + return error_indication; + } + } + + Status initialize( + void const* configuration, + void* host_workspace, + void* /* device_workspace */ = nullptr, + cudaStream_t stream = nullptr) const override + { + Status status = Status::kInvalid; + + if (configuration == nullptr) { + CUTLASS_TRACE_HOST("Input configuration is null."); + return Status::kInvalid; + } + + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + // Any kind of failure invalidates the last successful configuration. + clear_last_successful_config(); + return status; + } + else { + set_last_successful_config(configuration); + } + + if (host_workspace == nullptr) { + CUTLASS_TRACE_HOST("host_workspace is null."); + return Status::kInvalid; + } + (void) new (host_workspace) Operator; + return status; + + // CUTLASS 2 convolutions call the Operator's initialize function + // here, like this. + // + //return op->initialize(args, device_workspace, stream); + // + // CUTLASS 3 convolutions (ConvUniversal), like CUTLASS 3 Gemms + // (GemmUniversal), lack an "initialize" member function. + } + + Status run( + void const* arguments, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override + { + auto status = Status::kInvalid; + + // The Operator doesn't appear to save the last configuration (it + // doesn't have a way to do that, since it lacks an initialize() + // member function), so we have to use the stored configuration + // from the last successful initialize() call (if any). + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_stored_configuration(out_args); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("Updating from previous successful configuration failed."); + return status; + } + + if (arguments == nullptr) { + CUTLASS_TRACE_HOST("Input argument 'arguments' is null."); + return Status::kInvalid; + } + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + return status; + } + + auto* op = reinterpret_cast(host_workspace); + return op->run(out_args, device_workspace, stream, nullptr, in_args_ptr->use_pdl); + } + +private: + ConvDescription description_; + // Result of initialize() calling + // update_operator_arguments_from_configuration() successfully. + // This is needed because run() doesn't take a configuration, just + // arguments, and the kernel doesn't appear to save the + // configuration from the last initialize() call. + // + // Unfortunately, this must be declared mutable, because it must be + // set in initialize(), and initialize() is inherited as const. + mutable std::variant< + std::monostate, + Conv2dConfiguration, + Conv3dConfiguration> last_successful_config_{std::monostate{}}; + + // Clear the last configuration resulting from a successful initialize() call. + // + // Unfortunately, this must be declared const, because initialize() is. + void clear_last_successful_config() const { + last_successful_config_ = std::monostate{}; + } + + // Set the last configuration resulting from a successful initialize() call. + // + // Unfortunately, this must be declared const, because initialize() is. + void set_last_successful_config(void const* configuration) const { + CUTLASS_ASSERT(configuration != nullptr); + + if constexpr (Operator::NumSpatialDimensions == 2) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d); + auto* conf_ptr = reinterpret_cast(configuration); + last_successful_config_ = *conf_ptr; + } else if constexpr (Operator::NumSpatialDimensions == 3) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d); + auto* conf_ptr = reinterpret_cast(configuration); + last_successful_config_ = *conf_ptr; + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + } + + // Whether a configuration from a successful initialize() call exists. + bool last_successful_config_exists() const { + return not std::holds_alternative(last_successful_config_); + } + + // Visitor for update_operator_arguments_from_stored_configuration. + struct ConfigurationVisitor { + typename Operator::Arguments& out_args; + + Status operator() (std::monostate const&) const { + CUTLASS_TRACE_HOST("No successful previous configuration exists. " + "One cause is calling run() before a successful initialize() call."); + return Status::kInvalid; + } + Status operator() (Conv2dConfiguration const& conf2d) const { + return update_operator_arguments_from_configuration(out_args, conf2d); + } + Status operator() (Conv3dConfiguration const& conf3d) const { + return update_operator_arguments_from_configuration(out_args, conf3d); + } + }; + + // Like update_operator_arguments_from_configuration, but on the + // stored configuration from the last successful initialize() call, + // if any. If there was no last successful initialize() call, + // then return Status::kInvalid. + // + // Unfortunately, this must be declared const, because run() is. + Status update_operator_arguments_from_stored_configuration( + typename Operator::Arguments& out_args) const + { + return std::visit(ConfigurationVisitor{out_args}, last_successful_config_); + } + + template + struct UpdateFusionArgs { + static Status update_( + FusionArgs const&, + ConvArguments const&) + { + // For custom EVT, it is the user's responsibility to ensure + // that alpha and beta are updated appropriately. + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_( + FusionArgs& fusion_args, + ConvArguments const& arguments) + { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + static Status update_operator_arguments_from_configuration( + typename Operator::Arguments& out_args, + Conv2dConfiguration const& config) + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv2dConfiguration)\n"); +#endif + using detail::vector_to_array_strides; + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + if constexpr (num_spatial_dims != 2) { + CUTLASS_TRACE_HOST("You can only use Conv2dConfiguration " + "with an Operator whose NumSpatialDimensions is exactly 2."); + return Status::kInvalid; + } + else { + // Convolutions split the metadata (in Conv2dConfiguration) from + // the data (ConvArguments, which only has pointers and a single + // enum value). Thus, this class will need both the + // configuration and the (user's input) arguments to set up the + // kernel's arguments. This function can fill in what the + // configuration has now, but the class will need the user's + // input arguments later. + if (config.split_k_mode != conv::SplitKMode::kSerial) { + CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial."); + return Status::kInvalid; + } + // config.problem_size.split_k_slices is only meaningful if + // split_k_mode != kSerial. If this code later supports other + // split_k_mode values, then it will also need to read + // split_k_slices. + + const int N = config.problem_size.N; + const int H = config.problem_size.H; + const int W = config.problem_size.W; + const int C = config.problem_size.C; + const int K = config.problem_size.K; + const int R = config.problem_size.R; + const int S = config.problem_size.S; + const int pad_h = config.problem_size.pad_h; + const int pad_w = config.problem_size.pad_w; + const int traversal_stride_h = config.problem_size.stride_h; + const int traversal_stride_w = config.problem_size.stride_w; + const int dilation_h = config.problem_size.dilation_h; + const int dilation_w = config.problem_size.dilation_w; + + // CUTLASS 3's implicit GEMM convolution kernels currently only + // support cross correlation (passing over the activation and + // filter tensors in the same order). The convolution mode is + // future work. + const auto mode = config.problem_size.mode; + if (mode != cutlass::conv::Mode::kCrossCorrelation) { + CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation " + "are not currently supported."); + return Status::kInvalid; + } + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + constexpr size_t stride_size = size_t(num_spatial_dims) + 2u; + constexpr auto the_stride_size = std::integral_constant{}; + +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n" + << " stride_size = " << stride_size << "\n"; + auto print_stride = [] (auto const& stride, char const variable_name[]) { + std::cerr << " " << variable_name << ": ["; + for (size_t k = 0; k < stride.size(); ++k) { + std::cerr << stride[k]; + if (k + 1u < stride.size()) { + std::cerr << ", "; + } + } + std::cerr << "]\n"; + }; + print_stride(config.stride_a, "config.stride_a"); + print_stride(config.stride_b, "config.stride_b"); + print_stride(config.stride_c, "config.stride_c"); +#endif + + // Conv2dConfiguration stores the strides as std::vector, + // so the code needs to check the run-time vector lengths. + if (config.stride_a.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_a.size() + 1u = " + << (config.stride_a.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + if (config.stride_b.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_b.size() + 1u = " + << (config.stride_b.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + if (config.stride_c.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_c.size() + 1u = " + << (config.stride_c.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + + constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; + using problem_shape_type = + cutlass::conv::ConvProblemShape; + // cute::array; must convert to the kernel's native strides + using TensorStride = typename problem_shape_type::TensorStride; + + const TensorStride stride_A = vector_to_array_strides(config.stride_a, the_stride_size); + const TensorStride stride_B = vector_to_array_strides(config.stride_b, the_stride_size); + const TensorStride stride_C = vector_to_array_strides(config.stride_c, the_stride_size); + + // cutlass::library::Conv2dConfiguration has no member stride_d. + // The code below imitates the testbed, + // which just sets D's strides to C's strides. + + const int num_groups = config.problem_size.groups; + if (num_groups != 1) { + CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); + return Status::kInvalid; + } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // This means that stride_act isn't always config.stride_A, + // depending on Fprop / Dgrad / Wgrad. The code here "undoes" + // the logic in Conv2dWorkspace::set_stride_vector so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + + problem_shape_type problem_shape( + /* mode = */ mode, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, + /* lower_padding = */ {pad_h, pad_w}, + /* upper_padding = */ {pad_h, pad_w}, + /* traversal_stride = */ {traversal_stride_h, traversal_stride_w}, + /* dilation = */ {dilation_h, dilation_w}, + num_groups); + out_args.problem_shape = problem_shape; + + // ConvProblemShape's constructor sets its shape_C member. +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); +#endif + // Initialization of C's and D's strides follows the CUTLASS 3 + // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). + { + using StrideC = typename Operator::ConvKernel::StrideC; + using StrideD = typename Operator::ConvKernel::StrideD; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + + if constexpr (conv_op == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " Wgrad: stride_C: " << stride_C << "\n"; +#endif + } + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): " + << stride_C_i << "\n"; +#endif + cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): " + << stride_D_i << "\n"; +#endif + cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + } + out_args.epilogue.dC = stride_C; + out_args.epilogue.dD = stride_D; + } + return Status::kSuccess; + } + } + + static Status update_operator_arguments_from_configuration( + typename Operator::Arguments& out_args, + Conv3dConfiguration const& config) + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv3dConfiguration)\n"); +#endif + using detail::coord_to_array_strides; + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + if constexpr (num_spatial_dims != 3) { + CUTLASS_TRACE_HOST("You can only use Conv3dConfiguration " + "with an Operator whose NumSpatialDimensions is exactly 3."); + return Status::kInvalid; + } + else { + // Convolutions split the metadata (in Conv3dConfiguration) from + // the data (ConvArguments, which only has pointers and a single + // enum value). Thus, this class will need both the + // configuration and the (user's input) arguments to set up the + // kernel's arguments. This function can fill in what the + // configuration has now, but the class will need the user's + // input arguments later. + if (config.split_k_mode != conv::SplitKMode::kSerial) { + CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial."); + return Status::kInvalid; + } + // config.problem_size.split_k_slices is only meaningful if + // split_k_mode != kSerial. If this code later supports other + // split_k_mode values, then it will also need to read + // split_k_slices. + + const int N = config.problem_size.N; + const int D = config.problem_size.D; + const int H = config.problem_size.H; + const int W = config.problem_size.W; + const int C = config.problem_size.C; + const int K = config.problem_size.K; + const int T = config.problem_size.T; + const int R = config.problem_size.R; + const int S = config.problem_size.S; + const int pad_d = config.problem_size.pad_d; + const int pad_h = config.problem_size.pad_h; + const int pad_w = config.problem_size.pad_w; + const int traversal_stride_d = config.problem_size.stride_d; + const int traversal_stride_h = config.problem_size.stride_h; + const int traversal_stride_w = config.problem_size.stride_w; + const int dilation_d = config.problem_size.dilation_d; + const int dilation_h = config.problem_size.dilation_h; + const int dilation_w = config.problem_size.dilation_w; + + // CUTLASS 3's implicit GEMM convolution kernels currently only + // support cross correlation (passing over the activation and + // filter tensors in the same order). The convolution mode is + // future work. + const auto mode = config.problem_size.mode; + if (mode != cutlass::conv::Mode::kCrossCorrelation) { + CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation " + "are not currently supported."); + return Status::kInvalid; + } + + using Stride = cutlass::layout::TensorNDHWC::Stride; + static_assert(std::is_same_v>); + + const cutlass::library::ConvKind conv_kind = [] () { + constexpr cutlass::conv::Operator op = Operator::DispatchPolicy::ConvOp; + if constexpr (op == cutlass::conv::Operator::kFprop) { + return library::ConvKind::kFprop; + } + else if constexpr (op == cutlass::conv::Operator::kDgrad) { + return library::ConvKind::kDgrad; + } + else /* if constexpr (op == cutlass::conv::Operator::kWgrad) */ { + return library::ConvKind::kWgrad; + } + } (); + const Stride input_stride_a = config.layout_a(conv_kind).stride(); + const Stride input_stride_b = config.layout_b(conv_kind).stride(); + const Stride input_stride_c = config.layout_c(conv_kind).stride(); + +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + constexpr size_t stride_size = size_t(num_spatial_dims) + 2u; + std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n" + << " stride_size = " << stride_size << "\n"; + auto print_stride = [] (Stride const& stride, char const variable_name[]) { + std::cerr << " " << variable_name << ": ["; + for (size_t k = 0; k < Stride::kRank; ++k) { + std::cerr << stride[static_cast(k)]; + if (k + 1u < Stride::kRank) { + std::cerr << ", "; + } + } + std::cerr << "]\n"; + }; + print_stride(input_stride_a, "input_stride_a"); + print_stride(input_stride_b, "input_stride_b"); + print_stride(input_stride_c, "input_stride_c"); +#endif + // Conv3dConfiguration stores the strides as Coord (with + // compile-time size), so there's no need to check sizes here + // (unlike Conv2dConfiguration, which stores strides as + // std::vector). + + constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; + using problem_shape_type = + cutlass::conv::ConvProblemShape; + // cute::array; must convert to the kernel's native strides + using TensorStride = typename problem_shape_type::TensorStride; + + const TensorStride stride_A = coord_to_array_strides(input_stride_a); + const TensorStride stride_B = coord_to_array_strides(input_stride_b); + const TensorStride stride_C = coord_to_array_strides(input_stride_c); + + const int num_groups = config.problem_size.groups; + if (num_groups != 1) { + CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); + return Status::kInvalid; + } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // Conv3dConfiguration differs a bit from Conv2dConfiguration, + // but the idea is the same: the "input_stride_a" from config + // depends on conv_kind (Fprop, Dgrad, or Wgrad), so stride_act + // isn't always input_stride_a. Analogously, stride_flt isn't + // always input_stride_b. The code here "undoes" the logic in + // config.layout_a(conv_kind) and config.layout_b(conv_kind) + // (analogous to Conv2dWorkspace::set_stride_vector) so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, D, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, T, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + + problem_shape_type problem_shape( + /* mode = */ mode, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, + /* lower_padding = */ {pad_d, pad_h, pad_w}, + /* upper_padding = */ {pad_d, pad_h, pad_w}, + /* traversal_stride = */ {traversal_stride_d, traversal_stride_h, traversal_stride_w}, + /* dilation = */ {dilation_d, dilation_h, dilation_w}, + num_groups); + out_args.problem_shape = problem_shape; + + // ConvProblemShape's constructor sets its shape_C member. +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); +#endif + // Initialization of C's and D's strides follows the CUTLASS 3 + // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). + { + using StrideC = typename Operator::ConvKernel::StrideC; + using StrideD = typename Operator::ConvKernel::StrideD; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + + if constexpr (conv_op == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " Wgrad: stride_C: " << stride_C << "\n"; +#endif + } + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): " + << stride_C_i << "\n"; +#endif + cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): " + << stride_D_i << "\n"; +#endif + cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + } + out_args.epilogue.dC = stride_C; + out_args.epilogue.dD = stride_D; + } + return Status::kSuccess; + } + } + + Status update_operator_arguments_from_arguments( + typename Operator::Arguments& out_args, + ConvArguments const& in_args) const + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperation3x::update_operator_arguments_from_arguments\n"); +#endif + auto status = UpdateFusionArgs::update_( + out_args.epilogue.thread, in_args); + if (status != Status::kSuccess) { + return status; + } + + out_args.mainloop.ptr_A = reinterpret_cast(in_args.A); + out_args.mainloop.ptr_B = reinterpret_cast(in_args.B); + + out_args.epilogue.ptr_C = reinterpret_cast(in_args.C); + out_args.epilogue.ptr_D = reinterpret_cast(in_args.D); + + return Status::kSuccess; + } +}; + +} // namespace cutlass::library diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..880cb4bf34b1f3d946e1dc86b80806309bb2b3c1 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h @@ -0,0 +1,1408 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_complex.h" +#include "cutlass/gemm/device/gemm_batched.h" +#include "cutlass/gemm/device/gemm_array.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + // assuming all tensors use same type for StrideIndex + using StrideIndex = typename Operator::LayoutA::Index; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmOperationBase(char const *name = "unknown_gemm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = GemmKind::kGemm; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::GemmKernel::WarpCount::kM, + Operator::GemmKernel::WarpCount::kN, + Operator::GemmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kGemm; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = {nullptr, configuration->lda}; + operator_args.ref_B = {nullptr, configuration->ldb}; + operator_args.ref_C = {nullptr, configuration->ldc}; + operator_args.ref_D = {nullptr, configuration->ldd}; + + operator_args.split_k_slices = configuration->split_k_slices; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + operator_args.ref_A.reset(static_cast(arguments->A)); + operator_args.ref_B.reset(static_cast(arguments->B)); + operator_args.ref_C.reset(static_cast(arguments->C)); + operator_args.ref_D.reset(static_cast(arguments->D)); + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + return op->run(stream); + } + + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmSparseOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementE = typename Operator::ElementE; + using LayoutE = typename Operator::LayoutE; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmSparseOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.kind = OperationKind::kSparseGemm; + this->description_.gemm_kind = GemmKind::kSparse; + this->description_.E = make_TensorDescription(Operator::kAlignmentE); + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + SparseGemmConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + operator_args.ref_A = {nullptr, configuration->lda}; + operator_args.ref_B = {nullptr, configuration->ldb}; + operator_args.ref_C = {nullptr, configuration->ldc}; + operator_args.ref_D = {nullptr, configuration->ldd}; + operator_args.ref_E = {nullptr, configuration->lde}; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + SparseGemmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(arguments->A)); + operator_args.ref_B.reset(static_cast(arguments->B)); + operator_args.ref_C.reset(static_cast(arguments->C)); + operator_args.ref_D.reset(static_cast(arguments->D)); + operator_args.ref_E.reset(static_cast(arguments->E)); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + SparseGemmConfiguration const *configuration = + static_cast(configuration_ptr); + + SparseGemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + return op->run(stream); + } + + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmUniversalOperation(char const *name = "unknown_gemm"): + GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmUniversalConfiguration const *configuration) { + + operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = (configuration->lda); + operator_args.ldb = (configuration->ldb); + operator_args.ldc = (configuration->ldc); + operator_args.ldd = (configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplex; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + + operator_args.lda_real = configuration->lda_real; + operator_args.lda_imag = configuration->lda_imag; + operator_args.ldb_real = configuration->ldb_real; + operator_args.ldb_imag = configuration->ldb_imag; + operator_args.ldc_real = configuration->ldc_real; + operator_args.ldc_imag = configuration->ldc_imag; + operator_args.ldd_real = configuration->ldd_real; + operator_args.ldd_imag = configuration->ldd_imag; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.batch_stride_A = arguments->batch_stride_A_real; + operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag; + operator_args.batch_stride_B = arguments->batch_stride_B_real; + operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag; + operator_args.batch_stride_C = arguments->batch_stride_C_real; + operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag; + operator_args.batch_stride_D = arguments->batch_stride_D_real; + operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag; + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexArrayOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplexArray; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda_real = configuration->lda_real; + operator_args.lda_imag = configuration->lda_imag; + operator_args.ldb_real = configuration->ldb_real; + operator_args.ldb_imag = configuration->ldb_imag; + operator_args.ldc_real = configuration->ldc_real; + operator_args.ldc_imag = configuration->ldc_imag; + operator_args.ldd_real = configuration->ldd_real; + operator_args.ldd_imag = configuration->ldd_imag; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.ptr_M = arguments->M; + operator_args.ptr_N = arguments->N; + operator_args.ptr_K = arguments->K; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexArrayConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArrayArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmGroupedOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmGroupedOperation(char const *name = "unknown_gemm"): + GemmOperationBase(name) { + + this->description_.kind = OperationKind::kGroupedGemm; + this->description_.provider = Provider::kCUTLASS; + this->threadblock_count = Operator::sufficient(); + + this->description_.gemm = GemmOperationBase::description_; + this->description_.gemm.gemm_kind = GemmKind::kGrouped; + this->description_.tile_description = this->description_.gemm.tile_description; + } + + /// Returns the description of the GroupedGEMM operation + virtual OperationDescription const & description() const override final { + return description_; + } + + +private: + int threadblock_count; + GroupedGemmDescription description_; + +protected: + + /// Constructs the arguments structure given the configuration and arguments + Status construct_arguments_( + OperatorArguments &op_args, + GemmGroupedConfiguration const *config) const { + + op_args.problem_count = config->problem_count; + op_args.threadblock_count = threadblock_count; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_( + OperatorArguments &op_args, + GemmGroupedArguments const *arguments) const { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + + op_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { + + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + + op_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + op_args.threadblock_count = threadblock_count; + op_args.problem_count = arguments->problem_count; + op_args.problem_sizes = arguments->problem_sizes; + + op_args.ptr_A = static_cast(arguments->ptr_A); + op_args.ptr_B = static_cast(arguments->ptr_B); + op_args.ptr_C = static_cast(arguments->ptr_C); + op_args.ptr_D = static_cast(arguments->ptr_D); + + op_args.lda = arguments->lda; + op_args.ldb = arguments->ldb; + op_args.ldc = arguments->ldc; + op_args.ldd = arguments->ldd; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmGroupedConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmGroupedArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2c1d17943f11fe8126b3070c3fcead5598e2d207 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp @@ -0,0 +1,714 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/array.h" +#include "cutlass/array_subbyte.h" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cute/tensor.hpp" +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperation3xBase : public Operation { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + // assuming all tensors use same type for StrideIndex + using StrideIndex = typename Operator::LayoutA::Index; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + +protected: + GemmDescription description_; + +public: + + /// Constructor + GemmOperation3xBase(char const *name = "unknown_gemm", GemmKind gemm_kind_ = GemmKind::kGemm) { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = gemm_kind_; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + GemmDescription const& get_gemm_description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversal3xOperation : public GemmOperation3xBase { +public: + + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + +public: + + /// Constructor + GemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { + dim3 cluster_dims( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{})); + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + } + } + +private: + int max_active_clusters{}; + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + template class Policy, int Stages, class ClusterShape, class KernelSchedule> + static constexpr bool is_sm90_mixed_dtype_mainloop_(Policy policy) { + return (cute::is_same_v, + cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput>); + } + + template + static constexpr bool is_sm90_mixed_dtype_mainloop_(DispatchPolicy) { + return false; + } + + template < + typename ElementWide, + typename ElementNarrow, + typename ElementScaleMainloop, + class ActualStrideAB, + Sm90MixedInputWiderOperand wider_operand, + bool is_n4w8, + typename ElementScale, + typename ElementZero, + class Layout_SZ> + static void dequantize_encode_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + cudaStream_t stream, + const int &problem_mn, + const int &problem_k, + const int &options_l, + const int &options_g, + ElementScale *ptr_S, + ElementZero *ptr_Z, + const size_t &SZ_size, + Layout_SZ layout_SZ + ) { + + auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); + auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); + auto layout_AB = cute::make_layout(shape_AB, stride_AB); + auto *ptr_dequantized_AB = static_cast(arguments->dequantized_AB); + const ElementNarrow *ptr_AB = nullptr; + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + ptr_AB = static_cast(arguments->B); + } + else { + ptr_AB = static_cast(arguments->A); + } + dequantize(ptr_dequantized_AB, ptr_AB, layout_AB, ptr_S, ptr_Z, layout_SZ, options_g, stream); + if constexpr(is_n4w8) { + size_t AB_size = cute::size(layout_AB); + cutlass::int4b_t *encoded_AB = static_cast(arguments->encoded_AB); + unified_encode_int4b(ptr_AB, encoded_AB, AB_size); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.ptr_B = static_cast(encoded_AB); + } + else { + operator_args.mainloop.ptr_A = static_cast(encoded_AB); + } + ElementScaleMainloop *ptr_packed_Scale = static_cast(arguments->packed_Scale); + pack_scale_fp8(ptr_S, ptr_packed_Scale, SZ_size); + } + } + + template < + typename ElementAB, + class ActualStrideAB, + class LayoutAB_Reordered, + class LayoutAtomQuant, + Sm90MixedInputWiderOperand wider_operand> + static void handle_shuffle_tensor_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + const int &problem_mn, + const int &problem_k, + const int &options_l) { + + auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); + auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); + auto layout_AB = cute::make_layout(shape_AB, stride_AB); + LayoutAB_Reordered layout_AB_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_AB); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.dB = layout_AB_reordered; + } + else { + operator_args.mainloop.dA = layout_AB_reordered; + } + if (arguments->generate_dequantized_AB) { + size_t AB_size = cute::size(layout_AB); + ElementAB *AB_reordered = cutlass::device_memory::allocate(AB_size); + const ElementAB *AB_src = nullptr; + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + AB_src = static_cast(operator_args.mainloop.ptr_B); + } + else { + AB_src = static_cast(operator_args.mainloop.ptr_A); + } + reorder_tensor(AB_src, layout_AB, AB_reordered, layout_AB_reordered); + ElementAB *AB_dst = static_cast(arguments->encoded_AB); + cutlass::device_memory::copy_device_to_device(AB_dst, AB_reordered, AB_size); + cutlass::device_memory::free(AB_reordered); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.ptr_B = AB_dst; + } + else { + operator_args.mainloop.ptr_A = AB_dst; + } + } + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_( + OperatorArguments& operator_args, + GemmUniversalArguments const* arguments, + cudaStream_t stream = nullptr) const { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + // TODO: type erase Arguments structure in 3.0 GEMM + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + // Stride{A,B} is a Layout if and only if: + // (1) This is a mixed dtype kernel, and + // (2) This mixed dtype kernel is using shuffling, and + // (3) sizeof(narrow_type) == 4 or 8 bits, and + // (4) sizeof(wide_type) == 16 bits. + // If A/B has the narrow data type, Stride{A/B} will be a Layout + constexpr bool is_StrideA_Layout = cute::is_layout::value; + constexpr bool is_StrideB_Layout = cute::is_layout::value; + static_assert(!(is_StrideA_Layout && is_StrideB_Layout), "Incorrect kernel configuration: StrideA and StrideB are both cute::Layout"); + if constexpr(!is_StrideA_Layout) { + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + } + if constexpr(!is_StrideB_Layout) { + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + } + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + using MainloopPolicy = typename CollectiveMainloop::DispatchPolicy; + if constexpr(is_sm90_mixed_dtype_mainloop_(MainloopPolicy{})) { + const int problem_m = arguments->problem_size.m(); + const int problem_n = arguments->problem_size.n(); + const int problem_k = arguments->problem_size.k(); + const int options_l = arguments->batch_count; + + constexpr Sm90MixedInputWiderOperand wider_operand = + (cutlass::sizeof_bits::value > cutlass::sizeof_bits::value) ? + Sm90MixedInputWiderOperand::A : Sm90MixedInputWiderOperand::B; + using ElementWide = std::conditional_t; + using ElementNarrow = std::conditional_t; + + constexpr bool has_scale = !std::is_same_v; + constexpr bool has_zero = !std::is_same_v; + + const int options_g = problem_k; + const int scale_k = (problem_k + options_g - 1) / options_g; + + constexpr bool is_A4B8 = ( + cutlass::is_same_v && + (cutlass::is_same_v || + cutlass::is_same_v)); + constexpr bool is_A8B4 = ( + cutlass::is_same_v && + (cutlass::is_same_v || + cutlass::is_same_v)); + constexpr bool is_int4_x_fp8 = is_A4B8 || is_A8B4; + + // If this is a convert-only kernel, we still need to generate dequantized A or B for verification, + // and in this case ElementScale is the same as ElementWide + // In int4 * fp8, ElementScale is a cutlass::Array, need to take out it's real element + using DummyElementScaleMainloop = std::conditional_t< + is_int4_x_fp8, + typename cutlass::Array, + ElementWide + >; + using ElementScaleMainloop = std::conditional_t< + has_scale, + typename CollectiveMainloop::ElementScale, + DummyElementScaleMainloop + >; + using ElementScale = std::conditional_t< + has_scale, + typename UnderlyingElement::type, + ElementWide + >; + using StrideScale = typename CollectiveMainloop::StrideScale; + // In ScaleOnly mode, we have allocated the same size of memory for arguments->Z and arguments->S + using ElementZero = std::conditional_t< + has_zero, + typename CollectiveMainloop::ElementZero, + ElementScale + >; + const int SZ_1st_dim = (wider_operand == Sm90MixedInputWiderOperand::A) ? problem_n : problem_m; + const size_t SZ_size = static_cast(SZ_1st_dim * scale_k * options_l); + auto shape_SZ = cute::make_shape(SZ_1st_dim, scale_k, options_l); + ElementScale *ptr_S = static_cast(arguments->Scale); + ElementZero *ptr_Z = static_cast(arguments->Zero); + + // 1. If arguments is initialized in profiler, S and Z needs to be allocated and filled + if (arguments->generate_scale_and_zero) { + float scale_min = 1.0f, scale_max = 1.0f; + if constexpr(has_scale) { + const float elt_max_f = float(cutlass::platform::numeric_limits::max()); + // Need to fix max_dequant_val and min_dequant_val? + const float max_dequant_val = elt_max_f * 0.25f; + const float min_dequant_val = 0.5f; + scale_max = max_dequant_val / elt_max_f; + scale_min = min_dequant_val / elt_max_f; + } + uint64_t seed = 2023; + cutlass::reference::device::BlockFillRandomUniform( + ptr_S, SZ_size, seed, ElementScale(scale_max), ElementScale(scale_min)); + + // In ScaleOnly mode, set Z as zero for generating dequantized A or B + const float zero_max = has_zero ? 2.0f : 0.0f; + const float zero_min = has_zero ? -2.0f : 0.0f; + cutlass::reference::device::BlockFillRandomUniform( + ptr_Z, SZ_size, seed, ElementZero(zero_max), ElementZero(zero_min)); + } // End of "if (arguments->generate_scale_and_zero)" + + // 2. Generate the dequantized A or B for verification + if (arguments->generate_dequantized_AB) { + StrideScale stride_SZ = cutlass::make_cute_packed_stride(StrideScale{}, shape_SZ); + auto layout_SZ = cute::make_layout(shape_SZ, stride_SZ); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + if constexpr(is_StrideB_Layout) { + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout of B later + using ActualLayoutB = cutlass::layout::ColumnMajor; + using ActualStrideB = cutlass::detail::TagToStrideB_t; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + else { + using ActualStrideB = typename CollectiveMainloop::StrideB; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + } + else { + if constexpr(is_StrideA_Layout) { + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout of A later + using ActualLayoutA = cutlass::layout::RowMajor; + using ActualStrideA = cutlass::detail::TagToStrideA_t; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + else { + using ActualStrideA = typename CollectiveMainloop::StrideA; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + } // End of "if constexpr(wider_operand == Sm90MixedInputWiderOperand::A)" + } // End of "if (arguments->generate_dequantized_AB)" + + // 3. Put Scale and Zero in mainloop + if constexpr(has_scale) { + if constexpr(is_int4_x_fp8) { + operator_args.mainloop.ptr_S = static_cast(arguments->packed_Scale); + } + else { + operator_args.mainloop.ptr_S = static_cast(arguments->Scale); + } + operator_args.mainloop.dS = cutlass::make_cute_packed_stride(StrideScale{}, shape_SZ); + operator_args.mainloop.group_size = options_g; + if constexpr(has_zero) { + operator_args.mainloop.ptr_Z = static_cast(arguments->Zero); + } + } // End of "if constexpr(has_scale)" + + // Handle the shuffling + using ValueShuffle = std::conditional_t< + cutlass::sizeof_bits::value == 4, + cute::Layout, cute::Stride>, + cute::Layout, cute::Stride> + >; + constexpr int NumShuffleAtoms = 1; + using MmaAtomShape = cute::Layout>>; + using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout and stride of A/B later + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A && is_StrideB_Layout) { + using ActualLayoutB = cutlass::layout::ColumnMajor; + using ActualStrideB = cutlass::detail::TagToStrideB_t; + using LayoutB_Reordered = typename CollectiveMainloop::StrideB; + handle_shuffle_tensor_( + operator_args, arguments, problem_n, problem_k, options_l); + } + if constexpr(wider_operand == Sm90MixedInputWiderOperand::B && is_StrideA_Layout) { + using ActualLayoutA = cutlass::layout::RowMajor; + using ActualStrideA = cutlass::detail::TagToStrideA_t; + using LayoutA_Reordered = typename CollectiveMainloop::StrideA; + handle_shuffle_tensor_( + operator_args, arguments, problem_m, problem_k, options_l); + } + } // End of "if constexpr(is_sm90_mixed_dtype_mainloop_(MainloopPolicy{}))" + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { + operator_args.hw_info.max_active_clusters = max_active_clusters; + } + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + [[maybe_unused]] void const *configuration_ptr, void const *arguments_ptr) const override { + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + OperatorArguments args; + + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + Status can_impl = Operator::can_implement(args); + + //return Operator::can_implement(args); + return can_impl; + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr), stream); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, + static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..91f618d4fab74a6d43e2d82c572d215d5bea5a1c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp @@ -0,0 +1,873 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all grouped GEMM operations in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "gemm_operation_3x.hpp" +#include "library_internal.h" + +namespace cutlass::library { + +template +class GroupedGemmOperation3xBase : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + GroupedGemmOperation3xBase(char const* name = "unknown_gemm") + : GemmOperation3xBase(name, GemmKind::kGrouped) { + this->description_.kind = OperationKind::kGroupedGemm; + this->description_.name = name; + this->description_.provider = Provider::kCUTLASS; + + this->description_.gemm = GemmOperation3xBase::description_; + this->description_.tile_description = this->description_.gemm.tile_description; + }; + +public: + mutable CudaBuffer strideA_device; + mutable CudaBuffer strideB_device; + mutable CudaBuffer strideC_device; + mutable CudaBuffer strideD_device; + + /// Returns the description of the GEMM operation + virtual OperationDescription const& description() const override final { return description_; } + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const* configuration) const override final { + return sizeof(Operator); + } + +protected: + library::GroupedGemmDescription description_; + + Status initialize_strides(GemmGroupedConfiguration const& config) const { + auto const num_groups = config.problem_count; + this->strideA_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups); + this->strideB_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups); + this->strideC_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups); + this->strideD_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups); + + std::vector strideA_host(num_groups); + std::vector strideB_host(num_groups); + std::vector strideC_host(num_groups); + std::vector strideD_host(num_groups); + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + strideA_host[group_idx] = + cute::make_int_tuple_from( + config.lda[group_idx]); + strideB_host[group_idx] = + cute::make_int_tuple_from( + config.ldb[group_idx]); + strideC_host[group_idx] = + cute::make_int_tuple_from( + config.ldc[group_idx]); + strideD_host[group_idx] = + cute::make_int_tuple_from( + config.ldc[group_idx]); + } + CUDA_CHECK(cudaMemcpy( + this->strideA_device.data(), + strideA_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideB_device.data(), + strideB_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideC_device.data(), + strideC_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideD_device.data(), + strideD_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups, + cudaMemcpyHostToDevice)); + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_base( + OperatorArguments& operator_args, + GemmGroupedArguments const& arguments) const { + operator_args.mode = cutlass::gemm::GemmUniversalMode::kGrouped; + operator_args.problem_shape = { + arguments.problem_count, + arguments.problem_sizes_3x, + arguments.pointer_mode == ScalarPointerMode::kHost ? arguments.problem_sizes_3x_host + : nullptr}; + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); + operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + + using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; + + static_assert(cute::is_same_v, + "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); + using RuntimeDatatypeArg = RuntimeDataTypeA; + + auto mapping = [](RuntimeDatatype type) { + if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE5M2) { + return cute::UMMA::MXF8F6F4Format::E5M2; + } + else if (type == RuntimeDatatype::kE4M3) { + return cute::UMMA::MXF8F6F4Format::E4M3; + } + else if (type == RuntimeDatatype::kE3M2) { + return cute::UMMA::MXF8F6F4Format::E3M2; + } + else if (type == RuntimeDatatype::kE2M3) { + return cute::UMMA::MXF8F6F4Format::E2M3; + } + else if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF8F6F4Format::E2M1; + } + else { + #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 + std::cerr << "Invalid input datatype specified. Running with e4m3." << std::endl; + #endif + return cute::UMMA::MXF8F6F4Format::E4M3; + } + } + else if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF4Format::E2M1; + } + else { + #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 + std::cerr << "Invalid input datatype specified. Running with e2m1." << std::endl; + #endif + return cute::UMMA::MXF4Format::E2M1; + } + } + // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype + CUTE_GCC_UNREACHABLE; + }; + operator_args.mainloop.runtime_data_type_a = mapping(arguments.runtime_input_datatype_a); + operator_args.mainloop.runtime_data_type_b = mapping(arguments.runtime_input_datatype_b); + } + else { + operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); + operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + } + operator_args.epilogue.ptr_C = static_cast(arguments.ptr_C); + operator_args.epilogue.ptr_D = static_cast(arguments.ptr_D); + + operator_args.mainloop.dA = + static_cast(this->strideA_device.data()); + operator_args.mainloop.dB = + static_cast(this->strideB_device.data()); + operator_args.epilogue.dC = + static_cast(this->strideC_device.data()); + operator_args.epilogue.dD = + static_cast(this->strideD_device.data()); + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments.sm_count; + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + operator_args.hw_info.max_active_clusters = arguments.max_active_clusters; + } + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments.swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments.raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = + dim3(arguments.cluster_shape.m(), arguments.cluster_shape.n(), arguments.cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments.cluster_shape_fallback.m(), + arguments.cluster_shape_fallback.n(), + arguments.cluster_shape_fallback.k()); + } + return Status::kSuccess; + } + + template + static Status update_fusion_args(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } +}; + +/// **** CAUTION **** +/// Unlike other operations, initialize() must be called when +/// certain arguments change. See initialize() for details. +template +class GroupedGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + +public: + GroupedGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) {} + + ~GroupedGemmUniversal3xOperation() override = default; + +private: + int max_active_clusters{}; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + + /// Constructs the arguments structure given the configuration and arguments + Status + update_arguments_(OperatorArguments& operator_args, GemmGroupedArguments const* arguments) const { + + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + status = this->update_arguments_base(operator_args, *arguments); + return status; + } + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GemmGroupedArguments const* arguments = static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + Operator* op = new (host_workspace) Operator; + + auto const& config = *static_cast(configuration_ptr); + return this->initialize_strides(config); + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(operator_args, device_workspace, stream, nullptr, args.use_pdl); + return status; + } + + // Set arguments that should only be set once before verifying or profiling the kernel. + // This should encompass any expensive operations that don't vary from run to run + // (e.g., max_active_clusters). + Status initialize_with_arguments(void* arguments_ptr) const override { + if constexpr (Operator::ArchTag::kMinComputeCapability < 90) { + return Status::kSuccess; + } + + GemmGroupedArguments* args = static_cast(arguments_ptr); + + dim3 cluster_dims; + if constexpr (cute::is_static_v) { + cluster_dims = dim3( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{}) + ); + } + else { + cluster_dims = dim3( + args->cluster_shape.m(), + args->cluster_shape.n(), + args->cluster_shape.k() + ); + } + + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + args->max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + + if (args->max_active_clusters == 0) { + std::cerr << "Max Active Clusters could not be queried. " + << "Falling back to heuristics mode (static cluster shape) or preferred cluster mode.\n"; + } + + return Status::kSuccess; + } +}; + +template +class GroupedBlockScaledGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; + using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + constexpr static int SFVecSize = TiledMma::SFVecSize; + + + static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; + static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; + using ElementSFD = cute::conditional_t; + using LayoutSFD = cute::conditional_t; + + GroupedBlockScaledGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) { + + BlockScaleDescription block_scaled_desc{}; + block_scaled_desc.kind = OperationKind::kBlockScaledGemm; + block_scaled_desc.SFA.element = NumericTypeMap::kId; + block_scaled_desc.SFA.layout = LayoutTypeID::kRowMajor; + block_scaled_desc.SFA.alignment = 128; + block_scaled_desc.SFA.log_extent_range = 32; + block_scaled_desc.SFA.log_stride_range = 32; + + block_scaled_desc.SFB.element = NumericTypeMap::kId; + block_scaled_desc.SFB.layout = LayoutTypeID::kRowMajor; + block_scaled_desc.SFB.alignment = 128; + block_scaled_desc.SFB.log_extent_range = 32; + block_scaled_desc.SFB.log_stride_range = 32; + + block_scaled_desc.SFMVecSize = 1; + block_scaled_desc.SFNVecSize = 1; + block_scaled_desc.SFKVecSize = SFVecSize; + + block_scaled_desc.SFD = make_TensorDescription(128); + block_scaled_desc.EpilogueSFVecSize = SFD_VectorSize; + + this->description_.block_scales = block_scaled_desc; + } + + ~GroupedBlockScaledGemmUniversal3xOperation() override = default; + + mutable CudaBuffer layout_SFA_device; + mutable CudaBuffer layout_SFB_device; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status + update_(FusionArgs& fusion_args, GroupedGemmBlockScaledArguments const& arguments) { + + if constexpr (epilogue_scalefactor_generation) { + fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); + fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); + } + + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GroupedGemmBlockScaledArguments const* arguments = + static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + Status update_arguments_( + OperatorArguments& operator_args, + GroupedGemmBlockScaledArguments const* arguments) const { + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.mainloop.ptr_SFA = + static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = + static_cast(arguments->SFB); + + operator_args.mainloop.layout_SFA = + static_cast(this->layout_SFA_device.data()); + operator_args.mainloop.layout_SFB = + static_cast(this->layout_SFB_device.data()); + + return this->update_arguments_base(operator_args, *arguments); + } + + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = + update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + auto const& config = *static_cast(configuration_ptr); + auto status = this->initialize_strides(config); + if (status != Status::kSuccess) { + return status; + } + + auto num_groups = config.problem_count; + this->layout_SFA_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups); + this->layout_SFB_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups); + auto layout_SFA_host = std::vector(num_groups); + auto layout_SFB_host = std::vector(num_groups); + + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + auto const& shape = config.problem_sizes_3x_host[group_idx]; + auto M = get<0>(shape); + auto N = get<1>(shape); + auto K = get<2>(shape); + + auto layout_SFA = CollectiveMainloop::Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = CollectiveMainloop::Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + layout_SFA_host[group_idx] = layout_SFA; + layout_SFB_host[group_idx] = layout_SFB; + } + + CUDA_CHECK(cudaMemcpy( + this->layout_SFA_device.data(), + layout_SFA_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->layout_SFB_device.data(), + layout_SFB_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups, + cudaMemcpyHostToDevice)); + + Operator* op = new (host_workspace) Operator; + return status; + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + status = op->run(operator_args, device_workspace, stream, nullptr); + return status; + } +}; + +template +class GroupedBlockwiseGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using ElementSFA = typename Operator::ElementAccumulator; + using ElementSFB = typename Operator::ElementAccumulator; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + + GroupedBlockwiseGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) { + + BlockScaleDescription blockwise_desc{}; + blockwise_desc.kind = OperationKind::kBlockwiseGemm; + blockwise_desc.SFA.element = NumericTypeMap::kId; + blockwise_desc.SFA.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFA{}.stride()) == 1 ? + LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + blockwise_desc.SFA.alignment = CollectiveMainloop::AlignmentSFA; + blockwise_desc.SFA.log_extent_range = 32; + blockwise_desc.SFA.log_stride_range = 32; + + blockwise_desc.SFB.element = NumericTypeMap::kId; + blockwise_desc.SFB.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFB{}.stride()) == 1 ? + LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor; + blockwise_desc.SFB.alignment = CollectiveMainloop::AlignmentSFA; + blockwise_desc.SFB.log_extent_range = 32; + blockwise_desc.SFB.log_stride_range = 32; + + blockwise_desc.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM; + blockwise_desc.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN; + blockwise_desc.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK; + + blockwise_desc.EpilogueSFVecSize = 0; + + this->description_.block_scales = blockwise_desc; + } + + ~GroupedBlockwiseGemmUniversal3xOperation() override = default; + + mutable CudaBuffer layout_SFA_device; + mutable CudaBuffer layout_SFB_device; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status + update_(FusionArgs& fusion_args, GroupedGemmBlockwiseArguments const& arguments) { + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GroupedGemmBlockwiseArguments const* arguments = + static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + Status update_arguments_( + OperatorArguments& operator_args, + GroupedGemmBlockwiseArguments const* arguments) const { + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.mainloop.ptr_SFA = + static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = + static_cast(arguments->SFB); + + operator_args.mainloop.layout_SFA = + static_cast(this->layout_SFA_device.data()); + operator_args.mainloop.layout_SFB = + static_cast(this->layout_SFB_device.data()); + + return this->update_arguments_base(operator_args, *arguments); + } + + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = + update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + auto const& config = *static_cast(configuration_ptr); + auto status = this->initialize_strides(config); + if (status != Status::kSuccess) { + return status; + } + + auto num_groups = config.problem_count; + this->layout_SFA_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups); + this->layout_SFB_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups); + auto layout_SFA_host = std::vector(num_groups); + auto layout_SFB_host = std::vector(num_groups); + + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + auto const& shape = config.problem_sizes_3x_host[group_idx]; + auto M = get<0>(shape); + auto N = get<1>(shape); + auto K = get<2>(shape); + + auto layout_SFA = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + layout_SFA_host[group_idx] = layout_SFA; + layout_SFB_host[group_idx] = layout_SFB; + } + + CUDA_CHECK(cudaMemcpy( + this->layout_SFA_device.data(), + layout_SFA_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->layout_SFB_device.data(), + layout_SFB_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups, + cudaMemcpyHostToDevice)); + + Operator* op = new (host_workspace) Operator; + return status; + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + status = op->run(operator_args, device_workspace, stream, nullptr); + return status; + } +}; + + +} // namespace cutlass::library diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/library_internal.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/library_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..e8bd77397f3b85cce2da2a7a8e447ab6ccb48aea --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/library_internal.h @@ -0,0 +1,427 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/arch_mappings.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct NumericTypeMap; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kVoid; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kB1; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS4; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU4; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE4M3; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE5M2; +}; + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE2M3; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE3M2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE2M1; +}; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFUE8M0; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFUE4M3; +}; + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF64; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF16; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF32; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kBF16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kTF32; +}; + + + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF6; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF4; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kInvalid; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAdd; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddMixedInputUpcast; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddGaussianComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kXorPopc; +}; + + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastF32; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplexFastF32; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct LayoutMap; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kRowMajor; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct OpcodeClassMap; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kSimt; +}; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kTensorOp; +}; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kSparseTensorOp; +}; + + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kBlockScaledOp; +}; + + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ComplexTransformMap; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone; +}; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ConvModeMap; + +template <> struct ConvModeMap { + static ConvModeID const kId = ConvModeID::kCrossCorrelation; +}; + +template <> struct ConvModeMap { + static ConvModeID const kId = ConvModeID::kConvolution; +}; + + +template struct ConvKindMap; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kFprop; +}; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kDgrad; +}; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kWgrad; +}; + + +template struct IteratorAlgorithmMap; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFixedChannels; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFewChannels; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +TensorDescription make_TensorDescription(int alignment = 1) { + TensorDescription desc; + + desc.element = NumericTypeMap::kId; + desc.layout = LayoutMap::kId; + desc.alignment = alignment; + desc.log_extent_range = int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; + desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; + + return desc; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..76d8d0dfdb1aa6ed0324b9d6299b06ebf3f436d9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Rank 2K operation kinds (Syr2k, Her2k) + in CUTLASS Library. + + +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/rank_2k.h" +#include "cutlass/gemm/kernel/default_rank_2k_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Rank2KOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + RankKDescription description_; + +public: + + /// Constructor + Rank2KOperationBase(char const *name = "unknown_rank_k") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.rank_k_kind = RankKKind::kUniversal; + description_.fill_mode = kFillModeC; + description_.blas_mode = kBlasMode; + description_.num_ranks = kUpdateRank; + + description_.kind = OperationKind::kRank2K; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::Rank2Kkernel::WarpCount::kM, + Operator::Rank2Kkernel::WarpCount::kN, + Operator::Rank2Kkernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the SYRK operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Rank2KOperation : public Rank2KOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + Rank2KOperation(char const *name = "unknown_rank_2k"): + Rank2KOperationBase(name) { + + this->description_.rank_k_kind = RankKKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + RankKConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + RankKArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + RankKConfiguration const *configuration = + static_cast(configuration_ptr); + + RankKArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + //std::cout << "initialize() library::Rank2KOperation" << std::endl; + //print_operator_args(args); + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run() library::Rank2KOperation" << std::endl; + //print_operator_args(args); + status = op->run(stream); + + return status; + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Rank2KOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " epilogue (alpha, beta): " + << operator_args.epilogue.alpha << ", " + << operator_args.epilogue.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ptr_A << ", {" + << operator_args.lda << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ptr_B << ", {" + << operator_args.ldb << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ptr_C << ", {" + << operator_args.ldc << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ptr_D << ", {" + << operator_args.ldd << "}" << std::endl; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..021f7f03fcc4449bdc2ef2c97e29fe0fead09a64 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Rank K operation kinds (Syrk, Herk) + in CUTLASS Library. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/rank_k.h" +#include "cutlass/gemm/kernel/default_rank_k_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class RankKOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementA; + using LayoutB = typename Operator::LayoutA; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + RankKDescription description_; + +public: + + /// Constructor + RankKOperationBase(char const *name = "unknown_rank_k") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.rank_k_kind = RankKKind::kUniversal; + description_.fill_mode = kFillModeC; + description_.blas_mode = kBlasMode; + description_.num_ranks = kUpdateRank; + + description_.kind = OperationKind::kRankK; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::RankKkernel::WarpCount::kM, + Operator::RankKkernel::WarpCount::kN, + Operator::RankKkernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentA); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the SYRK operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class RankKOperation : public RankKOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementA; + using LayoutB = typename Operator::LayoutA; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + RankKOperation(char const *name = "unknown_rank_k"): + RankKOperationBase(name) { + + this->description_.rank_k_kind = RankKKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + RankKConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->lda); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + RankKArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + RankKConfiguration const *configuration = + static_cast(configuration_ptr); + + RankKArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..6e948540e3f29dceace42b5e8ef3f91118c01b37 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for reduction operation in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/reduction/thread/reduction_operators.h" +#include "cutlass/reduction/device/reduce_split_k.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ReductionOperation : public Operation { +public: + using Operator = Operator_; + + using ElementWorkspace = typename Operator::ElementWorkspace; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementOutput = typename Operator::ElementOutput; + + using ElementCompute = typename Operator::OutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ReductionDescription description_; + +public: + + /// Constructor + ReductionOperation(char const *name = "unknown_reduction") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kReduction; + + description_.tile_description.threadblock_shape = make_Coord(Operator::Shape::kRow, Operator::Shape::kColumn, 1); + + description_.tile_description.math_instruction.instruction_shape = make_Coord(1, 1, 1); + description_.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; + description_.tile_description.math_instruction.opcode_class = OpcodeClassID::kSimt; + description_.tile_description.math_instruction.math_operation = MathOperationID::kAdd; + + description_.tile_description.minimum_compute_capability = 50; + description_.tile_description.maximum_compute_capability = 1024; + + description_.element_workspace = NumericTypeMap::kId; + description_.element_output = NumericTypeMap::kId; + description_.element_epilogue = NumericTypeMap::kId; + + } + + /// Returns the description of the Reduction operation + virtual OperationDescription const & description() const { + return description_; + } + + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + ReductionConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + operator_args.partitions = configuration->partitions; + operator_args.partition_stride = configuration->partition_stride; + + operator_args.workspace = {nullptr, int(configuration->ldw)}; + operator_args.source = {nullptr, int(configuration->lds)}; + operator_args.destination = {nullptr, int(configuration->ldd)}; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ReductionArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::OutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::OutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.workspace.reset(static_cast(const_cast(arguments->workspace))); + operator_args.source.reset(static_cast(const_cast(arguments->source))); + operator_args.destination.reset(static_cast(const_cast(arguments->destination))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + ReductionConfiguration const *configuration = + static_cast(configuration_ptr); + + ReductionArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Reduction" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run library::Reduction" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Reduction::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Reduction::OperatorArguments" << std::endl + << " problem_size: " + << operator_args.problem_size << std::endl + << " partitions: " + << operator_args.partitions << std::endl + << " partition_stride: " + << operator_args.partition_stride << std::endl + << " epilogue (alpha, beta): " + << operator_args.output.alpha << ", " + << operator_args.output.beta << std::endl + << " workspace (ptr, stride): " + << operator_args.workspace.data() << ", " + << operator_args.workspace.stride(0) << std::endl + << " source (ptr, stride): " + << operator_args.source.data() << ", " + << operator_args.source.stride(0) << std::endl + << " destination (ptr, stride): " + << operator_args.destination.data() << ", " + << operator_args.destination.stride(0) << std::endl; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..769da1c8515877536fd9b9fd72c836fd43ebd5d8 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h @@ -0,0 +1,453 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for block-scaled GEMM operation kinds in CUTLASS Library +*/ + + + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "cutlass/util/packed_stride.hpp" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +namespace detail { +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + typename ElementSFA_, + typename ElementB_, + typename LayoutB_, + typename ElementSFB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ElementSFD_ = void, + typename LayoutSFD_ = LayoutC_, + int SFVecSize_ = 32, + int EpilogueSFVecSize_ = 0, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class BlockScaledGemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementSFA = ElementSFA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementSFB = ElementSFB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using ElementSFD = ElementSFD_; + using LayoutSFD = LayoutSFD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + constexpr static int SFVecSize = SFVecSize_; + constexpr static int EpilogueSFVecSize = EpilogueSFVecSize_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + BlockScaledGemmDescription description_; + +public: + + /// Constructor + BlockScaledGemmReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kBlockScaledGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.SFA = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.SFB = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + description_.SFD = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + description_.SFVecSize = SFVecSize; + description_.EpilogueSFVecSize = EpilogueSFVecSize; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.SFA.element) << to_string(description_.SFA.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.SFB.element) << to_string(description_.SFB.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.SFD.element) << to_string(description_.SFD.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + using namespace cute; + + BlockScaledGemmArguments const &args = *static_cast(arguments); + + // Construct cute::Tensor A/B/C + + int M = args.problem_size.m(); + int N = args.problem_size.n(); + int K = args.problem_size.k(); + int L = args.batch_count; + + auto problem_shape_MNKL = cute::make_shape(M, N, K, L); + + auto alpha = *(static_cast(args.alpha)); + auto beta = *(static_cast(args.beta)); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + using Sm1xxBlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + auto A = cute::make_tensor(detail::make_iterator(static_cast(args.A)), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(static_cast(args.SFA), Sm1xxBlockScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL)); + + auto B = cute::make_tensor(detail::make_iterator(static_cast(args.B)), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(static_cast(args.SFB), Sm1xxBlockScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL)); + + auto C = [&]() { + if constexpr (not is_same_v) { + return cute::make_tensor(detail::make_iterator(static_cast(args.C)), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + else { + return cute::make_tensor(detail::make_iterator(static_cast(nullptr)), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + }(); + + auto D = cute::make_tensor(detail::make_iterator(static_cast(args.D)), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{A, SfA, B, SfB}; + + if constexpr (not is_same_v) { + + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + EpilogueSFVecSize + >; + + auto SfD = cute::make_tensor(detail::make_iterator(static_cast(args.SFD)), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, ElementAccumulator, ElementCompute, + decltype(C), decltype(D), decltype(SfD), Int, cutlass::reference::host::SfStrategy::SfDGen> + epilogue_params{alpha, beta, C, D, SfD, *(static_cast(args.norm_constant))}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + else { + // W/O SF generation + auto SfD = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L))); // not used. + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, ElementAccumulator, ElementCompute, + decltype(C), decltype(D), decltype(SfD)> + epilogue_params{alpha, beta, C, D, SfD}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + + return Status::kSuccess; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementSFD_ = void, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + int SFVecSize = 32, + int EpilogueSFVecSize = SFVecSize, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_block_scaled_gemm_tn(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementSFD_ = void, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + int SFVecSize = 32, + int EpilogueSFVecSize = SFVecSize, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_block_scaled_gemm(Manifest &manifest) { + /// + /// A is Row , B is Col + /// + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + /// + /// A is Col , B is Row + /// + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..fd988f899f563acfc6f8003bdb49523bca51d6d9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h @@ -0,0 +1,807 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for blockwise/groupwise GEMM operation kinds in CUTLASS Library +*/ + + + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "cutlass/util/packed_stride.hpp" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + typename LayoutSFA_, + typename ElementSFA_, + typename ElementB_, + typename LayoutB_, + typename LayoutSFB_, + typename ElementSFB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class BlockwiseGemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementSFA = ElementSFA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementSFB = ElementSFB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + BlockwiseGemmDescription description_; + +public: + + /// Constructor + BlockwiseGemmReferenceOperation(int SFMVecSize_, int SFNVecSize_, int SFKVecSize_) + : SFMVecSize(SFMVecSize_), SFNVecSize(SFNVecSize_), SFKVecSize(SFKVecSize_) { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kBlockwiseGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.SFA = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.SFB = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + description_.SFMVecSize = SFMVecSize; + description_.SFNVecSize = SFNVecSize; + description_.SFKVecSize = SFKVecSize; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.SFA.element) << SFMVecSize << "x" << SFKVecSize << to_string(description_.SFA.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.SFB.element) << SFNVecSize << "x" << SFKVecSize << to_string(description_.SFB.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + using namespace cute; + + BlockwiseGemmArguments const &args = *static_cast(arguments); + + // Construct cute::Tensor A/B/C + + int M = args.problem_size.m(); + int N = args.problem_size.n(); + int K = args.problem_size.k(); + int L = args.batch_count; + + auto problem_shape_MNKL = cute::make_shape(M, N, K, L); + + auto alpha = *(static_cast(args.alpha)); + auto beta = *(static_cast(args.beta)); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + using BlockwiseConfig = cutlass::detail::RuntimeBlockwiseScaleConfig<>; + auto A = cute::make_tensor(static_cast(args.A), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(static_cast(args.SFA), BlockwiseConfig::tile_atom_to_shape_SFA(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); + + auto B = cute::make_tensor(static_cast(args.B), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(static_cast(args.SFB), BlockwiseConfig::tile_atom_to_shape_SFB(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); + + auto C = [&]() { + if constexpr (not is_same_v) { + return cute::make_tensor(static_cast(args.C), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + else { + return cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + }(); + + auto D = cute::make_tensor(static_cast(args.D), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{A, SfA, B, SfB}; + + // W/O SF generation + cutlass::reference::host::GettEpilogueParams< + ElementCompute, ElementAccumulator, ElementAccumulator, ElementCompute, + decltype(C), decltype(D)> + epilogue_params{alpha, beta, C, D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + return Status::kSuccess; + } + +private: + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_blockwise_gemm(Manifest &manifest, int SFMVecSize, int SFNVecSize, int SFKVecSize) { + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + +} + +template +void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &manifest) { + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..240fe18d16a27778bf75e0c02f99d251c096353f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h @@ -0,0 +1,636 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "library_internal.h" + +#include "cutlass/conv/convolution.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + Provider kProvider, + cutlass::conv::Operator ConvolutionalOperator, + int ConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +struct ConvReferenceDispatcher; + +/// Dispatcher for Conv2d (partially specialized for kConvDim == 2) +template < + Provider kProvider, + cutlass::conv::Operator kConvolutionalOperator, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator, + typename ConvertOp, + typename InnerProductOp +> +struct ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + 2, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp> { + + static Status dispatch( + void const *configuration, + ElementA *ptr_A, + ElementB *ptr_B, + ElementC *ptr_C, + ElementC *ptr_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr + ) { + + Conv2dConfiguration const &config = + *static_cast(configuration); + + // TODO: make below code more general. It is fixed for NHWC now. + layout::TensorNHWC layout_a; + layout::TensorNHWC layout_b; + layout::TensorNHWC layout_c; + + layout_a.stride() = + make_Coord(int32_t(config.stride_a[0]), + int32_t(config.stride_a[1]), + int32_t(config.stride_a[2])); + + layout_b.stride() = + make_Coord(int32_t(config.stride_b[0]), + int32_t(config.stride_b[1]), + int32_t(config.stride_b[2])); + + layout_c.stride() = + make_Coord(int32_t(config.stride_c[0]), + int32_t(config.stride_c[1]), + int32_t(config.stride_c[2])); + + if (kProvider == Provider::kReferenceHost) { + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC , + LayoutC, + ElementCompute, + ElementAccumulator, + ElementC, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, layout_a}, + {ptr_B, layout_b}, + {ptr_C, layout_c}, + {ptr_D, layout_c}, + alpha, + beta + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + return cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, layout_a}, + {ptr_B, layout_b}, + {ptr_C, layout_c}, + {ptr_D, layout_c}, + alpha, + beta, + stream + ); + } + return Status::kErrorNotSupported; + } +}; + +/// Dispatcher for Conv3d (partially specialized for kConvDim == 3) +template < + Provider kProvider, + cutlass::conv::Operator kConvolutionalOperator, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator, + typename ConvertOp, + typename InnerProductOp +> +struct ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + 3, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp> { + + static Status dispatch( + void const *configuration, + ElementA *ptr_A, + ElementB *ptr_B, + ElementC *ptr_C, + ElementC *ptr_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr + ) { + + Conv3dConfiguration const &config = + *static_cast(configuration); + + ConvKind const conv_kind = ConvKindMap::kId; + + if (kProvider == Provider::kReferenceHost) { + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC , + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, config.layout_a(conv_kind)}, + {ptr_B, config.layout_b(conv_kind)}, + {ptr_C, config.layout_c(conv_kind)}, + {ptr_D, config.layout_c(conv_kind)}, + alpha, + beta + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + return cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, config.layout_a(conv_kind)}, + {ptr_B, config.layout_b(conv_kind)}, + {ptr_C, config.layout_c(conv_kind)}, + {ptr_D, config.layout_c(conv_kind)}, + alpha, + beta, + stream + ); + } + return Status::kErrorNotSupported; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + cutlass::conv::Operator ConvolutionalOperator, + int ConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class ConvReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + static cutlass::conv::Operator const kConvolutionalOperator = ConvolutionalOperator; + static int const kConvDim = ConvDim; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + ConvDescription description_; + +public: + + /// Constructor + ConvReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = (kConvDim == 2 ? OperationKind::kConv2d : OperationKind::kConv3d); + description_.conv_kind = ConvKindMap::kId; + description_.conv_dim = kConvDim; + + // Tensor description + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Iterator algorithm for convolution reference + description_.iterator_algorithm = IteratorAlgorithmID::kNone; + + // Compute capability for convolution reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + // Procedural name + std::stringstream ss; + + ss << "conv" << kConvDim << "d_" << to_string(description_.conv_kind) + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + switch (kConvDim) { + case 2: + return sizeof(Conv2dConfiguration); + case 3: + return sizeof(Conv3dConfiguration); + default: + break; + } + + return 0; + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); + + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + ConvArguments const &args = *static_cast(arguments); + + ElementCompute alpha; + ElementCompute beta; + + alpha = *static_cast(args.alpha); + beta = *static_cast(args.beta); + + // TODO - respect pointer mode + + // Invoke 2D or 3D convolution + return detail::ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + kConvDim, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >::dispatch( + host_workspace, + static_cast(const_cast(args.A)), + static_cast(const_cast(args.B)), + static_cast(const_cast(args.C)), + static_cast(args.D), + alpha, + beta, + stream + ); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs Fprop reference operators. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_fprop(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kFprop, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kFprop, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/// Constructs Dgrad and Wgrad reference operators. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_backwards(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kDgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kDgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kWgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kWgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/// Six operators for the price of one. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_all(Manifest &manifest) { + + make_conv_fprop< + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_conv_backwards< + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..e07158b0602eef1d71cfdca95323b3da60553747 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h @@ -0,0 +1,543 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for GEMM operation kinds in CUTLASS Library +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + cutlass::ComplexTransform TransformA, + typename ElementB_, + typename LayoutB_, + cutlass::ComplexTransform TransformB, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class GemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + static cutlass::ComplexTransform const kTransformA = TransformA; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + static cutlass::ComplexTransform const kTransformB = TransformB; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.transform_A = ComplexTransformMap::kId; + description_.B = make_TensorDescription(); + description_.transform_B = ComplexTransformMap::kId; + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); + + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + GemmUniversalConfiguration const &config = *static_cast(host_workspace); + GemmUniversalArguments const &args = *static_cast(arguments); + + TensorRefA ref_A{static_cast(const_cast(args.A)), LayoutA(int(config.lda))}; + TensorRefB ref_B{static_cast(const_cast(args.B)), LayoutB(int(config.ldb))}; + TensorRefC ref_C{static_cast(const_cast(args.C)), LayoutC(int(config.ldc))}; + TensorRefD ref_D{static_cast(args.D), LayoutC(int(config.ldd))}; + + if (kProvider == Provider::kReferenceHost) { + + cutlass::reference::host::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, + InnerProductOp + >( + config.problem_size, + *static_cast(args.alpha), + ref_A, + kTransformA, + ref_B, + kTransformB, + *static_cast(args.beta), + ref_C, + ref_D, + ElementAccumulator(), + ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), + args.batch_stride_A, + args.batch_stride_B, + args.batch_stride_C, + args.batch_stride_D + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + + cutlass::reference::device::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, + InnerProductOp + >( + config.problem_size, + *static_cast(args.alpha), + ref_A, + kTransformA, + ref_B, + kTransformB, + *static_cast(args.beta), + ref_C, + ref_D, + ElementAccumulator(), + ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), + args.batch_stride_A, + args.batch_stride_B, + args.batch_stride_C, + args.batch_stride_D + ); + + return Status::kSuccess; + } + + return Status::kErrorNotSupported; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename LayoutA_, + cutlass::ComplexTransform TransformA, + typename ElementB_, + typename LayoutB_, + cutlass::ComplexTransform TransformB, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new GemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, LayoutA_, TransformA, + ElementB_, LayoutB_, TransformB, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new GemmReferenceOperation< + Provider::kReferenceDevice, + ElementA_, LayoutA_, TransformA, + ElementB_, LayoutB_, TransformB, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >); +#endif +} + +/// Helper to create NN, NT, TN, and TT GEMM layouts. +template < + typename ElementA_, cutlass::ComplexTransform TransformA, + typename ElementB_, cutlass::ComplexTransform TransformB, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_canonical_layouts(Manifest &manifest) { + + // M Major outputs + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + // N Major outputs + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + + +/// Helper to create TN and interleaved layouts GEMM layouts. +template < + int InterleaveK, + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_interleaved_layouts(Manifest &manifest) { + + make_gemm< + ElementA_, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + +} + +/// Helper to real-valued GEMM with canonical layouts +template < + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_real_canonical_layouts(Manifest &manifest) { + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +// Helper to create all complex transformation permutations +template < + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_complex_canonical_layouts(Manifest &manifest) { + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kConjugate, + ElementB_, cutlass::ComplexTransform::kConjugate, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kConjugate, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kConjugate, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..01caa11e229ffd9109b0973dcca01064df448fa3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp @@ -0,0 +1,504 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/array.h" +#include "cutlass/array_subbyte.h" +#include "cutlass/library/library.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor +#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter +#include "cutlass/util/packed_stride.hpp" // make_cute_packed_stride +#include "gemm_operation_3x.hpp" +#include "library_internal.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cute/tensor.hpp" +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Limitation & Assumptions: +// 1. The tensor must be densely packed. That is, lda is k if the tensor is k-major, +// and lda is m if the tensor is m-major. +// 2. Circular buffer for tensorA and tensorE may have a less count compared to tensorB and others. +// This is because we can not get the problem_count information in the get_device_workspace_size(). +// But I can promise it will use at least 192MB memory if we enable circular buffer. +template +class SparseGemmUniversal3xOperation : public GemmOperation3xBase { +public: + + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementE = typename CollectiveMainloop::ElementE; + using LayoutE = typename CollectiveMainloop::LayoutE; + using SparseConfig = typename CollectiveMainloop::SparseConfig; + using LayoutATag = decltype(SparseConfig::deduce_layoutA_tag(typename CollectiveMainloop::LayoutA{})); + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutATag, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutATag, + SparseConfig, + typename Operator::ArchTag>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +public: + + /// Constructor + SparseGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) {} + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + CompressorUtility const& compressor_utility, + void* device_a_compressed_ptr = nullptr, + void* device_e_ptr = nullptr) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_E = static_cast(device_e_ptr); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.layout_a = compressor_utility.fill_layoutA_from_compressor(); + operator_args.mainloop.layout_e = compressor_utility.fill_layoutE_from_compressor(); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto problem_shape_MNKL = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + const int M = configuration->problem_size.m(); + const int N = configuration->problem_size.n(); + const int K = configuration->problem_size.k(); + const int L = configuration->batch_count; + using StrideA = typename CompressorUtility::StrideA; + auto dA = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + compressor_utility.set_problem_size(problem_shape_MNKL, dA); + auto status = update_arguments_(args, arguments, compressor_utility); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = problem_shape_MNKL; + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *) const override { + // Memory to hold operator + host_op_workspace_size = sizeof(Operator); + + // Memory to hold result of `.structure_sparse_zero_mask_fill()` + tensor_a_size = compressor_utility.get_raw_tensor_A_bytes(); + + // NOTE: order here is the order of workspace partition + const uint64_t size = host_op_workspace_size + tensor_a_size; + + return size; + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr), compressor_utility); + if (status != Status::kSuccess) { + return 0; + } + + typename Compressor::Arguments compress_arguments { + {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, + {/*Empty Not Use*/}, + {/*Empty Not Use*/} }; + + // Size for one iteration + // For multi-iteration, will need to multiply result of this function w/ actual problem_count + tensor_ac_size = compressor_utility.get_compressed_tensor_A_bytes(); + tensor_e_size = compressor_utility.get_tensor_E_bytes(); + device_op_workspace_size = Operator::get_workspace_size(args); + device_compress_workspace_size = Compressor::get_workspace_size(compress_arguments); + + // NOTE: order here is the order of workspace partition + device_per_iter_workspace_size = device_op_workspace_size + device_compress_workspace_size + tensor_ac_size + tensor_e_size; + + return device_per_iter_workspace_size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + return Status::kErrorInternal; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + + iter_idx.resize(static_cast(configuration)->device_count, 0); + + // Set problem_count. + problem_count = problem_count_from_profiler; + + // * Host Ptr + auto* host_op_workspace_ptr = reinterpret_cast(host_workspace); + auto* host_a_raw_ptr = host_op_workspace_ptr + host_op_workspace_size; + + // * Construct Op + Operator *op = new (host_op_workspace_ptr) Operator; + + // * Device Ptr (1st iteration) + // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | + // iteri : op_workspace | tensor_ac | tensor_e + auto* device_ptr_iter1 = static_cast(device_workspace); + auto* device_op_workspace_ptr_iter1 = device_ptr_iter1; + auto* device_compressor_workspace_ptr_iter1 = device_op_workspace_ptr_iter1 + device_op_workspace_size; + auto* device_a_compressed_ptr_iter1 = device_compressor_workspace_ptr_iter1 + device_compress_workspace_size; + auto* device_e_ptr_iter1 = device_a_compressed_ptr_iter1 + tensor_ac_size; + + // * Device A Raw Ptr + auto* device_a_raw_ptr = profiler_workspaces[0]; + + // * Random fill 50% of TensorA w/ zero following the structured sparse requirement + CUDA_CHECK(cudaMemcpyAsync(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost, stream)); + compressor_utility.structure_sparse_zero_mask_fill(host_a_raw_ptr, 2000); + CUDA_CHECK(cudaMemcpyAsync(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice, stream)); + + CUDA_CHECK(cudaGetLastError()); + + // * Compress DTensorA and get DTensorAC & DTensorE + cutlass::KernelHardwareInfo hw_info; + CUDA_CHECK(cudaGetDevice(&hw_info.device_id)); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, + {device_a_raw_ptr, + compressor_utility.dA, + device_a_compressed_ptr_iter1, + device_e_ptr_iter1}, + {hw_info} + }; + + cutlass::Status status {cutlass::Status::kSuccess }; + + Compressor compressor_op; + status = compressor_op.can_implement(arguments); + if (status != Status::kSuccess) { + return status; + } + + status = compressor_op.initialize(arguments, device_compressor_workspace_ptr_iter1, stream); + if (status != Status::kSuccess) { + return status; + } + + status = compressor_op.run(stream); + if (status != Status::kSuccess) { + return status; + } + + // * Copy Iter1's DTensorAC DTensorE to each iteration's DTensorAC DTensorE + for (int iter_i = 1; iter_i < problem_count; iter_i++) { + // * Device AC E Ptr per iteration + // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | + // iteri : op_workspace | tensor_ac | tensor_e + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_i; + auto* device_op_workspace_ptr = device_ptr_iteri; + auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; + auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; + auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; + + CUDA_CHECK(cudaMemcpyAsync(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice, stream)); + } + + CUDA_CHECK(cudaStreamSynchronize(stream)); + + CUDA_CHECK(cudaGetLastError()); + + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + + + const auto device_index = static_cast(arguments_ptr)->device_index; + + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_idx[device_index]; + auto* device_op_workspace_ptr = device_ptr_iteri; + auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; + auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; + auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; + iter_idx[device_index] = (iter_idx[device_index] + 1) % problem_count; + + Status status = update_arguments_(operator_args, static_cast(arguments_ptr), compressor_utility, device_a_compressed_ptr, device_e_ptr ); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(operator_args, device_op_workspace_ptr, stream, nullptr, + static_cast(arguments_ptr)->use_pdl); + return status; + } + +private: + // Variables that must change in the const functions. + mutable CompressorUtility compressor_utility; + mutable int problem_count = 1; + mutable std::vector iter_idx; + + mutable uint64_t tensor_ac_size = 0; + mutable uint64_t tensor_e_size = 0; + mutable uint64_t tensor_a_size = 0; + mutable uint64_t host_op_workspace_size = 0; + mutable uint64_t device_compress_workspace_size = 0; + mutable uint64_t device_op_workspace_size = 0; + mutable uint64_t device_per_iter_workspace_size = 0; +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..c95d238a81f825dbbeae689ec452467cc8ca3afa --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h @@ -0,0 +1,382 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Symm operation kinds (Symm, Hemm) + in CUTLASS Library. + + +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/symm.h" +#include "cutlass/gemm/kernel/default_symm_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class SymmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static SideMode const kSideModeA = Operator::kSideModeA; + static FillMode const kFillModeA = Operator::kFillModeA; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + SymmDescription description_; + +public: + + /// Constructor + SymmOperationBase(char const *name = "unknown_symm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.symm_kind = SymmKind::kUniversal; + description_.side_mode = kSideModeA; + description_.fill_mode = kFillModeA; + description_.blas_mode = kBlasMode; + + description_.kind = OperationKind::kSymm; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::SymmKernel::WarpCount::kM, + Operator::SymmKernel::WarpCount::kN, + Operator::SymmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the SYMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class SymmOperation : public SymmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static SideMode const kSideModeA = Operator::kSideModeA; + static FillMode const kFillModeA = Operator::kFillModeA; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + SymmOperation(char const *name = "unknown_symm"): + SymmOperationBase(name) { + + this->description_.symm_kind = SymmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + SymmConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + SymmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + SymmConfiguration const *configuration = + static_cast(configuration_ptr); + + SymmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + //std::cout << "initialize() library::SymmOperation" << std::endl; + //print_operator_args(args); + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + bool need_swapped_matrices = (kSideModeA == SideMode::kLeft && + std::is_same::value) || + (kSideModeA == SideMode::kRight && + std::is_same::value); + if (need_swapped_matrices) { + status = op->update(args.swapped_matrices(), device_workspace); + } else { + status = op->update(args, device_workspace); + } + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run() library::SymmOperation" << std::endl; + //print_operator_args(args); + status = op->run(stream); + + return status; + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "SymmOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " epilogue (alpha, beta): " + << operator_args.epilogue.alpha << ", " + << operator_args.epilogue.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ptr_A << ", {" + << operator_args.lda << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ptr_B << ", {" + << operator_args.ldb << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ptr_C << ", {" + << operator_args.ldc << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ptr_D << ", {" + << operator_args.ldd << "}" << std::endl; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..d419723791ace5d90eb7955223be9db72bbc2c3c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all TRMM operation kinds in CUTLASS Library. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/trmm.h" +#include "cutlass/gemm/kernel/default_trmm_universal.h" +#include "cutlass/gemm/kernel/trmm_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TrmmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + static SideMode const kSideMode = Operator::kSideMode; + static FillMode const kFillMode = Operator::kFillMode; + static DiagType const kDiagType = Operator::kDiagType; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + TrmmDescription description_; + +public: + + /// Constructor + TrmmOperationBase(char const *name = "unknown_trmm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kTrmm; + description_.trmm_kind = TrmmKind::kUniversal; + description_.side_mode = kSideMode; + description_.fill_mode = kFillMode; + description_.diag_type = kDiagType; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::TrmmKernel::WarpCount::kM, + Operator::TrmmKernel::WarpCount::kN, + Operator::TrmmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.D = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + } + + /// Returns the description of the TRMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TrmmOperation : public TrmmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + static SideMode const kSideMode = Operator::kSideMode; + static FillMode const kFillMode = Operator::kFillMode; + static DiagType const kDiagType = Operator::kDiagType; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + TrmmOperation(char const *name = "unknown_trmm"): + TrmmOperationBase(name) { + + this->description_.trmm_kind = TrmmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + TrmmConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + TrmmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.ptr_D = arguments->D; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + TrmmConfiguration const *configuration = + static_cast(configuration_ptr); + + TrmmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + bool need_swapped_matrices = (kSideMode == SideMode::kLeft && + std::is_same::value) || + (kSideMode == SideMode::kRight && + std::is_same::value); + if (need_swapped_matrices) { + status = op->update(args.swapped_matrices(), device_workspace); + } else { + status = op->update(args, device_workspace); + } + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..5d500d9149bf645eadf8110d98612c40882d742c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h @@ -0,0 +1,330 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Blockscale Gemm Profiler +*/ + + + +#pragma once + +#include +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class BlockScaledGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + // + // Methods + // + + /// Parses the problem + Status parse( + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::BlockScaledGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::BlockScaledGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::BlockScaledGemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::BlockScaledGemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *SFA{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *SFB{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + DeviceAllocation *Computed_SFD{nullptr}; + DeviceAllocation *Reference_SFD{nullptr}; + DeviceAllocation *Norm_constant{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::BlockScaledGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + cudaStream_t stream; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + GemmWorkspace gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + BlockScaledGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~BlockScaledGemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GemmWorkspace &gemm_workspace, + gemm::GemmCoord const &problem_shape, + std::array const &leading_dim, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration according to flexible user setups + void update_result_( + PerformanceResult &result, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + gemm::GemmCoord const &problem_shape, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..c110de278cac640c1cedd8dd29d1b8ac09de81ef --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Blockscale Gemm Profiler +*/ + + + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class BlockwiseGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + int64_t sf_vec_m{0}; + int64_t sf_vec_n{0}; + int64_t sf_vec_k{0}; + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + // + // Methods + // + + /// Parses the problem + Status parse( + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::BlockwiseGemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::BlockwiseGemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *SFA{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *SFB{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::BlockwiseGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + GemmWorkspace gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + BlockwiseGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~BlockwiseGemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..683465f50cda19c8d505f2e66bcb60173d7e942d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h @@ -0,0 +1,495 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for convolution + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/handle.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class Conv2dOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct Conv2dProblem { + + int64_t n, h, w, c, p, q, k, r, s; + int64_t groups; + int64_t pad_h, pad_w; + int64_t stride_h, stride_w; + int64_t dilation_h, dilation_w; + + std::vector alpha; + std::vector beta; + + library::SplitKMode split_k_mode; + int64_t split_k_slices; + + library::ConvModeID conv_mode; + + library::Provider eq_gemm_provider; + + // convolution with parallel interleaved reduction + // convolution epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (Conv2dProblem::alpha, Conv2dProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + // + // Methods + // + + /// Total number of bytes loaded + int64_t bytes(library::ConvDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::ConvDescription const &operation_desc) const; + + void set_default_output_size() { + p = ((h + pad_h - r * dilation_h) / stride_h) + 1; + q = ((w + pad_w - s * dilation_w) / stride_w) + 1; + } + + // Returns equivalent gemm problem size for convolution + cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c / groups)); + case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * h * w), int(c), int(k * r * s)); + case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(r * s * c), int(n * p * q)); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor A + std::vector extent_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(h), int(w), int(c)}; + case library::ConvKind::kDgrad: return {int(n), int(p), int(q), int(k)}; + case library::ConvKind::kWgrad: return {int(n), int(p), int(q), int(k)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor B + std::vector extent_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c / groups)}; + case library::ConvKind::kDgrad: return {int(k), int(r), int(s), int(c)}; + case library::ConvKind::kWgrad: return {int(n), int(h), int(w), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor C + std::vector extent_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(p), int(q), int(k)}; + case library::ConvKind::kDgrad: return {int(n), int(h), int(w), int(c)}; + case library::ConvKind::kWgrad: return {int(k), int(r), int(s), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix A + library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix B + library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix C + library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + // Gemm operator assumes column-major output + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix A + int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix B + int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix C + int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + + /// Workspace used + struct Conv2dWorkspace { + + /// Conv device allocations + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *reordered_B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + /// Library configuration and arguments for convolution operator + library::Conv2dConfiguration configuration; + library::ConvArguments arguments; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count; + + /// Buffer used for the cutlass conv2d operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// Host data buffers for host reference operation + /// host buffer for tensor + std::vector host_tensor_a; + + /// host buffer for tensor b + std::vector host_tensor_b; + + /// host buffer for tensor c + std::vector host_tensor_c; + + // + // Methods + // + + Conv2dWorkspace() + : A(nullptr), + B(nullptr), + reordered_B(nullptr), + C(nullptr), + Computed(nullptr), + Reference(nullptr) {} + + // Set stride vector for tensor activations, filters, output + void set_stride_vector(Conv2dProblem const &problem, + library::ConvKind const &conv_kind, + library::LayoutTypeID const &layout_a, + library::LayoutTypeID const &layout_b, + library::LayoutTypeID const &layout_c) { + std::vector stride_activations; + std::vector stride_filters; + std::vector stride_output; + + // Strides for interleaved fprop + if (conv_kind == library::ConvKind::kFprop && + ((layout_a == library::LayoutTypeID::kTensorNC32HW32 && + layout_b == library::LayoutTypeID::kTensorC32RSK32 && + layout_c == library::LayoutTypeID::kTensorNC32HW32) || + (layout_a == library::LayoutTypeID::kTensorNC64HW64 && + layout_b == library::LayoutTypeID::kTensorC64RSK64 && + layout_c == library::LayoutTypeID::kTensorNC64HW64))) { + int interleave = + (layout_a == library::LayoutTypeID::kTensorNC32HW32) ? 32 : 64; + + stride_activations.push_back(int(problem.w) * interleave); + stride_activations.push_back(int(problem.w) * int(problem.h) * + interleave); + stride_activations.push_back(int(problem.h) * int(problem.w) * + int(problem.c)); + + stride_filters.push_back(int(problem.k) * interleave); + stride_filters.push_back(int(problem.k) * int(problem.s) * interleave); + stride_filters.push_back(int(problem.k) * int(problem.s) * + int(problem.r) * interleave); + + stride_output.push_back(int(problem.q) * interleave); + stride_output.push_back(int(problem.q) * int(problem.p) * interleave); + stride_output.push_back(int(problem.q) * int(problem.p) * + int(problem.k)); + } else { + // Strides for the rest cases + stride_activations.push_back(int(problem.c)); + stride_activations.push_back(int(problem.w) * int(problem.c)); + stride_activations.push_back(int(problem.h) * int(problem.w) * + int(problem.c)); + + stride_filters.push_back(int(problem.c / problem.groups)); + stride_filters.push_back(int(problem.s) * int(problem.c / problem.groups)); + stride_filters.push_back(int(problem.r) * int(problem.s) * + int(problem.c / problem.groups)); + + stride_output.push_back(int(problem.k)); + stride_output.push_back(int(problem.q) * int(problem.k)); + stride_output.push_back(int(problem.q) * int(problem.p) * + int(problem.k)); + } + + switch (conv_kind) { + case library::ConvKind::kFprop: + configuration.stride_a = stride_activations; + configuration.stride_b = stride_filters; + configuration.stride_c = stride_output; + + break; + case library::ConvKind::kDgrad: + configuration.stride_a = stride_output; + configuration.stride_b = stride_filters; + configuration.stride_c = stride_activations; + + break; + case library::ConvKind::kWgrad: + configuration.stride_a = stride_output; + configuration.stride_b = stride_activations; + configuration.stride_c = stride_filters; + + break; + default: + throw std::runtime_error( + "Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + +protected: + + // + // Data members + // + + /// CONV problem obtained from problem space + Conv2dProblem problem_; + + /// Device memory allocations + Conv2dWorkspace conv_workspace_; + + /// CUTLASS parallel reduction operation to follow this* conv2d operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + Conv2dOperationProfiler(Options const &options); + + /// Destructor + virtual ~Conv2dOperationProfiler(); + + Conv2dProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::ConvDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against host reference + bool verify_with_host_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against device reference + bool verify_with_device_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#if CUTLASS_ENABLE_CUDNN + + /// Verifies CUTLASS against cudnn reference + bool verify_with_cudnn_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#endif //#if CUTLASS_ENABLE_CUDNN + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ac4abdef238b00f216053419620a60dfccfd5316 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h @@ -0,0 +1,449 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for convolution + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/handle.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class Conv3dOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct Conv3dProblem { + + int64_t n, d, h, w, c, z, p, q, k, t, r, s; + int64_t pad_d, pad_h, pad_w; + int64_t stride_d, stride_h, stride_w; + int64_t dilation_d, dilation_h, dilation_w; + + std::vector alpha; + std::vector beta; + + library::SplitKMode split_k_mode; + int64_t split_k_slices; + + library::ConvModeID conv_mode; + + library::Provider eq_gemm_provider; + + // convolution with parallel interleaved reduction + // convolution epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (Conv3dProblem::alpha, Conv3dProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + // + // Methods + // + + /// Total number of bytes loaded + int64_t bytes(library::ConvDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::ConvDescription const &operation_desc) const; + + /// Infers output size from the input size, padding, stride, and dilation + void set_default_output_size() { + z = ((d + pad_d - t * dilation_d) / stride_d) + 1; + p = ((h + pad_h - r * dilation_h) / stride_h) + 1; + q = ((w + pad_w - s * dilation_w) / stride_w) + 1; + } + + // Returns equivalent gemm problem size for convolution + cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * z * p * q), int(k), int(t * r * s * c)); + case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * d * h * w), int(c), int(t * r * s * k)); + case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(t * r * s * c), int(n * z * p * q)); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor A + std::vector extent_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(d), int(h), int(w), int(c)}; + case library::ConvKind::kDgrad: return {int(n), int(z), int(p), int(q), int(k)}; + case library::ConvKind::kWgrad: return {int(n), int(z), int(p), int(q), int(k)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor B + std::vector extent_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(k), int(t), int(r), int(s), int(c)}; + case library::ConvKind::kDgrad: return {int(k), int(t), int(r), int(s), int(c)}; + case library::ConvKind::kWgrad: return {int(n), int(d), int(h), int(w), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor C + std::vector extent_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(z), int(p), int(q), int(k)}; + case library::ConvKind::kDgrad: return {int(n), int(d), int(h), int(w), int(c)}; + case library::ConvKind::kWgrad: return {int(k), int(t), int(r), int(s), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix A + library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix B + library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix C + library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + // Gemm operator assumes column-major output + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix A + int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix B + int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix C + int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + + /// Workspace used + struct Conv2dWorkspace { + + /// Conv device allocations + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + /// Library configuration and arguments for convolution operator + library::Conv3dConfiguration configuration; + library::ConvArguments arguments; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count; + + /// Buffer used for the cutlass conv2d operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// Host data buffers for host reference operation + /// host buffer for tensor + std::vector host_tensor_a; + + /// host buffer for tensor b + std::vector host_tensor_b; + + /// host buffer for tensor c + std::vector host_tensor_c; + + + // + // Methods + // + + Conv2dWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + + // Returns stride vector for tensor A + std::vector stride_a(library::ConvKind const &conv_kind) { + return { + configuration.layout_a(conv_kind).stride()[0], + configuration.layout_a(conv_kind).stride()[1], + configuration.layout_a(conv_kind).stride()[2], + configuration.layout_a(conv_kind).stride()[3] + }; + } + + // Returns stride vector for tensor B + std::vector stride_b(library::ConvKind const &conv_kind) { + + return { + configuration.layout_b(conv_kind).stride()[0], + configuration.layout_b(conv_kind).stride()[1], + configuration.layout_b(conv_kind).stride()[2], + configuration.layout_b(conv_kind).stride()[3] + }; + } + + // Returns stride vector for tensor C + std::vector stride_c(library::ConvKind const &conv_kind) { + + return { + configuration.layout_c(conv_kind).stride()[0], + configuration.layout_c(conv_kind).stride()[1], + configuration.layout_c(conv_kind).stride()[2], + configuration.layout_c(conv_kind).stride()[3] + }; + } + }; + +protected: + + // + // Data members + // + + /// CONV problem obtained from problem space + Conv3dProblem problem_; + + /// Device memory allocations + Conv2dWorkspace conv_workspace_; + + /// CUTLASS parallel reduction operation to follow this* conv2d operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + Conv3dOperationProfiler(Options const &options); + + /// Destructor + virtual ~Conv3dOperationProfiler(); + + Conv3dProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Updates the arguments structure for the CUTLASS operator based on + /// the problem index. + void set_cutlass_operator_arguments_(int problem_idx = 0); + + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::ConvDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against host reference + bool verify_with_host_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against device reference + bool verify_with_device_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#if CUTLASS_ENABLE_CUDNN + + /// Verifies CUTLASS against cudnn reference + bool verify_with_cudnn_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#endif //#if CUTLASS_ENABLE_CUDNN + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..873ba1abe03c05df29edc032ea3f1ffd2f19c3ee --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h @@ -0,0 +1,456 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Helper functions for mapping CUTLASS concepts to cuBLAS. +*/ + +#pragma once + +#if CUTLASS_ENABLE_CUBLAS +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/blas3.h" + +#include "options.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Converts a cuBLAS status to cutlass::Status +Status get_cutlass_status(cublasStatus_t cublas); + +/// Converts a cuBLAS status to cutlass::profiler::Disposition +Disposition get_cutlass_disposition(cublasStatus_t cublas_status); + +/// Maps a CUTLASS tensor layout to a cuBLAS transpose operation +bool get_cublas_transpose_operation( + cublasOperation_t &operation, + library::LayoutTypeID layout, + library::ComplexTransform transform = library::ComplexTransform::kNone); + +/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration +bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type); + +/// Gets the cublas algorithm given threadblock tile dimensions and math opcode class +cublasGemmAlgo_t get_cublas_gemm_algo( + int cta_m, + int cta_n, + int cta_k, + library::OpcodeClassID opcode_class); + +/// Returns a status if cuBLAS can satisfy a particular GEMM description +Status cublas_satisfies(library::GemmDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular RankK description +Status cublas_satisfies(library::RankKDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular TRMM description +Status cublas_satisfies(library::TrmmDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular SYMM/HEMM description +Status cublas_satisfies(library::SymmDescription const &desc); + +/// This is a helper class to create cublasHandle_t automatically on CublasCreate object creation and +/// to destroy cublasHandle_t on CublasCreate object destruction. +/// Additionally, it provides implicit cast from CublasCreate's object to cublasHandle_t's object +class CublasCreate { +private: + cublasHandle_t handle; + cublasStatus_t status; + +public: + CublasCreate() { + status = cublasCreate(&handle); + } + + ~CublasCreate() { + cublasDestroy(handle); + } + + /// Implicit cast CublasCreate object to cublasHandle_t + operator cublasHandle_t() const { return handle; } + + /// returns cublasStatus_t for handle creation + cublasStatus_t get_cublas_create_status() { return status; } +}; + +/// This is a helper class to create cublasLtHandle_t automatically on CublasLtCreate object creation and +/// to destroy cublasLtHandle_t on CublasLtCreate object destruction. +/// Additionally, it provides implicit cast from CublasLtCreate's object to cublasLtHandle_t's object +class CublasLtCreate { +private: + cublasLtHandle_t handle; + cublasStatus_t status; + +public: + CublasLtCreate() { + status = cublasLtCreate(&handle); + } + + ~CublasLtCreate() { + cublasLtDestroy(handle); + } + + /// Implicit cast CublasLtCreate object to cublasLtHandle_t + operator cublasLtHandle_t() const { return handle; } + + /// returns cublasLtStatus_t for handle creation + cublasStatus_t get_cublaslt_create_status() { return status; } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Selects one or more cuBLAS algorithms. +static void select_cublas_algorithms( + std::vector &algorithms, + Options const &options, + library::GemmDescription const &op_desc) { + + library::OpcodeClassID const & opcode_class = + op_desc.tile_description.math_instruction.opcode_class; + + switch (options.library.algorithm_mode) { + case AlgorithmMode::kMatching: + { + algorithms.push_back(get_cublas_gemm_algo( + op_desc.tile_description.threadblock_shape.m(), + op_desc.tile_description.threadblock_shape.n(), + op_desc.tile_description.threadblock_shape.k(), + opcode_class)); + break; + } + + case AlgorithmMode::kBest: + { + // Choose first enumerated mode. If none are enumerated, choose based on opcode class + // and evaluate all of them. + + if (options.library.algorithms.empty()) { + // Enumerate all algorithms + if (opcode_class == library::OpcodeClassID::kSimt) { + + for (int algo = CUBLAS_GEMM_DEFAULT; + algo <= CUBLAS_GEMM_ALGO23; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + else { + + for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + } + else { + // Use the listed algorithms + algorithms.reserve(options.library.algorithms.size()); + + for (int algo : options.library.algorithms) { + algorithms.push_back(reinterpret_cast(algo)); + } + } + + break; + } + + case AlgorithmMode::kDefault: + { + + // Use the library's default algorithm + algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? + CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + break; + } + default: + { + break; + } + } +} + +/// Dispatcher to cublasGemmEx() +struct cublasGemmExDispatcher { + + // + // Data members + // + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasOperation_t trans_B; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + cublasGemmAlgo_t algo; + Status status; + + // + // Methods + // + + cublasGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmUniversalConfiguration configuration_, + library::GemmUniversalArguments arguments_, + cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT + ); + + /// Executes GEMM using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/// Dispatcher to cublaslt kernels +// +struct cublasLtGemmExDispatcher { + + // + // Data members + // + library::GemmDescription const &op_desc; + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasOperation_t trans_B; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type = CUDA_R_32F; + + //cublasLt-specific data structures + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL; + cublasLtMatmulPreference_t preference = NULL; + + //is set by call to get_cublaslt_algo() + cublasLtMatmulHeuristicResult_t heuristicResult_; + void *workspace = nullptr; + + Status status; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + // + // Methods + // + + cublasLtGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmUniversalConfiguration configuration_, + library::GemmUniversalArguments arguments_ + ); + + /// Initialize the cublasLt variables + void initialize_cublaslt(); + + + /// Runs auto-tuning for the cublas heuristics + bool get_cublaslt_algo(cublasLtHandle_t handle, + AlgorithmMode algorithm_mode + ); + + /// Executes GEMM using these arguments + cublasStatus_t operator()(cublasLtHandle_t handle, cudaStream_t stream = nullptr); + + ~cublasLtGemmExDispatcher(){ + + // descriptors are no longer needed as all GPU work was already enqueued + if (preference) cublasLtMatmulPreferenceDestroy(preference); + if (Ddesc) cublasLtMatrixLayoutDestroy(Ddesc); + if (Cdesc) cublasLtMatrixLayoutDestroy(Cdesc); + if (Bdesc) cublasLtMatrixLayoutDestroy(Bdesc); + if (Adesc) cublasLtMatrixLayoutDestroy(Adesc); + if (operationDesc) cublasLtMatmulDescDestroy(operationDesc); + + if (workspace) { + cudaFree(workspace); + } + + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublas rank k update kernels +struct cublasRankKDispatcher { + + // + // Data members + // + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasFillMode_t uplo; + cudaDataType_t data_type_A; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + int num_ranks; //(rank-k or rank-2k) + BlasMode blas_mode; //(symmetric or hermitian) + Status status; + + // + // Methods + // + + cublasRankKDispatcher( + library::RankKDescription const &op_desc, + library::RankKConfiguration configuration_, + library::RankKArguments arguments_ + ); + + /// Executes RankK using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublasTrmm() +struct cublasTrmmDispatcher { + + // + // Data members + // + library::TrmmConfiguration configuration; + library::TrmmArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasSideMode_t side; + cublasFillMode_t uplo; + cublasDiagType_t diag; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_D; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + Status status; + + // + // Methods + // + + cublasTrmmDispatcher( + library::TrmmDescription const &op_desc, + library::TrmmConfiguration configuration_, + library::TrmmArguments arguments_ + ); + + /// Executes TRMM using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublas symm/hemm update kernels +struct cublasSymmDispatcher { + + // + // Data members + // + library::SymmConfiguration configuration; + library::SymmArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasSideMode_t side; + cublasFillMode_t uplo; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + BlasMode blas_mode; //(symmetric or hermitian) + Status status; + + // + // Methods + // + + cublasSymmDispatcher( + library::SymmDescription const &op_desc, + library::SymmConfiguration configuration_, + library::SymmArguments arguments_ + ); + + /// Executes Symm using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +} // namespace profiler +} // namespace cutlass + + +#endif // #if CUTLASS_ENABLE_CUBLAS diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..7ce9eea5a883fa4c5732f5d8aec120a99064bac0 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h @@ -0,0 +1,590 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Helper functions for mapping CUTLASS concepts to cuDNN. + +*/ + +#pragma once +#if CUTLASS_ENABLE_CUDNN +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/library/library.h" +#include "enumerated_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Converts a cuDNN status to cutlass::Status +Status get_cutlass_status(cudnnStatus_t cudnn_status); + +/// Converts a cuDNN status to cutlass::profiler::Disposition +Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status); + +/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception +Status checkCudnnErr(cudnnStatus_t cudnn_status); + +/// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration +bool get_cudnn_conv_mode(cudnnConvolutionMode_t &cudnn_conv_mode, conv::Mode conv_mode); + +/// Maps a CUTLASS layout type to a cuDNN data type enumeration +bool get_cudnn_layout(cudnnTensorFormat_t &cudnn_layout, library::LayoutTypeID layout); + +/// Maps a CUTLASS numeric type to a cuDNN data type enumeration +bool get_cudnn_datatype(cudnnDataType_t &cudnn_element_type, library::NumericTypeID element_type); + +/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type +bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescription const &conv_desc); + +/// Returns a status if cudnn can satisfy a particular Conv2d description +Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv2dConfiguration const &configuration); + +/// Returns a status if cudnn can satisfy a particular Conv3d description +Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv3dConfiguration const &configuration); + +/// Cudnn compute type seems to be hardcoded to float (To handle a possible cudnn issue) +float cast_cudnn_compute_type_to_float(library::NumericTypeID type, void const * src); + + +/// This is a helper class to create cudnnHandle_t automatically on CudnnCreate object creation and +/// to destroy cudnnHandle_t on CudnnCreate object destruction. +/// Additionally, it provides implicit cast from CudnnCreate's object to cudnnHandle_t's object +class CudnnCreate { +private: + cudnnHandle_t handle; + cudnnStatus_t status; + +public: + CudnnCreate() { + status = cudnnCreate(&handle); + } + + ~CudnnCreate() { + cudnnDestroy(handle); + } + + /// Implicit cast CudnnCreate object to cudnnHandle_t + operator cudnnHandle_t() const { return handle; } + + /// returns cudnnStatus_t for handle creation + cudnnStatus_t get_cudnn_create_status() { return status; } +}; + + +namespace detail { + +/// Dispatcher to cudnn convolution operators +struct cudnnConvDispatcher { + + // + // Data members + // + //library::Conv2dConfiguration configuration; + library::ConvArguments arguments; + library::ConvKind conv_kind; + + // cudnn-specific data structures to fill cudnn API call arguments + // cudnn activation, filter, and output descriptors + cudnnTensorDescriptor_t activation_desc; + cudnnFilterDescriptor_t filter_desc; + cudnnTensorDescriptor_t output_desc; + cudnnConvolutionDescriptor_t conv_desc; + + // cudnn datatypes + cudnnDataType_t data_type_activation; + cudnnDataType_t data_type_filter; + cudnnDataType_t data_type_output; + + // cudnn layouts + cudnnTensorFormat_t layout_activation; + cudnnTensorFormat_t layout_filter; + cudnnTensorFormat_t layout_output; + + // cudnn convolution mode + cudnnConvolutionMode_t conv_mode; + + // cudnn math type (tensorop, tensorop with conversion, simt) + cudnnMathType_t math_type; + + // cudnn compute data type + cudnnDataType_t compute_type; + + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + float alpha; + float beta; + + // cudnn workspace + size_t workspace_size_in_bytes = 0; + cutlass::device_memory::allocation workspace; + + // select cudnn's implicit gemm precomputed algorithm with tensor operations + static cudnnConvolutionFwdAlgo_t const fprop_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + static cudnnConvolutionBwdDataAlgo_t const dgrad_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + static cudnnConvolutionBwdFilterAlgo_t const wgrad_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + + Status status; + + // + // Methods + // + + // TODO: unify ctor cudnnConvDispatcher for conv2d and conv3d by unifying Conv2dConfiguration + + // ctor for conv2d + cudnnConvDispatcher( + library::ConvDescription const &op_desc, + library::Conv2dConfiguration configuration, + library::ConvArguments arguments_, + cudnnHandle_t handle + ): + //configuration(configuration_), + arguments(arguments_), + conv_kind(op_desc.conv_kind), + status(Status::kSuccess) { + + bool good = true; + + // Get cudnn datatype, layout, and convolution mode from library::ConvDescription + good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); + good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); + good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); + good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); + good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); + good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); + good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); + // Get cudnn mathtype (cudnnMathType_t) + good = (good && get_cudnn_mathtype(math_type, op_desc)); + good = (good && get_cudnn_datatype( + compute_type, + op_desc.tile_description.math_instruction.element_accumulator)); + // Check cutlass Conv2d description has equivalent operator in cudnn + if (!good) { + status = Status::kErrorNotSupported; + return; + } + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); + beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); + + // Create convolution descriptor object + status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); + + // Configure convolution operator + std::vector padding {configuration.problem_size.pad_h, configuration.problem_size.pad_w}; + std::vector stride {configuration.problem_size.stride_h, configuration.problem_size.stride_w}; + std::vector dilation {configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; + + status = get_cutlass_status( + cudnnSetConvolutionNdDescriptor( + conv_desc, + op_desc.conv_dim, + padding.data(), + stride.data(), + dilation.data(), + conv_mode, + compute_type + )); + + // Set groups + status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); + + // Create activation, filter, and output descriptor objects + status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); + status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); + status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); + + // Set activation, filter, and output descriptor + status = get_cutlass_status( + cudnnSetTensor4dDescriptor( + activation_desc, + layout_activation, + data_type_activation, + configuration.problem_size.N, + configuration.problem_size.C, + configuration.problem_size.H, + configuration.problem_size.W + )); + + status = get_cutlass_status( + cudnnSetFilter4dDescriptor( + filter_desc, + data_type_filter, + layout_filter, + configuration.problem_size.K, + configuration.problem_size.C / configuration.problem_size.groups, + configuration.problem_size.R, + configuration.problem_size.S + )); + + status = get_cutlass_status( + cudnnSetTensor4dDescriptor( + output_desc, + layout_output, + data_type_output, + configuration.problem_size.N, + configuration.problem_size.K, + configuration.problem_size.P, + configuration.problem_size.Q + )); + + // Set math instruction to tensor op + status = get_cutlass_status( + cudnnSetConvolutionMathType(conv_desc, math_type)); + + // Initialize workspace + switch (conv_kind) { + case library::ConvKind::kFprop: + status = get_cutlass_status( + cudnnGetConvolutionForwardWorkspaceSize( + handle, + activation_desc, + filter_desc, + conv_desc, + output_desc, + fprop_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kDgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, + filter_desc, + output_desc, + conv_desc, + activation_desc, + dgrad_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kWgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, + activation_desc, + output_desc, + conv_desc, + filter_desc, + wgrad_algo, + &workspace_size_in_bytes + )); break; + + } + + workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); + } + + + // ctor for conv3d + cudnnConvDispatcher( + library::ConvDescription const &op_desc, + library::Conv3dConfiguration configuration, + library::ConvArguments arguments_, + cudnnHandle_t handle + ): + //configuration(configuration_), + arguments(arguments_), + conv_kind(op_desc.conv_kind), + status(Status::kSuccess) { + + bool good = true; + + // Get cudnn datatype, layout, and convolution mode from library::ConvDescription + good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); + good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); + good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); + + good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); + good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); + good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); + + good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); + + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); + beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); + + good = (good && get_cudnn_datatype( + compute_type, + op_desc.tile_description.math_instruction.element_accumulator)); + + // Check cutlass Conv2d description has equivalent operator in cudnn + if (!good) { + status = Status::kErrorNotSupported; + } + + // Create convolution descriptor object + status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); + + // Configure convolution operator + std::vector padding {configuration.problem_size.pad_d, configuration.problem_size.pad_h, configuration.problem_size.pad_w}; + std::vector stride {configuration.problem_size.stride_d, configuration.problem_size.stride_h, configuration.problem_size.stride_w}; + std::vector dilation {configuration.problem_size.dilation_d, configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; + + status = get_cutlass_status( + cudnnSetConvolutionNdDescriptor( + conv_desc, + op_desc.conv_dim, + padding.data(), + stride.data(), + dilation.data(), + conv_mode, + compute_type + )); + + // Set groups + status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); + + // Create activation, filter, and output descriptor objects + status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); + status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); + status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); + + // Set activation descriptor + std::vector activation_extent { + configuration.problem_size.N, + configuration.problem_size.C, + configuration.problem_size.D, + configuration.problem_size.H, + configuration.problem_size.W + }; + + std::vector activation_stride { + configuration.layout_activations.stride()[3], + 1, + configuration.layout_activations.stride()[2], + configuration.layout_activations.stride()[1], + configuration.layout_activations.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetTensorNdDescriptor( + activation_desc, + data_type_activation, + op_desc.conv_dim + 2, + activation_extent.data(), + activation_stride.data() + )); + + // Set filter descriptor + std::vector filter_extent { + configuration.problem_size.K, + configuration.problem_size.C, + configuration.problem_size.T, + configuration.problem_size.R, + configuration.problem_size.S + }; + + std::vector filter_stride { + configuration.layout_filters.stride()[3], + 1, + configuration.layout_filters.stride()[2], + configuration.layout_filters.stride()[1], + configuration.layout_filters.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetFilterNdDescriptor( + filter_desc, + data_type_filter, + layout_filter, + op_desc.conv_dim + 2, + filter_extent.data() + )); + + + // Set output descriptor + std::vector output_extent { + configuration.problem_size.N, + configuration.problem_size.K, + configuration.problem_size.Z, + configuration.problem_size.P, + configuration.problem_size.Q + }; + + std::vector output_stride { + configuration.layout_output.stride()[3], + 1, + configuration.layout_output.stride()[2], + configuration.layout_output.stride()[1], + configuration.layout_output.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetTensorNdDescriptor( + output_desc, + data_type_output, + op_desc.conv_dim + 2, + output_extent.data(), + output_stride.data() + )); + + // Set math instruction to tensor op + status = get_cutlass_status( + cudnnSetConvolutionMathType(conv_desc, math_type)); + + // Initialize workspace + switch (conv_kind) { + case library::ConvKind::kFprop: + status = get_cutlass_status( + cudnnGetConvolutionForwardWorkspaceSize( + handle, + activation_desc, + filter_desc, + conv_desc, + output_desc, + fprop_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kDgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, + filter_desc, + output_desc, + conv_desc, + activation_desc, + dgrad_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kWgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, + activation_desc, + output_desc, + conv_desc, + filter_desc, + wgrad_algo, + &workspace_size_in_bytes + )); break; + + } + + workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); + } + + /// Executes Conv2d operator from cudnn library + cudnnStatus_t operator()(cudnnHandle_t handle) { + + switch (conv_kind) { + case library::ConvKind::kFprop: + return cudnnConvolutionForward( + handle, + &alpha, + activation_desc, + activation(), + filter_desc, + filter(), + conv_desc, + fprop_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + output_desc, + arguments.D + ); + case library::ConvKind::kDgrad: + return cudnnConvolutionBackwardData( + handle, + &alpha, + filter_desc, + filter(), + output_desc, + output(), + conv_desc, + dgrad_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + activation_desc, + arguments.D + ); + case library::ConvKind::kWgrad: + return cudnnConvolutionBackwardFilter( + handle, + &alpha, + activation_desc, + activation(), + output_desc, + output(), + conv_desc, + wgrad_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + filter_desc, + arguments.D + ); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Activation Tensor + void const * activation() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.A; + case library::ConvKind::kDgrad : return arguments.C; + case library::ConvKind::kWgrad : return arguments.B; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Filter Tensor + void const *filter() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.B; + case library::ConvKind::kDgrad : return arguments.B; + case library::ConvKind::kWgrad : return arguments.C; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Output Tensor + void const *output() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.C; + case library::ConvKind::kDgrad : return arguments.A; + case library::ConvKind::kWgrad : return arguments.A; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } +}; + +} // namespace detail +///////////////////////////////////////////////////////////////////////////////////////////////// +#endif //#if CUTLASS_ENABLE_CUDNN +} // namespace profiler +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..be82245325cebb147e2c801965a52ece91395cb2 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment +*/ + +#pragma once +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +#include "options.h" +#include "operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// CUTLASS Profiler application +class CutlassProfiler { +private: + + // + // Data members + // + + /// Performance testbench options + Options options_; + + /// Entry points for each operation + OperationProfilerVector operation_profilers_; + +private: + + /// Prints usage + void print_usage_(std::ostream &); + + /// Prints usage + void print_options_(std::ostream &); + + /// Enumerates all operations + void enumerate_(); + + /// Profiles all operations + int profile_(); + +public: + + CutlassProfiler(Options const &options); + ~CutlassProfiler(); + + /// Invokes profiling operations + int operator()(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..98f1fdc3044501e456c927471b30d74b09eafd39 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h @@ -0,0 +1,56 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief +*/ + +#pragma once + +#include + +//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } +//#define report(x) {} + +// Enable/Disable Profiler debug prints +//#define DEBUG_PROFILER + +//RED 31m // profiler prints debug messages in red +//YELLOW 33m // ir prints debug messages in yellow + +#ifndef DEBUG_PROFILER +#define debugprof(...) +#else +#define debugprof(...) do { \ + printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ + printf(__VA_ARGS__); \ + printf("\033[0m\n"); \ + } while (0) +#endif diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h new file mode 100644 index 0000000000000000000000000000000000000000..488b635c2ec233e3027303bbf15a34f375a438fd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/library/library.h" +#include "cutlass/util/distribution.h" + +#include "enumerated_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Device memory allocation +class DeviceAllocation { +private: + + /// Data type of contained elements + library::NumericTypeID type_; + + /// Gets the stride between elements + size_t batch_stride_; + + /// Capacity in elements of device allocation + size_t capacity_; + + /// Pointer to device memory + void *pointer_; + + /// Layout type ID + library::LayoutTypeID layout_; + + /// Stride vector + std::vector stride_; + + /// Extent vector + std::vector extent_; + + /// Support allocating a 'batch' of non-overlapping tensors in contiguous memory + int batch_count_; + + /// Buffer holding TensorRef instance to recently allocated memory + std::vector tensor_ref_buffer_; + + /// The device ID where the allocation is made + int device_; + +public: + // + // Static member functions + // + + /// Determines the number of bytes needed to represent this numeric type + static size_t bytes(library::NumericTypeID type, size_t capacity); + + /// Returns the stride of a packed layout + static std::vector get_packed_layout( + library::LayoutTypeID layout_id, + std::vector const &extent); + + /// returns the capacity needed + static size_t construct_layout( + void *bytes, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector &stride); + + /// Returns true if two blocks have exactly the same value + static bool block_compare_equal( + library::NumericTypeID numeric_type, + void const *ptr_A, + void const *ptr_B, + size_t capacity); + + /// Returns true if two blocks have approximately the same value + static bool block_compare_relatively_equal( + library::NumericTypeID numeric_type, + void const *ptr_A, + void const *ptr_B, + size_t capacity, + double epsilon, + double nonzero_floor); + +public: + // + // Methods + // + + DeviceAllocation(); + + DeviceAllocation( + library::NumericTypeID type, + size_t capacity, + int device = -1); + + DeviceAllocation( + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride = std::vector(), + int batch_count = 1, + int device = -1); + + ~DeviceAllocation(); + + DeviceAllocation &reset(); + + /// Allocates device memory of a given type and capacity + DeviceAllocation &reset(library::NumericTypeID type, size_t capacity); + + /// Allocates memory for a given layout and tensor + DeviceAllocation &reset( + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride = std::vector(), + int batch_count = 1); + + /// Returns a buffer owning the tensor reference + std::vector &tensor_ref() { + return tensor_ref_buffer_; + } + + bool good() const; + + /// Data type of contained elements + library::NumericTypeID type() const; + + /// Pointer to start of device memory allocation + void *data() const; + + /// Pointer to the first element of a batch + void *batch_data(int batch_idx) const; + + /// Gets the layout type + library::LayoutTypeID layout() const; + + /// Gets the stride vector + std::vector const & stride() const; + + /// Gets the extent vector + std::vector const & extent() const; + + /// Gets the number of adjacent tensors in memory + int batch_count() const; + + /// Gets the stride (in units of elements) between items + int64_t batch_stride() const; + + /// Gets the stride (in units of bytes) between items + int64_t batch_stride_bytes() const; + + /// Capacity of allocation in number of elements + size_t capacity() const; + + /// Capacity of allocation in bytes + size_t bytes() const; + + /// Initializes a device allocation to a random distribution using cuRAND + void initialize_random_device(int seed, Distribution dist); + + /// Initializes a host allocation to a random distribution using std::cout + void initialize_random_host(int seed, Distribution dist); + + /// Initializes a device allocation to a sequential distribution + void initialize_sequential_device(Distribution dist); + + /// Initializes a host allocation to a sequential distribution + void initialize_sequential_host(Distribution dist); + + /// Initializes a device allocation to a random distribution using cuRAND + void initialize_random_sparsemeta_device(int seed, int MetaSizeInBits); + + /// Initializes a host allocation to a random distribution using std::cout + void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits); + + /// Uniformly fills a tensor with a value when provided o.w. zero + void fill_device(double value); + + /// Uniformly fills a host allocation with a value when provided o.w. zero + void fill_host(double value); + + /// Copies from an equivalent-sized tensor in device memory + void copy_from_device(void const *ptr); + + /// Copies from an equivalent-sized tensor in device memory + void copy_from_host(void const *ptr); + + /// Copies from an equivalent-sized tensor in device memory + void copy_to_host(void *ptr); + + /// Writes a tensor to csv + void write_tensor_csv(std::ostream &out); + +private: + /// A wrapper that sets the device, performs malloc, and sets back + cudaError_t malloc(void** ptr, size_t size); +}; + +using DeviceAllocationList = std::list; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h new file mode 100644 index 0000000000000000000000000000000000000000..0443b340397426bfafc812c1a4b9179fc6af0de4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief +*/ + +#pragma once + +#include +#include + + +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" + +#include "options.h" +#include "device_allocation.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Collection of allocations on the device +class DeviceContext { +public: + + // + // Type definitions + // + using AllocationMap = std::map; + +private: + // + // Data members + // + + /// Memory allocations that exist (owning) + DeviceAllocationList device_memory_; + + /// Non-owning set of named allocations + AllocationMap allocations_; + +public: + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_block( + Options const &options, + std::string const &name, + library::NumericTypeID type, + size_t capacity, + size_t device_index); + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride, + int batch_count, + size_t device_index); + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_and_initialize_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride, + int batch_count, + int seed_shift, + size_t device_index); + + /// Allocates memory for sparse meta data + DeviceAllocation *allocate_and_initialize_sparsemeta_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + library::NumericTypeID type_a, + std::vector const &extent, + std::vector const &stride, + int batch_count, + int seed_shift, + size_t device_index); + + /// Clears named allocations (but does not necessarily free memory) + void clear(); + + /// Frees all device memory allocations + void free(); + + /// Gets the allocation by name + DeviceAllocation &at(std::string const &name); + + size_t size() const; + + AllocationMap::iterator begin(); + AllocationMap::iterator end(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h new file mode 100644 index 0000000000000000000000000000000000000000..897311c228ce76c4e8814ce996929561d44d2465 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h @@ -0,0 +1,169 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +#include +#include +#include +#include +#include "cutlass/library/library.h" + +#define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +T from_string(std::string const &); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumerated type describing how the performance testbench evaluates kernels. +enum class ExecutionMode { + kProfile, ///< regular verification and profiling + kDryRun, ///< no kernels are launched or workspaces allocated; used to assess what operators might be launched + kEnumerate, ///< no kernels launched or workspaces allocated; lists all operation kind and operations + kTrace, ///< executes a single device-side computation with no other kernel launches + kInvalid +}; + +/// Converts a ExecutionMode enumerant to a string +char const *to_string(ExecutionMode mode, bool pretty = false); + +/// Parses a ExecutionMode enumerant from a string +template <> +ExecutionMode from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Library algorithm mode +enum class AlgorithmMode { + kMatching, ///< compare against best matching algorithm + kBest, ///< evaluate all library algorithms and report best + kDefault, ///< use the library's default algorithm option + kInvalid +}; + +/// Converts a ExecutionMode enumerant to a string +char const *to_string(AlgorithmMode mode, bool pretty = false); + +/// Parses a ExecutionMode enumerant from a string +template <> +AlgorithmMode from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Outcome of a performance test +enum class Disposition { + kPassed, + kFailed, // kernel itself reported an error + kNotRun, + kIncorrect, // kernel finished without a detected error, but result does not equal expected result + kNotVerified, + kInvalidProblem, + kNotSupported, + kInvalid +}; + +/// Converts a Disposition enumerant to a string +char const *to_string(Disposition disposition, bool pretty = false); + +/// Parses a Disposition enumerant from a string +template <> +Disposition from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Indicates when to save +enum class SaveWorkspace { + kNever, + kIncorrect, + kAlways, + kInvalid +}; + +/// Converts a SaveWorkspace enumerant to a string +char const *to_string(SaveWorkspace save_option, bool pretty = false); + +/// Parses a SaveWorkspace enumerant from a string +template <> +SaveWorkspace from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Indicates the type of kernel argument +// ArgumentType can be both ScalarType or NumericType. Thus, enums kScalar and kNumeric +// 1) kScalar: e.g. of a Scalar ArgumentType is u32 is a Scalar type. +// Its c++ equivalent as "type name = initializer" is "u32 m = 32" +// 2) kNumeric: e.g. of a Numeric ArgumentType is NumericTypeID is a Numeric type. +// Its c++ equivalent as "type name = initializer" is "NumericTypeID numeric_type = u32" +enum class ArgumentTypeID { + kScalar, + kInteger, + kTensor, + kBatchedTensor, + kStructure, + kEnumerated, + kInvalid +}; + +/// Converts a ArgumentTypeID enumerant to a string +char const *to_string(ArgumentTypeID type, bool pretty = false); + +/// Parses a ArgumentTypeID enumerant from a string +template <> +ArgumentTypeID from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Profiler typedefs +using ProviderVector = std::vector; +using DispositionMap = std::map; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Print vector for the report +template +std::ostream& operator<< (std::ostream& out, const std::vector& v) { + for (size_t i = 0; i < v.size(); ++i) { + out << to_string(v[i], true) << (i + 1u != v.size() ? "," : ""); + } + return out; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..faf317152473cac6dc62ecf8970cd1acfb2c1622 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -0,0 +1,333 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Gemm Profiler +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class GemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + + bool enable_sm90_mixed_dtype_shuffle_test{false}; + + // + // Methods + // + + /// Parses the problem + Status parse( + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::GemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::GemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::GemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::GemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// For mixed input dtype kernels + DeviceAllocation *Scale{nullptr}; // Scale tensor + DeviceAllocation *Zero{nullptr}; // Zero tensor + DeviceAllocation *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification + DeviceAllocation *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle + DeviceAllocation *packed_Scale{nullptr}; // Packed scale for int4 * fp8 + + cudaStream_t stream; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + std::vector gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + GemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~GemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GemmWorkspace &gemm_workspace, + gemm::GemmCoord const &problem_shape, + std::array const &leading_dim, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration according to flexible user setups + void update_result_( + PerformanceResult &result, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space, + gemm::GemmCoord const &problem_shape, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + GemmWorkspace &gemm_workspace); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h new file mode 100644 index 0000000000000000000000000000000000000000..154045295d6443d930ba53387366f4b8abe408a4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct GpuTimer { + + cudaEvent_t events[2]; + + // + // Methods + // + + GpuTimer(); + + GpuTimer(GpuTimer const&) = delete; + + GpuTimer(GpuTimer &&gpu_timer) noexcept; + + ~GpuTimer(); + + /// Records a start event in the stream, the flag is for cudaEventRecordWithFlags + void start(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Records a stop event in the stream, the flag is for cudaEventRecordWithFlags + void stop(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Records a stop event in the stream and synchronizes on the stream, the flag is for cudaEventRecordWithFlags + void stop_and_wait(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Returns the duration in milliseconds + double duration(int iterations = 1) const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..62d47990584cbb984935a00a267cff15dbb4f4e5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief GroupedGemm Profiler +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +// Profiler includes +#include "device_context.h" +#include "operation_profiler.h" +#include "options.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class GroupedGemmOperationProfiler : public OperationProfiler { +public: + /// Problem structure obtained from problem space + struct GroupedGemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGrouped}; + + std::vector problem_sizes; + std::vector> problem_sizes_3x; + + /// For exploration purposes + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + std::vector lda{0}; + std::vector ldb{0}; + std::vector ldc{0}; + + std::vector alpha; + std::vector beta; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + bool use_pdl{false}; + + /// Parses the problem + Status parse( + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + int64_t m(int group_idx) const { return problem_sizes[group_idx].m(); }; + int64_t n(int group_idx) const { return problem_sizes[group_idx].n(); }; + int64_t k(int group_idx) const { return problem_sizes[group_idx].k(); }; + + /// Total number of bytes loaded + int64_t bytes(library::GroupedGemmDescription const& operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::GroupedGemmDescription const& operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult& result, + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space); + }; + + struct BlockScalingWorkspace { + // host vector (per L2 workspace) of device vectors (per group) of device pointers + std::vector SFA_ptr_array_device; + std::vector SFB_ptr_array_device; + std::vector SFC_ptr_array_device; + std::vector SFD_ptr_array_device; + + // host vector (per group) of device tensors + // (where each batch of device allocation is for a L2 workspace) + std::vector SFA_ptr_array_host; + std::vector SFB_ptr_array_host; + std::vector SFC_ptr_array_host; + std::vector SFD_ptr_array_host; + std::vector SFD_reference_ptr_array_host; + + // matrix wide constant, not per-batch or per-group + DeviceAllocation* norm_constant; + }; + + // workspace contains the allocated blocks, arguments just contain the raw + // pointers + struct GroupedGemmWorkspace { + + // host vector (per L2 workspace) of device vectors (per group) of device pointers + std::vector A_ptr_array_device; + std::vector B_ptr_array_device; + std::vector C_ptr_array_device; + std::vector D_ptr_array_device; + std::vector reference_ptr_array_host; + + // host vector (per group) of device tensors + // (where each batch of device allocation is for a L2 workspace) + std::vector A_ptr_array_host; + std::vector B_ptr_array_host; + std::vector C_ptr_array_host; + std::vector D_ptr_array_host; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + /// *NOT* the number of groups in the grouped GEMM (we use `num_groups` in the profiler) + int problem_count{1}; + + DeviceAllocation* problem_sizes_array_device{nullptr}; + DeviceAllocation* problem_sizes_3x_array_device{nullptr}; + DeviceAllocation* lda_array_device{nullptr}; + DeviceAllocation* ldb_array_device{nullptr}; + DeviceAllocation* ldc_array_device{nullptr}; + DeviceAllocation* ldd_array_device{nullptr}; + + std::optional block_scales; + + library::GemmGroupedConfiguration configuration; + library::GroupedGemmBlockScaledArguments arguments; + + std::vector host_workspace; + DeviceAllocation device_workspace; + + cudaStream_t stream; + }; + +private: + void init_arguments(Options const& options) { + auto& arguments = gemm_workspace_.arguments; + // these get updated in each profiler run to ensure L2 cycling + arguments.ptr_A = gemm_workspace_.A_ptr_array_device[0]->data(); + arguments.ptr_B = gemm_workspace_.B_ptr_array_device[0]->data(); + arguments.ptr_C = gemm_workspace_.C_ptr_array_device[0]->data(); + arguments.ptr_D = gemm_workspace_.D_ptr_array_device[0]->data(); + + arguments.alpha = problem_.alpha.data(); + arguments.beta = problem_.beta.data(); + arguments.pointer_mode = library::ScalarPointerMode::kHost; + arguments.lda = static_cast(gemm_workspace_.lda_array_device->data()); + arguments.ldb = static_cast(gemm_workspace_.ldb_array_device->data()); + arguments.ldc = static_cast(gemm_workspace_.ldc_array_device->data()); + arguments.ldd = static_cast(gemm_workspace_.ldc_array_device->data()); + arguments.problem_sizes = + static_cast(gemm_workspace_.problem_sizes_array_device->data()); + arguments.problem_sizes_3x = static_cast*>( + gemm_workspace_.problem_sizes_3x_array_device->data()); + gemm_workspace_.arguments.problem_sizes_3x_host = problem_.problem_sizes_3x.data(); + gemm_workspace_.arguments.problem_count = problem_.problem_sizes.size(); + gemm_workspace_.arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}; + gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + arguments.sm_count = options.device.get_sm_count(0); + if (is_block_scaled) { + auto& block_scaled_ws = gemm_workspace_.block_scales.value(); + arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); + arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data(); + arguments.SFD = block_scaled_ws.SFD_ptr_array_device[0]->data(); + arguments.norm_constant = block_scaled_ws.norm_constant->data(); + } + else if (is_blockwise) { + auto& block_scaled_ws = gemm_workspace_.block_scales.value(); + arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); + arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data(); + } + } + +protected: + /// GEMM problem obtained from problem space + GroupedGemmProblem problem_; + + /// Device memory allocations + GroupedGemmWorkspace gemm_workspace_; + + bool is_block_scaled{false}; + bool is_blockwise{false}; + +public: + GroupedGemmOperationProfiler(Options const& options); + + virtual ~GroupedGemmOperationProfiler(); + + GroupedGemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream& out) const; + + /// Prints examples + virtual void print_examples(std::ostream& out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Measures performance results + virtual bool profile( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + +protected: + /// Initializes the performance result + void initialize_result_( + PerformanceResult& result, + Options const& options, + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space); + + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GroupedGemmWorkspace &gemm_workspace, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration for exploration parameters + void update_workspace_and_result_( + GroupedGemmWorkspace &gemm_workspace, + PerformanceResult &result, + ProblemSpace const &problem_space, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult& result, + Options const& options, + library::Operation const* operation, + void* arguments, + void* host_workspace, + void* device_workspace) override; + + /// Method to profile a CUTLASS Operation for the best configuration for a fixed shape + bool profile_cutlass_for_fixed_shape_( + Options const& options, + library::Operation const* operation, + ProblemSpace const& problem_space); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..446ef2c16739b28aaf038ca62bad6e3cdf667813 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h @@ -0,0 +1,287 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include +#include +#include +#include + +// CUTLASS includes +#include "cutlass/trace.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "performance_result.h" +#include "performance_report.h" +#include "problem_space.h" +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class OperationProfiler { +public: + + +protected: + // + // Data members + // + + /// Top-level operation kind + library::OperationKind kind_; + + /// Human readable description + std::string description_; + + /// Arguments parsed from command line + ArgumentDescriptionVector arguments_; + + /// List of providers used to verify and compare each result + ProviderVector verification_providers_; + + /// Model performance result initialized by the operation profiler with workload statistics + /// and reasonable default state. + PerformanceResult model_result_; + + /// Performance result vector constructed by profiling the operation + PerformanceResultVector results_; + +public: + + // + // Methods + // + + /// Ctor + OperationProfiler(); + + OperationProfiler( + Options const &options, + library::OperationKind kind, + ArgumentDescriptionVector const &arguments = ArgumentDescriptionVector(), + ProviderVector const & verification_providers = ProviderVector()); + + /// Destructor + virtual ~OperationProfiler(); + + /// Obtains the operation kind + library::OperationKind kind() const { return kind_; } + + /// Gets the schema description + std::string const &description() const; + + /// Returns a reference to the arguments + ArgumentDescriptionVector const &arguments() const { return arguments_; } + +public: + + // + // Basic overrides + // + + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const =0; + + /// Entry point to profile all operations in the manifest + virtual int profile_all( + Options const &options, + library::Manifest const &manifest, + DeviceContext &device_context); + +public: + + // + // Operation-specific phases of verification and profiling + // + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + +public: + + // + // Static helpers + // + + /// Sleep for a given duration in ms + static void sleep(int sleep_duration); + + /// Returns true if the current operation description satisfies the problem space + static bool satisfies( + library::OperationDescription const &op_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Compares tensors for equality + static Disposition compare_tensors( + Options const &options, + DeviceAllocation &experimental, + DeviceAllocation &reference, + int64_t count = 0); + + static void save_workspace( + DeviceContext &device_context, + Options const &options, + library::OperationDescription const &desc, + library::Provider provider, + library::Provider verification_provider = library::Provider::kInvalid); + + /// Helper to set a performance result member + static void set_argument( + PerformanceResult &result, + char const *name, + ProblemSpace const &problem_space, + std::string const &value); + + /// Helper to set a performance result member + static void set_argument( + PerformanceResult &result, + char const *name, + ProblemSpace const &problem_space, + int64_t value); + +protected: + + /// Sets operation description + static void initialize_result_( + PerformanceResult &result, + library::OperationDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Profiles the GPU kernel launched in `func` running simultaneously on all + /// requested devices. + Status profile_kernel_w_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams); + + Status profile_kernel_( + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams); + + /// Profiles the GPU kernel launched in `func` on the `stream` + Status profile_kernel_( + PerformanceResult& result, + Options const& options, + std::function const& func, + cudaStream_t stream = nullptr); + + /// Profiles the GPU kernel launched in `func` on the `stream` + Status profile_kernel_no_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, + cudaStream_t stream = nullptr); + +private: + /// finds string matches filter_string in operation_name + bool find_string_matches_( + std::string const &filter_string, + std::string const &operation_name); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Vector of owning operation profilers +using OperationProfilerVector = std::vector>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h new file mode 100644 index 0000000000000000000000000000000000000000..1a957b36eea35f7c0a5366645c3a62298ca56dea --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Command line options for performance test program +*/ + +#pragma once + +#include +#include +#include + +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/library/library.h" + +#include "enumerated_types.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Global options +class Options { +public: + + /// Cublas and cuDNN options + struct Library { + + // + // Data members + // + + /// Algorithm mode + AlgorithmMode algorithm_mode; + + /// Algorithm enumerants + std::vector algorithms; + + // + // Methods + // + + explicit Library(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + }; + + /// Options related to the selected device + struct Device { + + /// Device ID + std::vector devices; + + /// Number of total devices + /// This is not set by the user, it is set by automatically + int num_devices; + + /// CUDA Device properties + std::vector properties; + + /// Total memory allocation on each device + size_t maximum_capacity; + + private: + /// SM Count + /// Limits the number of SMs to use on each device + int sm_count; + + // + // Methods + // + public: + explicit Device(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + void print_device_info(std::ostream &out) const; + + /// Returns the device ID from a device index + int device_id(size_t device_index) const; + + /// Returns the sm_count if set, otherwise returns the number of SMs on the device + int get_sm_count(int device_index) const; + + /// Returns the compute capability of the listed devices (e.g. 70, 75, 80, etc.) + int compute_capability(int device_index) const; + }; + + /// Options related to initializing input tensors + struct Initialization { + + /// If true, data is initialized randomly. If false, no initialization is performed after + /// allocating tensors. + bool enabled; + + /// If true, data distribution is set by the user and is not allowed to change + /// If false, data distribution is allowed to change based on element_type (library::NumericTypeID) + bool fix_data_distribution; + + /// Data distribution for input tensors + Distribution data_distribution; + + /// Source of random tensor elements + library::Provider provider; + + /// Random number generator seed. + int seed; + + // + // Methods + // + + explicit Initialization(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Helper to parse a Distribution object from the command line parser + static void get_distribution( + cutlass::CommandLine const &args, + std::string const &arg, + cutlass::Distribution &dist); + }; + + /// Options related to verification of the result + struct Verification { + + // + // Data members + // + + /// If true, kernels are verified before they are profiled + bool enabled; + + /// If true, causes profiler to return an error code if no reference check is run. + /// Only valid when verification is enabled. + bool required; + + /// Relative error threshold - zero to require bit-level consistency + double epsilon; + + /// Values smaller than this are assumed to be zero + double nonzero_floor; + + /// List of providers used to verify each result + ProviderVector providers; + + /// Indicates when to save the workspace + SaveWorkspace save_workspace; + + // + // Methods + // + + explicit Verification(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Returns true if a provider is enabled + bool provider_enabled(library::Provider provider) const; + + /// Returns the index of a provider if its enabled + size_t index(library::Provider provider) const; + }; + + /// Options related to profiling + struct Profiling { + + /// Number of workspaces to rotate through to avoid cache-resident working sets + int workspace_count{0}; + + /// Number of iterations to warmup each kernel prior to profiling + int warmup_iterations{10}; + + /// Number of iterations to profile each kernel - if 0, kernels are launched up to the profiling duration + /// This will always override profiling-duration and min-iterations. + int iterations{100}; + + /// Time to spend profiling each kernel (ms) + int duration{10}; + + /// Minimum number of iterations to profile + int min_iterations{10}; + + /// If true, profiling with cuda graph enabled. + bool use_cuda_graphs{false}; + + /// If enabled, the CUTLASS profiler searches for the best-performing kernel + /// within the subset of kernels matching a kernel filter regex. The best + /// performance is determined by screening over a set of predefined M/N/K + /// sizes and performance-related parameters, including cluster shapes, + /// swizzle sizes, and rasterization orders. + /// For now, it only supports legacy GEMM and blockscaled GEMM. + bool enable_kernel_performance_search{false}; + + /// If enabled, the CUTLASS profiler searches for the best-performing kernel + /// for a given M/N/K problem size by evaluating various performance-related + /// parameters such as cluster shapes, swizzle sizes, and rasterization orders. + /// For now, it only supports legacy GEMM and blockscaled GEMM. + bool enable_best_kernel_for_fixed_shape{false}; + + /// Number of ms to sleep between profiling periods (ms) + int sleep_duration{50}; + + /// If true, profiling is actually conducted. + bool enabled{true}; + + /// If true, profiling returns an error code if no kernels are found to match the filters. + bool error_on_no_match{false}; + + /// If true, profiling returns an error code if no kernel are profiled + // Sometimes the kernel matches but failed to profile (e.g. can_implement() error) + bool error_if_nothing_is_profiled{false}; + + /// List of providers of each functionality to be profiled + ProviderVector providers; + + // + // Methods + // + + explicit Profiling(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Returns true if a provider is enabled + bool provider_enabled(library::Provider provider) const; + + /// Returns the index of a provider if its enabled + size_t index(library::Provider provider) const; + }; + + /// Options related to reporting + struct Report { + + /// If true, result is appended to possibly existing file + bool append; + + /// Path to a file containing results + std::string output_path; + + /// Path to a file containing junit xml results + std::string junit_output_path; + + /// Sequence of tags to attach to each result + std::vector> pivot_tags; + + /// If true, reports status of all kernels including those that were + /// not run for the given arguments + bool report_not_run; + + /// Prints human-readable text to stdout. If false, nothing is written to stdout + bool verbose; + + /// Sort results by flops-per-byte + bool sort_flops_per_byte; + + /// Sort results by flops-per-second + bool sort_flops_per_sec; + + /// Prints the name of the kernel being profiled before running the kernel. + /// This is useful for determining which kernel is causing a run of the profiler to hang + bool print_kernel_before_running; + + // + // Methods + // + + explicit Report(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + }; + + /// Options related to printing usage and version information + struct About { + + /// If true, usage is printed and the program ends. + bool help; + + /// Prints version string + bool version; + + /// Print information about devices + bool device_info; + + // + // Methods + // + + explicit About(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + static void print_version(std::ostream &out); + }; + +public: + + // + // Data members + // + + /// Top-level execution mode + ExecutionMode execution_mode; + + /// Name of math function to profile + library::OperationKind operation_kind; + + /// Vector of operation name substrings + std::vector operation_names; + + /// Map of problems to run for each operation + /// [operation_name] -> vector of problems, each problem specified as a vector of [argument name] -> [argument value] + std::unordered_map> operation_problems; + + /// Vector of operation name substrings + std::vector excluded_operation_names; + + + // + // Detailed configuration options + // + + /// Configuration + CommandLine cmdline; + Device device; + Initialization initialization; + Library library; + Verification verification; + Profiling profiling; + Report report; + About about; + +public: + + explicit Options(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out) const; + + static std::string indent_str(int indent); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h new file mode 100644 index 0000000000000000000000000000000000000000..07102c99bc0f38a071e1ab828aab30678a3e2d44 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h @@ -0,0 +1,128 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Class performing output during profiling +*/ + +#pragma once + +#include +#include + +// CUTLASS Profiler includes +#include "options.h" +#include "enumerated_types.h" +#include "performance_result.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +class PerformanceReport { +private: + + /// Reference to options + Options const &options_; + + /// Operation kind + library::OperationKind op_kind_; + + /// Operation file name containing performance report of op_kind + std::string op_file_name_; + + /// Output file containing results + std::ofstream output_file_; + + /// Operation file name containing junit performance report of op_kind + std::string op_junit_file_name_; + + /// Output file containing junit results + std::ofstream junit_output_file_; + + /// Flag indicating the performance report is valid + bool good_; + + /// Vector of argument names + std::vector argument_names_; + + /// Counter uniquely identifying problem within the report + size_t problem_index_; + + /// Collection of all results + PerformanceResultVector concatenated_results_; + +public: + + PerformanceReport(Options const &options, std::vector const &argument_names, library::OperationKind const &op_kind); + ~PerformanceReport(); + + bool good() const { return good_; } + + void next_problem(); + void append_result(PerformanceResult result); + void sort_flops_per_byte(PerformanceResultVector &results); + void sort_flops_per_sec(PerformanceResultVector &results); + void append_results(PerformanceResultVector const &results); + +public: + + /// Prints the CSV header + std::ostream & print_csv_header_(std::ostream &out); + + /// Prints the CSV + std::ostream & print_result_csv_(std::ostream &out, PerformanceResult const &result); + + /// @defgroup jUnit Result Generation + /// Functions related to generation of the jUnit results + /// @{ + + std::ostream & print_junit_header_(std::ostream &out); + std::ostream & print_junit_result_(std::ostream &out, PerformanceResult const &result); + std::ostream & print_junit_footer_(std::ostream &out); + + /// @} + + /// Prints the result in human readable form + std::ostream & print_result_pretty_( + std::ostream &out, + PerformanceResult const &result, + bool use_shell_coloring = true); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h new file mode 100644 index 0000000000000000000000000000000000000000..986ac89bc86a267ce8fb181a986f28f3f0936566 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h @@ -0,0 +1,137 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +// CUTLASS Profiler includes +#include "enumerated_types.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performance result object +struct PerformanceResult { + + /// Index of problem + size_t problem_index; + + /// library::Provider + library::Provider provider; + + /// Operation kind + library::OperationKind op_kind; + + /// CUTLASS status result from kernels (success or failure) + // Status does information on verification + Status status; + + /// Outcome of verification (worst case verification result) + Disposition disposition; + + /// Outcome of verification (all verification results) + DispositionMap verification_map; + + /// Operation name + std::string operation_name; + + /// Stringified vector of argument values + std::vector > arguments; + + /// Number of bytes read or written + int64_t bytes; + + /// Number of DL flops performed by the math function + int64_t flops; + + /// Average runtime in ms + double runtime; + + /// Average runtime in ms per device + std::vector runtime_vector; + + // + // Members + // + + /// Ctor + PerformanceResult(): + problem_index(0), + op_kind(library::OperationKind::kInvalid), + provider(library::Provider::kInvalid), + disposition(Disposition::kNotRun), + status(Status::kInvalid), + bytes(0), + flops(0), + runtime(0) + { } + + // Copy constructor for deep copy + PerformanceResult(const PerformanceResult& other) = default; + + // Explicitly define copy assignment operator + PerformanceResult& operator=(const PerformanceResult& other) = default; + + /// Returns true if the runtime is valid + bool good() const { + return runtime > 0; + } + + /// Math throughput in units of GFLOP/s + double gflops_per_sec() const { + return double(flops) / runtime / 1.0e6; + } + + /// memory bandwidth in units of GiB/s + double gbytes_per_sec() const { + return double(bytes) / double(1 << 30) / runtime * 1000.0; + } + +}; + +using PerformanceResultVector = std::vector; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h new file mode 100644 index 0000000000000000000000000000000000000000..9bdbec657c10cff0dafebd2cb6cd52057f3695c9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h @@ -0,0 +1,1039 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief + + "Any sufficiently complicated C or Fortran program contains an ad-hoc, informally-specified, + bug-ridden, slow implementation of half of Common Lisp." + + - Greenspun's Tenth Rule of Programming + + + cutlass::profiler::ProblemSpace defines a set of data structures which represent the Cartesian + product of sequences defined by integer ranges, lists of scalars, and sets of enumerated types. + + These permit a single invocation of the CUTLASS Profiler to iterate over a large set of problems, + verify and profile various operations when they are compatible with the command line, and + construct data tables of results that are convenient inputs to post processing in Excel or Pandas. + + By executing multiple problems per invocation, startup overheads may be amortized across many + kernel launches. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include +#include +#include + +// CUTLASS Utility includes +#include "cutlass/util/command_line.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +// Profiler includes +#include "enumerated_types.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines the argument schema +struct ArgumentDescription { + + /// Type of argument + ArgumentTypeID type; + + /// Prioritized array of aliases used in command line parsing + std::vector aliases; + + /// Description of argument + std::string description; + + // + // Methods + // + + /// Default ctor + ArgumentDescription(): + type(ArgumentTypeID::kInvalid) { } + + /// Constructor with aliases + ArgumentDescription( + ArgumentTypeID type_, + std::vector const &aliases_, + std::string const &description_ + ): + type(type_), aliases(aliases_), description(description_) { } +}; + +/// Vector of arguments +using ArgumentDescriptionVector = std::vector; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Base class for kernel arguments +struct KernelArgument { + + // + // Type definitions + // + + /// Value base class + struct Value { + + KernelArgument const *argument; + bool not_null; + + // + // Methods + // + + Value( + KernelArgument const *argument_ = nullptr, + bool not_null_ = true + ): argument(argument_), not_null(not_null_) { } + + virtual ~Value() { } + + virtual std::ostream &print(std::ostream &out) const =0; + }; + + /// Abstract base class to iterate over values within arguments + struct ValueIterator { + + /// Indicates type of kernel argument + KernelArgument const *argument; + + /// If the iterator points to an argument that is null, it needs to be distinguished + /// from end. + bool null_argument; + + // + // Methods + // + + /// Constructs a value iterator - no methods are valid if argument_ == nullptr + ValueIterator( + KernelArgument const *argument_ = nullptr, + bool null_argument_ = false): + argument(argument_), null_argument(null_argument_) { + + if (!argument_->not_null()) { + null_argument = true; + } + } + + virtual ~ValueIterator() { } + + /// Advances to next point in range + virtual void operator++() = 0; + + /// Compares against another value iterator - must be of the same KernelArgument type + virtual bool operator==(ValueIterator const &it) const = 0; + + /// Returns a unique_ptr object pointing to a newly created value object + virtual std::unique_ptr at() const = 0; + + /// Gets the type of the iterator + ArgumentTypeID type() const { + return argument->description->type; + } + + /// Helper to compute inequality + bool operator!=(ValueIterator const &it) const { + return !(*this == it); + } + + std::ostream &print(std::ostream &out) const; + }; + + // + // Data members + // + + /// Describes the argument + ArgumentDescription const *description; + + /// Parent node + KernelArgument *parent; + + /// Sequence in which the kernel argument is to be iterated over. + /// Smaller means faster changing. -1 is don't care + int ordinal; + + // + // Methods + // + + /// Default ctor + KernelArgument( + ArgumentDescription const *description_ = nullptr, + KernelArgument *parent_ = nullptr, + int ordinal_ = -1 + ): description(description_), parent(parent_), ordinal(ordinal_) { } + + virtual ~KernelArgument(); + + /// Returns true if the kernel argument iself is empty + virtual bool not_null() const =0; + + /// Returns a string name for debugging + std::string qualified_name() const { + if (description) { + if (description->aliases.empty()) { + return ""; + } + return description->aliases.front(); + } + return ""; + } + + virtual std::unique_ptr begin() const =0; + virtual std::unique_ptr end() const =0; +}; + +using KernelArgumentVector = std::vector>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a scalar argument type as a string that is lexically cast to the appropriate kernel +/// type. +struct ScalarArgument : public KernelArgument { + + // + // Type definitions + // + + /// Value type + struct ScalarValue : public KernelArgument::Value { + + std::string value; + + // + // Methods + // + + ScalarValue( + std::string const &value_ = "", + ScalarArgument const *argument = nullptr, + bool not_null_ = true + ); + + virtual std::ostream &print(std::ostream &out) const; + }; + + using ValueCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct ScalarValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit ScalarValueIterator(ScalarArgument const *argument = nullptr); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + /// Set of possible values + ValueCollection values; + + // + // Methods + // + + /// Default ctor + explicit ScalarArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Closed range supporting additive increment +struct Range { + + // + // Type definitions + // + + enum class Mode { + kSequence, + kRandom, + kRandomLog2, + kInvalid + }; + + struct Iterator { + + int64_t value; + int64_t increment; + Range const *range; + + // + // Methods + // + + Iterator( + int64_t value_ = 0, + int64_t increment_ = 1, + Range const *range_ = nullptr + ): + value(value_), increment(increment_), range(range_) { } + + Iterator & operator++() { + value += increment; + return *this; + } + + Iterator operator++(int) { + Iterator self(*this); + ++(*this); + return self; + } + + bool operator==(Iterator const &it) const { + return value == it.value; + } + + bool operator!=(Iterator const &it) const { + return !(*this == it); + } + + static int64_t round(int64_t value, int64_t divisible) { + int64_t rem = (value % divisible); + + // Round either up or down + if (rem > divisible / 2) { + value += (divisible - rem); + } + else { + value -= rem; + } + + return value; + } + + int64_t at() const { + if (!range) { + return value; + } + + switch (range->mode) { + case Mode::kSequence: return value; + + case Mode::kRandom: { + double rnd = double(range->minimum) + + double(std::rand()) / double(RAND_MAX) * (double(range->maximum) - double(range->minimum)); + + int64_t value = int64_t(rnd); + + return round(value, range->divisible); + } + break; + + case Mode::kRandomLog2: { + double lg2_minimum = std::log(double(range->minimum)) / std::log(2.0); + double lg2_maximum = std::log(double(range->maximum)) / std::log(2.0); + double rnd = lg2_minimum + double(std::rand()) / double(RAND_MAX) * (lg2_maximum - lg2_minimum); + + int64_t value = int64_t(std::pow(2.0, rnd)); + + return round(value, range->divisible); + } + break; + default: break; + } + return value; + } + + int64_t operator*() const { + return at(); + } + }; + + // + // Data members + // + + int64_t first; ///< first element in range + int64_t last; ///< last element in range + int64_t increment; ///< additive increment between values + + Mode mode; ///< mode selection enables alternative values + int64_t minimum; ///< minimum value to return + int64_t maximum; ///< maximum value to return + int64_t divisible; ///< rounds value down to an integer multiple of this value + + // + // Methods + // + + /// Default constructor - range acts as a scalar + Range(int64_t first_ = 0): first(first_), last(first_), increment(1), mode(Mode::kSequence), minimum(0), maximum(0), divisible(1) { } + + /// Range acts as a range + Range( + int64_t first_, + int64_t last_, + int64_t increment_ = 1, + Mode mode_ = Mode::kSequence, + int64_t minimum_ = 0, + int64_t maximum_ = 0, + int64_t divisible_ = 1 + ): first(first_), last(last_), increment(increment_), mode(mode_), minimum(minimum_), maximum(maximum_), divisible(divisible_) { + + // Helpers to avoid constructing invalid ranges + if (increment > 0) { + if (last < first) { + std::swap(last, first); + } + } + else if (increment < 0) { + if (first < last) { + std::swap(last, first); + } + } + else if (last != first) { + last = first; + increment = 1; + } + } + + /// Helper to construct a sequence range + static Range Sequence(int64_t first_, int64_t last_, int64_t increment_ = 1) { + return Range(first_, last_, increment_, Mode::kSequence); + } + + /// Helper to construct a range that is a random distribution + static Range Random(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { + return Range(1, count_, 1, Mode::kRandom, minimum_, maximum_, divisible_); + } + + /// Helper to construct a range that is a random distribution over a log scale + static Range RandomLog2(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { + return Range(1, count_, 1, Mode::kRandomLog2, minimum_, maximum_, divisible_); + } + + /// Returns an iterator to the first element within the range + Iterator begin() const { + return Iterator(first, increment, this); + } + + /// Returns an iterator to the first element *after* the range + Iterator end() const { + return Iterator(first + ((last - first)/increment + 1) * increment, increment, this); + } +}; + +/// Integer-valued argument - represented as a list of integer-valued ranges +struct IntegerArgument : public KernelArgument { + + // + // Type definitions + // + + /// Value type + struct IntegerValue : public KernelArgument::Value { + + int64_t value; + + // + // Methods + // + + IntegerValue( + int64_t value_ = 0, + IntegerArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + /// Collection of ranges represent the IntegerArgument's state + using RangeCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct IntegerValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + RangeCollection::const_iterator range_it; + Range::Iterator value_it; + + // + // Methods + // + + IntegerValueIterator(); + IntegerValueIterator(IntegerArgument const *argument); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + /// Set of possible values + RangeCollection ranges; + + // + // Methods + // + + /// Default ctor + IntegerArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + bool _not_null = !ranges.empty(); + return _not_null; + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure defining the data type of tensors +struct TensorArgument : public KernelArgument { + + // + // Type definitions + // + + struct TensorDescription { + + /// Data type of elements + library::NumericTypeID element; + + /// Layout definition + library::LayoutTypeID layout; + + /// Computed extent + std::vector extent; + + /// Enables directly specifying stride value used to size tensor + std::vector stride; + + // + // Methods + // + + TensorDescription( + library::NumericTypeID element_ = library::NumericTypeID::kUnknown, + library::LayoutTypeID layout_ = library::LayoutTypeID::kUnknown, + std::vector extent_ = std::vector(), + std::vector stride_ = std::vector() + ): + element(element_), layout(layout_), extent(extent_), stride(stride_) {} + }; + + using ValueCollection = std::vector; + + /// Value structure + struct TensorValue : public KernelArgument::Value { + + TensorDescription desc; + + // + // Methods + // + + TensorValue( + TensorDescription const &desc_ = TensorDescription(), + TensorArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + /// Abstract base class to iterate over values within arguments + struct TensorValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit TensorValueIterator(TensorArgument const *argument_); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + /// Set of possible values + ValueCollection values; + + // + // Methods + // + + /// Default ctor + explicit TensorArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Numeric data type +struct EnumeratedTypeArgument : public KernelArgument { + + // + // Type definitions + // + + struct EnumeratedTypeValue : public KernelArgument::Value { + + /// Data type of element + std::string element; + + // + // Methods + // + + EnumeratedTypeValue( + std::string const &element_ = std::string(), + EnumeratedTypeArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + using ValueCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct EnumeratedTypeValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit EnumeratedTypeValueIterator(EnumeratedTypeArgument const *argument_ = nullptr); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + ValueCollection values; + + // + // Members + // + + /// Default ctor + explicit EnumeratedTypeArgument(ArgumentDescription const *description): + KernelArgument(description) {} + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Object storing the space argument values +class ProblemSpace { +public: + + /// Tuple of arguments + using Problem = std::vector>; + + /// Type used to iterator over things + using IteratorVector = std::vector>; + + /// Iterates over points in the design space + class Iterator { + private: + + /// One iterator per argument + IteratorVector iterators; + + public: + + // + // Methods + // + + explicit Iterator(); + Iterator(ProblemSpace const &problem_space); + Iterator(Iterator &&it); + + // Rule of three + Iterator(Iterator const &) = delete; + Iterator &operator=(Iterator const &it) = delete; + ~Iterator() = default; + + /// Pre-increment - advances to next point in argument range + void operator++(); + + /// Gets the current argument value + Problem at() const; + + /// Moves iterator to end + void move_to_end(); + + /// Equality operator + bool operator==(Iterator const &it) const; + + /// Inequality operator + bool operator!=(Iterator const &it) const { + return !(*this == it); + } + + /// Helper to call at() method + Problem operator*() const { + return at(); + } + + /// Helper to print iterator state + std::ostream & print(std::ostream &out) const; + + private: + + /// Helper for recursively constructing iterators + void construct_(KernelArgument const *argument); + }; + +public: + + // + // Data members + // + + KernelArgumentVector arguments; + + /// Map of argument names to their position within the argument vector + std::unordered_map argument_index_map; + +public: + + // + // Methods + // + + /// Default ctor + ProblemSpace() = default; + + /// Constructs a problem space from a vector of arguments. This vector must outlive + /// the ProblemSpace object, which stores pointers to objects within the + /// ArgumentDescriptionVector. + ProblemSpace(ArgumentDescriptionVector const &schema, CommandLine const &cmdline); + + Iterator begin() const; // returns an iterator to the first point in the range + Iterator end() const; // returns an iterator to the first point after the range + + /// Returns the index of an argument by name + size_t argument_index(char const *name) const; + + /// Gets all argument names as an ordered vector + std::vector argument_names() const; + + /// Returns the number of dimensions of the problem space + size_t rank() const { return arguments.size(); } + +private: + + /// Helper for recursively cloning + void clone_( + KernelArgumentVector &kernel_args, + ArgumentDescription const *arg_desc); + + /// Parses command line argument + void parse_( + KernelArgument *arg, + CommandLine const &cmdline); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexically casts an argument to an int if it is defined. Returns true if not null. +bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int( + int &int_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int( + int64_t &int_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +bool arg_as_bool(bool &bool_value, KernelArgument::Value const *value_ptr); + +bool arg_as_bool(bool &bool_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_NumericTypeID(library::NumericTypeID &numeric_type, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_NumericTypeID( + library::NumericTypeID &numeric_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_LayoutTypeID(library::LayoutTypeID &layout_type, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_LayoutTypeID( + library::LayoutTypeID &layout_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_OpcodeClassID(library::OpcodeClassID &opcode_class, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_OpcodeClassID( + library::OpcodeClassID &opcode_class, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID(library::SplitKMode &split_k_mode, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID( + library::SplitKMode &split_k_mode, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ConvModeID(library::ConvModeID &conv_mode, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ConvModeID( + library::ConvModeID &conv_mode, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_IteratorAlgorithmID(library::IteratorAlgorithmID &iterator_algorithm, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_IteratorAlgorithmID( + library::IteratorAlgorithmID &iterator_algorithm, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype(library::RuntimeDatatype &runtime_datatype, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype( + library::RuntimeDatatype &runtime_datatype, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RasterOrder(library::RasterOrder &raster_order, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RasterOrder( + library::RasterOrder &raster_order, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ProviderID(library::Provider &provider, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ProviderID( + library::Provider &provider, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. +bool arg_as_scalar( + std::vector &bytes, + library::NumericTypeID numeric_type, + KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. +bool arg_as_scalar( + std::vector &bytes, + library::NumericTypeID numeric_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +bool arg_as_string( + std::string& arg, + char const* name, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + +/// Returns true if a tensor description satisfies a `tensor` value +bool tensor_description_satisfies( + library::TensorDescription const &tensor_desc, + TensorArgument::TensorValue const *value_ptr); + +/// Returns true if a tensor description satisfies a `tensor` value +bool tensor_description_satisfies( + library::TensorDescription const &tensor_desc, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Returns true if a conv kind satisfies the value +bool conv_kind_satisfies( + library::ConvKind const &conv_kind, + EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); + +/// Returns true if a conv kind satisfies the value +bool conv_kind_satisfies( + library::ConvKind const &conv_kind, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Returns true if a iterator algorithm satisfies the value +bool iterator_algorithm_satisfies( + library::IteratorAlgorithmID const &iterator_algorithm, + EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); + +/// Returns true if a iterator algorithm satisfies the value +bool iterator_algorithm_satisfies( + library::IteratorAlgorithmID const &iterator_algorithm, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ba47a6832077984c334a5467257a151735b088b3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class Rank2KOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct RankKProblem { + int64_t n; + int64_t k; + int64_t lda; + int64_t ldb; + int64_t ldc; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + RankKProblem(): + n(16), k(16), lda(0), ldc(0), + fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::RankKDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::RankKDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct RankKWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + RankKWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + RankKProblem problem_; + + /// Device memory allocations + RankKWorkspace rank_k_workspace_; + + +public: + // + // Methods + // + + /// Ctor + Rank2KOperationProfiler(Options const &options); + + /// Destructor + virtual ~Rank2KOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..fff190a7570cd5811c6e5de6284bf96e40c404b7 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h @@ -0,0 +1,227 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class RankKOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct RankKProblem { + int64_t n; + int64_t k; + int64_t lda; + int64_t ldc; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + RankKProblem(): + n(16), k(16), lda(0), ldc(0), + fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::RankKDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::RankKDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct RankKWorkspace { + + DeviceAllocation *A; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + RankKWorkspace(): + A(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + RankKProblem problem_; + + /// Device memory allocations + RankKWorkspace rank_k_workspace_; + + +public: + // + // Methods + // + + /// Ctor + RankKOperationProfiler(Options const &options); + + /// Destructor + virtual ~RankKOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..0c81ef4637175a6de1f44cedddf319436aaff24d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h @@ -0,0 +1,173 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for reduction operation + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class ReductionOperationProfiler : public OperationProfiler { +public: + + + /// Workspace used + struct ReductionWorkspace { + + /// Conv device allocations + DeviceAllocation *Workspace; + DeviceAllocation *Source; + DeviceAllocation *Destination; + DeviceAllocation *Reference; + + /// Library configuration and arguments + library::ReductionConfiguration configuration; + library::ReductionArguments arguments; + + /// Buffer used for the cutlass operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + ReductionWorkspace(): + Workspace(nullptr), Source(nullptr), Destination(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// Reduction problem obtained from problem space + MatrixCoord problem_; + + /// Device memory allocations + ReductionWorkspace conv_workspace_; + + +public: + // + // Methods + // + + /// Ctor + ReductionOperationProfiler(Options const &options); + + /// Destructor + virtual ~ReductionOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..60204d8c9d458ab12020a6492de23174739aa584 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h @@ -0,0 +1,214 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "gemm_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class SparseGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct SparseGemmProblem { + int64_t m; + int64_t n; + int64_t k; + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t lde; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + static int const sparse = 2; + // every 128b ElementA uses one elementE + int elements_per_128b; + + // + // Methods + // + + SparseGemmProblem(): + m(16), n(16), k(16), lda(0), ldb(0), ldc(0), lde(0), split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct SparseGemmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *E; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::SparseGemmConfiguration configuration; + library::SparseGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + SparseGemmWorkspace(): + A(nullptr), B(nullptr), C(nullptr), E(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + // GEMM problem + SparseGemmProblem problem_; + + /// Device memory allocations + SparseGemmWorkspace gemm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + SparseGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~SparseGemmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..94ded5e803bf914e5ae8c4ebb867cfe42ef829bc --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class SymmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct SymmProblem { + int64_t m; + int64_t n; + int64_t lda; + int64_t ldb; + int64_t ldc; + SideMode side_mode; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + SymmProblem(): + m(16), n(16), lda(0), ldb(0), ldc(0), + side_mode(SideMode::kInvalid), fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::SymmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::SymmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct SymmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::SymmConfiguration configuration; + library::SymmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + SymmWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + SymmProblem problem_; + + /// Device memory allocations + SymmWorkspace symm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + SymmOperationProfiler(Options const &options); + + /// Destructor + virtual ~SymmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..9f21dafa0ecc869840fdba0a9c4414a89bbf4a7d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class TrmmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct TrmmProblem { + int64_t m; + int64_t n; + int64_t lda; + int64_t ldb; + int64_t ldd; + SideMode side_mode; + FillMode fill_mode; + DiagType diag_type; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + TrmmProblem(): + m(16), n(16), lda(0), ldb(0), ldd(0), split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct TrmmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *D; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::TrmmConfiguration configuration; + library::TrmmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + TrmmWorkspace(): + A(nullptr), B(nullptr), D(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + TrmmProblem problem_; + + /// Device memory allocations + TrmmWorkspace trmm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + TrmmOperationProfiler(Options const &options); + + /// Destructor + virtual ~TrmmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c2727c989e645eca8e67a5d8d50391ced803cffa --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp @@ -0,0 +1,67 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +struct GPU_Clock +{ + GPU_Clock() { + cudaEventCreate(&start_); + cudaEventCreate(&stop_); + cudaEventRecord(start_); + } + + ~GPU_Clock() { + cudaEventDestroy(start_); + cudaEventDestroy(stop_); + } + + void start() { + cudaEventRecord(start_); + } + + float milliseconds() { + cudaEventRecord(stop_); + cudaEventSynchronize(stop_); + float time; + cudaEventElapsedTime(&time, start_, stop_); + return time; + } + + float seconds() { + return milliseconds() * float(1e-3); + } + + private: + cudaEvent_t start_, stop_; +}; diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h new file mode 100644 index 0000000000000000000000000000000000000000..c95bd1cbeb56cc566394b155ea7ac24f07c28162 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h @@ -0,0 +1,324 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * Utility for parsing command line arguments + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass { + +/****************************************************************************** + * command_line + ******************************************************************************/ + +/** + * Utility for parsing command line arguments + */ +struct CommandLine { + std::vector keys; + std::vector values; + std::vector args; + + /** + * Constructor + */ + CommandLine(int argc, const char** argv) { + using namespace std; + + for (int i = 1; i < argc; i++) { + string arg = argv[i]; + + if ((arg[0] != '-') || (arg[1] != '-')) { + args.push_back(arg); + continue; + } + + string::size_type pos; + string key, val; + if ((pos = arg.find('=')) == string::npos) { + key = string(arg, 2, arg.length() - 2); + val = ""; + } else { + key = string(arg, 2, pos - 2); + val = string(arg, pos + 1, arg.length() - 1); + } + + keys.push_back(key); + values.push_back(val); + } + } + + /** + * Constructor to represent a command line from a map of [argument] -> [value] + */ + CommandLine(std::unordered_map& arg_map) { + for (const auto& [key, value] : arg_map) { + keys.push_back(key); + values.push_back(value); + } + } + + /** + * Checks whether a flag "--" is present in the commandline + */ + bool check_cmd_line_flag(const char* arg_name) const { + using namespace std; + + for (int i = 0; i < int(keys.size()); ++i) { + if (keys[i] == string(arg_name)) return true; + } + return false; + } + + /** + * Returns number of naked (non-flag and non-key-value) commandline parameters + */ + size_t num_naked_args() const { + return args.size(); + } + + /** + * Print naked (non-flag and non-key-value) commandline parameters + */ + void print_naked_args(std::ostream &out) const { + for (auto arg : args) { + out << " " << arg <<"\n"; + } + } + + /** + * Returns the commandline parameter for a given index (not including flags) + */ + template + void get_cmd_line_argument(size_t index, value_t& val) const { + using namespace std; + if (index < args.size()) { + istringstream str_stream(args[index]); + str_stream >> val; + } + } + + /** + * Obtains the boolean value specified for a given commandline parameter --= + */ + void get_cmd_line_argument(const char* arg_name, bool& val, bool _default) const { + val = _default; + if (check_cmd_line_flag(arg_name)) { + std::string value; + get_cmd_line_argument(arg_name, value); + + val = !(value == "0" || value == "false"); + } + } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val) const { + + get_cmd_line_argument(arg_name, val, val); + } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val, + value_t const& _default) const { + using namespace std; + + val = _default; + + for (int i = 0; i < int(keys.size()); ++i) { + if (keys[i] == string(arg_name)) { + istringstream str_stream(values[i]); + str_stream >> val; + } + } + } + + /** + * Returns the values specified for a given commandline parameter --=,* + */ + template + void get_cmd_line_arguments(const char* arg_name, + std::vector& vals, + char sep = ',') const { + using namespace std; + + if (check_cmd_line_flag(arg_name)) { + // Clear any default values + vals.clear(); + + // Recover from multi-value string + for (size_t i = 0; i < keys.size(); ++i) { + if (keys[i] == string(arg_name)) { + string val_string(values[i]); + separate_string(val_string, vals, sep); + } + } + } + } + + /** + * Returns the values specified for a given commandline parameter + * --=,* + */ + void get_cmd_line_argument_pairs(const char* arg_name, + std::vector >& tokens, + char delim = ',', + char sep = ':') const { + if (check_cmd_line_flag(arg_name)) { + std::string value; + get_cmd_line_argument(arg_name, value); + + tokenize(tokens, value, delim, sep); + } + } + + /** + * Returns a list of ranges specified for a given commandline parameter + * --=,* + */ + void get_cmd_line_argument_ranges(const char* arg_name, + std::vector >& vals, + char delim = ',', + char sep = ':') const { + std::vector ranges; + get_cmd_line_arguments(arg_name, ranges, delim); + + for (std::vector::const_iterator range = ranges.begin(); + range != ranges.end(); ++range) { + + std::vector range_vals; + separate_string(*range, range_vals, sep); + vals.push_back(range_vals); + } + } + + /** + * The number of pairs parsed + */ + int parsed_argc() const { return (int)keys.size(); } + + //------------------------------------------------------------------------- + // Utility functions + //------------------------------------------------------------------------- + + /// Tokenizes a comma-delimited list of string pairs delimited by ':' + static void tokenize(std::vector >& tokens, + std::string const& str, + char delim = ',', + char sep = ':') { + // Home-built to avoid Boost dependency + size_t s_idx = 0; + size_t d_idx = std::string::npos; + while (s_idx < str.size()) { + d_idx = str.find_first_of(delim, s_idx); + + size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size()); + size_t sep_idx = str.find_first_of(sep, s_idx); + size_t offset = 1; + if (sep_idx == std::string::npos || sep_idx >= end_idx) { + sep_idx = end_idx; + offset = 0; + } + + std::pair item( + str.substr(s_idx, sep_idx - s_idx), + str.substr(sep_idx + offset, end_idx - sep_idx - offset)); + + tokens.push_back(item); + s_idx = end_idx + 1; + } + } + + /// Tokenizes a comma-delimited list of string pairs delimited by ':' + static void tokenize(std::vector& tokens, + std::string const& str, + char delim = ',', + char sep = ':') { + typedef std::vector > TokenVector; + typedef TokenVector::const_iterator token_iterator; + + std::vector > token_pairs; + tokenize(token_pairs, str, delim, sep); + for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) { + tokens.push_back(tok->first); + } + } + + template + static void separate_string(std::string const& str, + std::vector& vals, + char sep = ',') { + std::istringstream str_stream(str); + std::string::size_type old_pos = 0; + std::string::size_type new_pos = 0; + + // Iterate -delimited values + value_t val; + while ((new_pos = str.find(sep, old_pos)) != std::string::npos) { + if (new_pos != old_pos) { + str_stream.width(new_pos - old_pos); + str_stream >> val; + vals.push_back(val); + } + + // skip over delimiter + str_stream.ignore(1); + old_pos = new_pos + 1; + } + + // Read last value + str_stream >> val; + vals.push_back(val); + } +}; + +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8ace1e0a232ea7cccbb2089ec8432783c49410dd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp @@ -0,0 +1,528 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +//-- BLAM_DEBUG_OUT --------------------------------------------------------- +#ifdef BLAM_DEBUG +# include +# ifndef BLAM_DEBUG_OUT +# define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl +# define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl +# endif // BLAM_DEBUG_OUT +#else +# ifndef BLAM_DEBUG_OUT +# define BLAM_DEBUG_OUT(msg) +# define BLAM_DEBUG_OUT_2(msg) +# endif // BLAM_DEBUG_OUT +#endif // BLAM_DEBUG + +// User could potentially define ComplexFloat/ComplexDouble instead of std:: +#ifndef BLAM_COMPLEX_TYPES +#define BLAM_COMPLEX_TYPES 1 +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(complex) + +namespace blam { +template +using Complex = cuda::std::complex; +using ComplexFloat = cuda::std::complex; +using ComplexDouble = cuda::std::complex; +} +#endif // BLAM_COMPLEX_TYPES + +// User could potentially define Half instead of cute:: +#ifndef BLAM_HALF_TYPE +#define BLAM_HALF_TYPE 1 +#include +namespace blam { +using Half = cute::half_t; +} +#endif // BLAM_HALF_TYPE + +namespace blam +{ +namespace cublas +{ + +inline const char* +cublas_get_error(cublasStatus_t status) +{ + switch (status) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized."; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library."; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function."; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture."; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed."; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute."; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed."; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported."; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing."; + default: + return "CUBLAS_ERROR -- "; + } +} + +inline bool +cublas_is_error(cublasStatus_t status) +{ + return status != CUBLAS_STATUS_SUCCESS; +} + + +// hgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* A, int ldA, + const Half* B, int ldB, + const Half* beta, + Half* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasHgemm"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), CUDA_R_16F, ldA, + reinterpret_cast(B), CUDA_R_16F, ldB, + reinterpret_cast(beta), + reinterpret_cast< __half*>(C), CUDA_R_16F, ldC, + CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// mixed hf gemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const Half* A, int ldA, + const Half* B, int ldB, + const float* beta, + float* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasGemmEx mixed half-float"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + alpha, + reinterpret_cast(A), CUDA_R_16F, ldA, + reinterpret_cast(B), CUDA_R_16F, ldB, + beta, + C, CUDA_R_32F, ldC, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// igemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const int32_t* alpha, + const int8_t* A, int ldA, + const int8_t* B, int ldB, + const int32_t* beta, + int32_t* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasIgemm"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + alpha, + A, CUDA_R_8I, ldA, + B, CUDA_R_8I, ldB, + beta, + C, CUDA_R_32I, ldC, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// sgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* A, int ldA, + const float* B, int ldB, + const float* beta, + float* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasSgemm"); + + return cublasSgemm(handle, transA, transB, + m, n, k, + alpha, + A, ldA, + B, ldB, + beta, + C, ldC); +} + +// dgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* A, int ldA, + const double* B, int ldB, + const double* beta, + double* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasDgemm"); + + return cublasDgemm(handle, transA, transB, + m, n, k, + alpha, + A, ldA, + B, ldB, + beta, + C, ldC); +} + +// cgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* A, int ldA, + const ComplexFloat* B, int ldB, + const ComplexFloat* beta, + ComplexFloat* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasCgemm"); + + return cublasCgemm(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, + reinterpret_cast(B), ldB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC); +} + +// zgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* A, int ldA, + const ComplexDouble* B, int ldB, + const ComplexDouble* beta, + ComplexDouble* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasZgemm"); + + return cublasZgemm(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, + reinterpret_cast(B), ldB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC); +} + +// hgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* A, int ldA, int loA, + const Half* B, int ldB, int loB, + const Half* beta, + Half* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasHgemmStridedBatched"); + + return cublasHgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast<__half*>(C), ldC, loC, + batch_size); +} + +// sgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* A, int ldA, int loA, + const float* B, int ldB, int loB, + const float* beta, + float* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasSgemmStridedBatched"); + + return cublasSgemmStridedBatched(handle, transA, transB, + m, n, k, + alpha, + A, ldA, loA, + B, ldB, loB, + beta, + C, ldC, loC, + batch_size); +} + +// dgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* A, int ldA, int loA, + const double* B, int ldB, int loB, + const double* beta, + double* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasDgemmStridedBatched"); + + return cublasDgemmStridedBatched(handle, transA, transB, + m, n, k, + alpha, + A, ldA, loA, + B, ldB, loB, + beta, + C, ldC, loC, + batch_size); +} + +// cgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* A, int ldA, int loA, + const ComplexFloat* B, int ldB, int loB, + const ComplexFloat* beta, + ComplexFloat* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasCgemmStridedBatched"); + + return cublasCgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC, loC, + batch_size); +} + +// zgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* A, int ldA, int loA, + const ComplexDouble* B, int ldB, int loB, + const ComplexDouble* beta, + ComplexDouble* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasZgemmStridedBatched"); + + return cublasZgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC, loC, + batch_size); +} + +// hgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* const A[], int ldA, + const Half* const B[], int ldB, + const Half* beta, + Half* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasHgemmBatched"); + + return cublasHgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(const_cast(A)), ldA, + // A, ldA, // cuBLAS 9.2 + reinterpret_cast(const_cast(B)), ldB, + // B, ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + reinterpret_cast<__half**>(const_cast(C)), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// sgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* const A[], int ldA, + const float* const B[], int ldB, + const float* beta, + float* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasSgemmBatched"); + + return cublasSgemmBatched(handle, transA, transB, + m, n, k, + alpha, + const_cast(A), ldA, + // A, ldA, // cuBLAS 9.2 + const_cast(B), ldB, + // B, ldB, // cuBLAS 9.2 + beta, + const_cast(C), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// dgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* const A[], int ldA, + const double* const B[], int ldB, + const double* beta, + double* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasDgemmBatched"); + + return cublasDgemmBatched(handle, transA, transB, + m, n, k, + alpha, + const_cast(A), ldA, + // A, ldA, // cuBLAS 9.2 + const_cast(B), ldB, + // B, ldB, // cuBLAS 9.2 + beta, + const_cast(C), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// cgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* const A[], int ldA, + const ComplexFloat* const B[], int ldB, + const ComplexFloat* beta, + ComplexFloat* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasCgemmBatched"); + + return cublasCgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + const_cast(reinterpret_cast(A)), ldA, + //reinterpret_cast(A), ldA, // cuBLAS 9.2 + const_cast(reinterpret_cast(B)), ldB, + //reinterpret_cast(B), ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + const_cast(reinterpret_cast(C)), ldC, + //reinterpret_cast(C), ldC, // cuBLAS 9.2 + batch_size); +} + +// zgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* const A[], int ldA, + const ComplexDouble* const B[], int ldB, + const ComplexDouble* beta, + ComplexDouble* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasZgemmBatched"); + + return cublasZgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + const_cast(reinterpret_cast(A)), ldA, + //reinterpret_cast(A), ldA, // cuBLAS 9.2 + const_cast(reinterpret_cast(B)), ldB, + //reinterpret_cast(B), ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + const_cast(reinterpret_cast(C)), ldC, + //reinterpret_cast(C), ldC, // cuBLAS 9.2 + batch_size); +} + +} // end namespace cublas +} // end namespace blam diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..88481a82e0e08f06b54c07c946d28160d41f9f07 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Contains code for debugging cutlass code +*/ + +#pragma once + +#include "device_dump.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/****************************************************************************** + * Debug and logging macros + ******************************************************************************/ + +/** + * Formats and prints the given message to stdout + */ +#if !defined(CUDA_LOG) +#if !defined(__CUDA_ARCH__) +#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__) +#else +#define CUDA_LOG(format, ...) \ + printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ + blockIdx.x, \ + blockIdx.y, \ + blockIdx.z, \ + threadIdx.x, \ + threadIdx.y, \ + threadIdx.z, \ + __VA_ARGS__); +#endif +#endif + +/** + * Formats and prints the given message to stdout only if DEBUG is defined + */ +#if !defined(CUDA_LOG_DEBUG) +#ifdef DEBUG +#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__) +#else +#define CUDA_LOG_DEBUG(format, ...) +#endif +#endif + +/** + * \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) + * along with the supplied source context. + * + * \return The CUDA error. + */ +__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error, + const char* expression, + const char* filename, + int line) { + (void)filename; + (void)line; + if (error) { +#if !defined(__CUDA_ARCH__) + fprintf( + stderr, "CUDA error %d [%s, %d] in expression '%s': %s\n", error, filename, line, expression, cudaGetErrorString(error)); + fflush(stderr); +#else + printf("CUDA error %d [%s, %d] in expression '%s'\n", error, filename, line, expression); +#endif + } + return error; +} + +/** + * \brief Perror macro + */ +#ifndef CUDA_PERROR +#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__) +#endif + +/** + * \brief Perror macro with exit + */ +#ifndef CUDA_PERROR_EXIT +#define CUDA_PERROR_EXIT(e) \ + do { if (cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)) { \ + exit(1); \ + } } while (0) +#endif + +/** + * \brief Perror macro only if DEBUG is defined + */ +#ifndef CUDA_PERROR_DEBUG +#ifdef DEBUG +#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e) +#else +#define CUDA_PERROR_DEBUG(e) (e) +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// A small helper class to dump a type at compile time +// Usage:: DumpType::Class +template +struct DebugType {}; + +template +void DebugTypeFunc(T const& t) { + T::t; +} + +// A small helper class to dump a compile time constant at compile time +// Usage: DumpValue::kConstant +template +struct DebugValue {}; diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h new file mode 100644 index 0000000000000000000000000000000000000000..a73a8cfe79dd22c2d298fcb3be8cf25d5e3f5734 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h @@ -0,0 +1,187 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/cutlass.h" + +/** + * \file + * \brief C++ interface to dump fragments and shared memory contents for + * debugging. + */ + +namespace cutlass { +namespace debug { + +/****************************************************************************** + * Dump the fragments + ******************************************************************************/ + +/// The first N threads dump the first M elements from their fragments with a +/// stride of S elements. If N is not specified, dump the data of all the +/// threads. If M is not specified, dump all the elements of the fragment. +template +CUTLASS_DEVICE void dump_fragment(Fragment const& frag, int N = 0, int M = 0, + int S = 1) { + int total_threads = blockDim.x * blockDim.y * blockDim.z; + int block_id = + blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + + (threadIdx.y * blockDim.x) + threadIdx.x; + + if (N < 0 || N > total_threads) { + if (thread_id == 0 && block_id == 0) + printf("Thread number N = %d should between [1, %d].\n", N, + total_threads); + + __syncthreads(); + + return; + } + + int total_elements = int(frag.size()); + + if (M < 0 || M > total_elements) { + if (thread_id == 0 && block_id == 0) + printf("Element number M = %d should between [1, %d].\n", M, + total_elements); + + __syncthreads(); + + return; + } + + if (N == 0) N = total_threads; + + if (M == 0) M = total_elements; + + if (S < 1 || S > M) { + if (thread_id == 0 && block_id == 0) + printf("Stride S = %d should between [1, %d].\n", S, M); + + __syncthreads(); + + return; + } + + if (thread_id == 0 && block_id == 0) + printf("\n*******************Dumping the fragments*******************\n\n"); + + CUTLASS_PRAGMA_NO_UNROLL + for (int tid = 0; tid < N; ++tid) { + if (tid == thread_id) { + printf("TB%d W%d T%d: ", block_id, tid / 32, tid & 31); + CUTLASS_PRAGMA_NO_UNROLL + for (int i = 0; i < M; i += S) { + printf("%.0f ", float(typename Fragment::value_type(frag[i]))); + } + printf("\n"); + } + + __syncthreads(); + } + + if (thread_id == 0 && block_id == 0) + printf("\n***********************************************************\n\n"); + + __syncthreads(); + + return; +} + +/****************************************************************************** + * Dump the shared memory + ******************************************************************************/ + +#define SHMEM_ROW_SIZE 128 + +/// Dump the shared memory contents. ptr is the begin address, size specifies +/// the number of elements that need to be dumped, and S specifies the stride. +template +CUTLASS_DEVICE void dump_shmem(Element const* ptr, size_t size, int S = 1) { + int block_id = + blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + + (threadIdx.y * blockDim.x) + threadIdx.x; + + if (ptr == nullptr) { + if (thread_id == 0 && block_id == 0) printf("ptr is null.\n"); + + __syncthreads(); + return; + } + + if (size < 1) { + if (thread_id == 0 && block_id == 0) + printf("Element size is less than 1\n"); + + __syncthreads(); + + return; + } + + int row_elements = SHMEM_ROW_SIZE / sizeof(Element); + + if (S < 1 || S > row_elements) { + if (thread_id == 0 && block_id == 0) + printf("Stride S = %d should between [1, %d].\n", S, row_elements); + + __syncthreads(); + + return; + } + + __syncthreads(); + + if (thread_id == 0) + printf("\n********Dumping the shared memory of TB %d*******\n\n", block_id); + + if (thread_id == 0) { + for (int i = 0; i < size; i += row_elements) { + for (int j = 0; j < row_elements; j += S) { + printf("%.0f ", float(ptr[i + j])); + } + + printf("\n"); + } + } + + if (thread_id == 0) + printf("\n***********************************************************\n\n"); + + __syncthreads(); + + return; +} +} // namespace debug +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..59457b2e8122f46e443844fe276b2c7fb35f3f56 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h @@ -0,0 +1,402 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C']. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do group norm on a device memory tensor with NHWC layout. + * \tparam T: data type + */ +template +void groupnorm(cutlass::Tensor4DCoord input_size, + const int num_groups, + const float eps, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream); + +extern __shared__ char groupnorm_shm[]; + +// For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory, +// we store the input in the shared memory. +// grid(num_groups, dim0) +// block(BLOCKSIZE) +// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group +template +__global__ void groupnorm_twopass_store_locally(T* output, + const T* input, + const T* gamma, + const T* beta, + int num_groups, + int prod_dim1_to_last_dim, + int last_dim, + const float eps, + const int TVecs_PER_THREAD) +{ + const int bid = blockIdx.y; // index of batch + const int gid = blockIdx.x; // index of group + const int tid = threadIdx.x; // index of thread + const int bdimx = blockDim.x; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int v_reduce_elements = s_reduce_elements / T_PER_TVec; + const int s_group_stride = last_dim / num_groups; + const int v_group_stride = s_group_stride / T_PER_TVec; + const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; + const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; + TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; + T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid; + float local_sum[1] = {0.0f}; + +// load from global memory into shared memory +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); + const int local_val_offset = i * T_PER_TVec; +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + local_sum[0] += tmp; + local_val[local_val_offset + j] = tmp_vec_ptr[j]; + } + } + } + __shared__ float s_mean, s_variance; + + // reduction for mean + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_mean = local_sum[0] / s_reduce_elements; + } + __syncthreads(); + + // reduction for std + local_sum[0] = 0.0f; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int local_val_offset = i * T_PER_TVec; +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(local_val[local_val_offset + j]); + tmp -= s_mean; + local_sum[0] += tmp * tmp; + } + } + } + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); + } + __syncthreads(); + + // normalize + const int gamma_offset_of_group = gid * v_group_stride; + const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; + const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; + const int local_val_offset = i * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; + TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; + T* gamma_val_ptr = (T*)(&gamma_val); + T* beta_val_ptr = (T*)(&beta_val); + TVec tmp_vec; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = (static_cast(local_val[local_val_offset + j]) - s_mean) * s_variance + * static_cast(gamma_val_ptr[j]) + + static_cast(beta_val_ptr[j]); + if (sizeof(T) == sizeof(half)) { + tmp_vec_ptr[j] = T(__float2half_rn(tmp)); + } + else { + tmp_vec_ptr[j] = T(tmp); + } + } + output_TVec_ptr[offset_in_group] = tmp_vec; + } + } +} + +// For large prod_dim1_to_last_dim/num_groups, +// in which the data cannot be stored locally, +// we will load from global memory multiple times, +// grid(num_groups, dim0) +// block(BLOCKSIZE) +// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group +template +__global__ void groupnorm_twopass_multiple_load(T* output, + const T* input, + const T* gamma, + const T* beta, + int num_groups, + int prod_dim1_to_last_dim, + int last_dim, + const float eps, + const int TVecs_PER_THREAD) +{ + const int bid = blockIdx.y; // index of batch + const int gid = blockIdx.x; // index of group + const int tid = threadIdx.x; // index of thread + const int bdimx = blockDim.x; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int v_reduce_elements = s_reduce_elements / T_PER_TVec; + const int s_group_stride = last_dim / num_groups; + const int v_group_stride = s_group_stride / T_PER_TVec; + const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; + const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; + TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; + float local_sum[1] = {0.0f}; + +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + local_sum[0] += tmp; + } + } + } + __shared__ float s_mean, s_variance; + + // reduction for mean + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_mean = local_sum[0] / s_reduce_elements; + } + __syncthreads(); + + // reduction for std + local_sum[0] = 0.0f; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + tmp -= s_mean; + local_sum[0] += tmp * tmp; + } + } + } + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); + } + __syncthreads(); + + // normalize + const int gamma_offset_of_group = gid * v_group_stride; + const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; + const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; + TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; + TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; + T* gamma_val_ptr = (T*)(&gamma_val); + T* beta_val_ptr = (T*)(&beta_val); + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); + TVec output_tmp_vec; + T* output_tmp_vec_ptr = (T*)(&output_tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = + (static_cast(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast(gamma_val_ptr[j]) + + static_cast(beta_val_ptr[j]); + if (sizeof(T) == sizeof(half)) { + output_tmp_vec_ptr[j] = T(__float2half_rn(tmp)); + } + else { + output_tmp_vec_ptr[j] = T(tmp); + } + } + output_TVec_ptr[offset_in_group] = output_tmp_vec; + } + } +} + +//ref_input & ref_output should be [N, H, W, C] +//ref_gamma & ref_beta should be [1, 1, 1, C] +template +void groupnorm(cutlass::Tensor4DCoord input_size, + const int num_groups, + const float eps, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream){ + const int N = input_size.n(); + const int H = input_size.h(); + const int W = input_size.w(); + const int C = input_size.c(); + if (C % num_groups != 0){ + printf("[ERROR] C should be a multiple of num_groups.\n"); + } + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* gamma = ref_gamma.data(); + const T* beta = ref_beta.data(); + + const int dim0 = N; + const int last_dim = C; + const int prod_dim1_to_last_dim = H*W*C; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int s_group_stride = last_dim / num_groups; + dim3 grid(num_groups, dim0); + int threadblock_size = 32; + if (s_group_stride % 2 == 0) { + const int T_PER_TVec = 2; + while (threadblock_size < 1024) { + if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) + break; + threadblock_size *= 2; + } + dim3 block(threadblock_size); + const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; + const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); + // for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32; + // the size of grid & block may have better choice for different cases. + // ensure shared memory is smaller than 48KB + if (std::is_same::value){ + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + else{ + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + } + else { + const int T_PER_TVec = 1; + while (threadblock_size < 1024) { + if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) + break; + threadblock_size *= 2; + } + dim3 block(threadblock_size); + const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; + const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + +} + +} //namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h new file mode 100644 index 0000000000000000000000000000000000000000..0fcbf5cb0f4bf3152a708c6e3845e89fd214cfac --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h @@ -0,0 +1,644 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do layernorm on a device memory tensor with RowMajor layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do layernorm on a device memory tensor with RowMajor layout. + * \tparam T: data type + */ +template +void layernorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream); + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e1(T* output, + const T* input, + const T* gamma, + const T* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + T local_val[ITEM_PER_THREAD]; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + const T zero = T(0.0f); + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + local_val[i] = index < n ? input[index] : zero; + local_sums[0] += static_cast(local_val[i]); + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + if (index < n){ + const float tmp = static_cast(local_val[i]) - s_mean; + local_sums[0] += tmp * tmp; + } + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + if (index < n) { + const T gamma_val = gamma[index]; + const T beta_val = beta[index]; + output[index] = T((static_cast(local_val[i]) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e2(T2* output, + const T2* input, + const T2* gamma, + const T2* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + T2 local_val[ITEM_PER_THREAD]; + const int n_2 = n / 2; + int offset = m_idx * n_2; + input += offset; + output += offset; + + const T2 zero = {T(0.0f), T(0.0f)}; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + local_val[i] = index < n_2 ? input[index] : zero; + local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_2){ + const float2 tmp = {static_cast(local_val[i].x) - s_mean, + static_cast(local_val[i].y) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; + } + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_2){ + const T2 gamma_val = gamma[index]; + const T2 beta_val = beta[index]; + T2 tmp; + tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + output[index] = tmp; + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*4 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e4(T4* output, + const T4* input, + const T4* gamma, + const T4* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + T4 local_val[ITEM_PER_THREAD]; + const int n_4 = n / 4; + int offset = m_idx * n_4; + input += offset; + output += offset; + + const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)}; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + local_val[i] = index < n_4 ? input[index] : zero; + local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y) + + static_cast(local_val[i].z) + static_cast(local_val[i].w); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_4){ + const float4 tmp = {static_cast(local_val[i].x) - s_mean, + static_cast(local_val[i].y) - s_mean, + static_cast(local_val[i].z) - s_mean, + static_cast(local_val[i].w) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z + tmp.w * tmp.w; + } + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_4){ + const T4 gamma_val = gamma[index]; + const T4 beta_val = beta[index]; + T4 tmp; + tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + tmp.z = T((static_cast(local_val[i].z) - s_mean)*s_variance*static_cast(gamma_val.z) + static_cast(beta_val.z)); + tmp.w = T((static_cast(local_val[i].w) - s_mean)*s_variance*static_cast(gamma_val.w) + static_cast(beta_val.w)); + output[index] = tmp; + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements +*/ +template +__global__ void layernorm_twoPassAlgo_e1(T* output, + const T* input, + const T* gamma, + const T* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_sums[0] += local_val; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_val = local_val - s_mean; + local_sums[0] += local_val * local_val; + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + for (int index = tid ; index < n ; index += bdimx){ + const T gamma_val = gamma[index]; + const T beta_val = beta[index]; + const T local_val = input[index]; + output[index] = T((static_cast(local_val) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_e2(T2* output, + const T2* input, + const T2* gamma, + const T2* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + const int n_2 = n / 2; + int offset = m_idx * n_2; + input += offset; + output += offset; + + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + local_sums[0] += static_cast(local_val.x) + static_cast(local_val.y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + const float2 tmp = {static_cast(local_val.x) - s_mean, + static_cast(local_val.y) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + const T2 gamma_val = gamma[index]; + const T2 beta_val = beta[index]; + T2 tmp; + tmp.x = T((static_cast(local_val.x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val.y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + output[index] = tmp; + } +} + +template +void layernorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream){ + const int m = tensor_size.row(); + const int n = tensor_size.column(); + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* gamma = ref_gamma.data(); + const T* beta = ref_beta.data(); + dim3 grid(m); + dim3 block((n + 31)/32*32); + if (block.x > 1024){ + block.x = 1024; + } + // TODO : There should be better configs for different cases, we only use several samples to show how to use here + // TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels. + if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) { + block.x = (n/4 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e4<<>>( + (float4*)output, + (const float4*)input, + (const float4*)gamma, + (const float4*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e4<<>>( + (half4*)output, + (const half4*)input, + (const half4*)gamma, + (const half4*)beta, + m, + n); + } + } //if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) + else if (n % 2 == 0) { + if (n / 2 <= 1024) { + block.x = (n/2 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } //if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n / 2 <= 1024) + else if (n <= 8192) { + block.x = ((n + 7)/8 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 8192) + else if (n <= 16384) { + block.x = ((n + 15)/ 16 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 16384) + else if (n <= 32768) { + block.x = ((n + 31)/32 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 32768) + else { + if (block.x > 512) + block.x = 512; + if (std::is_same::value) { + layernorm_twoPassAlgo_e2<<>>( + (float2 *)output, + (const float2 *)input, + (const float2 *)gamma, + (const float2 *)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_e2<<>>( + (half2 *)output, + (const half2 *)input, + (const half2 *)gamma, + (const half2 *)beta, + m, + n); + } + } + } // if (n % 2 == 0) + else { + if (n <= 1024) { + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 1024) + else if (n <= 8192) { + block.x = ((n + 7)/8 + 31)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 8192) + else if (n <= 16384) { + block.x = ((n + 15)/16 + 32)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 16384) + else if (n <= 32768) { + block.x = ((n + 31)/32 + 31)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 32768) + else{ + if (block.x > 512) { + block.x = 512; + } + layernorm_twoPassAlgo_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } + } +} + +} //namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..44f6a467a5d0938289e4bc127cddc13b9aeabdf3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h @@ -0,0 +1,375 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief C++ interface to CUDA device memory management functions. + */ + +#include +#include + +#include "cutlass/platform/platform.h" +#include "cutlass/numeric_types.h" +#include "cutlass/trace.h" +#include "exceptions.h" + +namespace cutlass { +namespace device_memory { + +/****************************************************************************** + * Allocation lifetime + ******************************************************************************/ + +/// Allocate a buffer of \p count elements of type \p T on the current CUDA device +template +T* allocate(size_t count = 1) { + + T* ptr = 0; + size_t bytes = count * sizeof_bits::value / 8; + + cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); + + if (cuda_error != cudaSuccess) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 0) + std::ostringstream os; + os << "cutlass::device_memory::allocate: cudaMalloc failed: bytes=" << bytes; + CUTLASS_TRACE_HOST(os.str()); +#endif + throw cuda_exception("Failed to allocate memory", cuda_error); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + std::ostringstream os; + os << "cutlass::device_memory::allocate: Successful cudaMalloc: bytes=" << bytes; + CUTLASS_TRACE_HOST(os.str()); + } +#endif + + return ptr; +} + +/// Free the buffer pointed to by \p ptr +template +void free(T* ptr) { + if (ptr) { + cudaError_t cuda_error = (cudaFree(ptr)); + if (cuda_error != cudaSuccess) { + throw cuda_exception("Failed to free device memory", cuda_error); + } + } +} + +/****************************************************************************** + * Data movement + ******************************************************************************/ + +template +void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) { + size_t bytes = count * sizeof_bits::value / 8; + if (bytes == 0 && count > 0) { + bytes = 1; + } + cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind)); + if (cuda_error != cudaSuccess) { + std::ostringstream os; + os << "cutlass::device_memory::copy: cudaMemcpy() failed: " + << "dst=" << dst << ", src=" << src + << ", bytes=" << bytes << ", count=" << count; + if (kind == cudaMemcpyHostToDevice) { + os << ", kind=cudaMemcpyHostToDevice"; + } + else if (kind == cudaMemcpyDeviceToHost) { + os << ", kind=cudaMemcpyDeviceToHost"; + } + else if (kind == cudaMemcpyDeviceToDevice) { + os << ", kind=cudaMemcpyDeviceToDevice"; + } + else if (kind == cudaMemcpyHostToHost) { + os << ", kind=cudaMemcpyHostToHost"; + } + else if (kind == cudaMemcpyDefault) { + os << ", kind=cudaMemcpyDefault"; + } + else { + os << ", kind=Unknown"; + } + os << ", error: " << cudaGetErrorString(cuda_error); + + throw cuda_exception(os.str().c_str(), cuda_error); + } +} + +template +void copy_to_device(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyHostToDevice); +} + +template +void copy_to_host(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyDeviceToHost); +} + +template +void copy_device_to_device(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyDeviceToDevice); +} + +template +void copy_host_to_host(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyHostToHost); +} + +/// Copies elements from device memory to host-side range +template +void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) { + size_t elements = end - begin; + copy_to_host(&*begin, device_begin, elements); +} + +/// Copies elements to device memory from host-side range +template +void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) { + size_t elements = end - begin; + copy_to_device(device_begin, &*begin, elements); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DeviceAllocation { +public: + + /// Delete functor for CUDA device memory + struct deleter { + void operator()(T* ptr) { + cudaError_t cuda_error = (cudaFree(ptr)); + if (cuda_error != cudaSuccess) { + // noexcept + // throw cuda_exception("cudaFree() failed", cuda_error); + return; + } + } + }; + +public: + // + // Data members + // + + /// Number of elements of T allocated on the current CUDA device + size_t capacity; + + /// Smart pointer + platform::unique_ptr smart_ptr; + +public: + + // + // Static methods + // + + /// Static member to compute the number of bytes needed for a given number of elements + static size_t bytes(size_t elements) { + if (sizeof_bits::value < 8) { + size_t const kElementsPerByte = 8 / sizeof_bits::value; + return elements / kElementsPerByte; + } + else { + size_t const kBytesPerElement = sizeof_bits::value / 8; + return elements * kBytesPerElement; + } + } + +public: + + // + // Methods + // + + /// Constructor: allocates no memory + DeviceAllocation() : capacity(0) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device + DeviceAllocation(size_t _capacity) : + smart_ptr(device_memory::allocate(_capacity)), capacity(_capacity) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation + DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {} + + /// Copy constructor + DeviceAllocation(DeviceAllocation const &p): + smart_ptr(device_memory::allocate(p.capacity)), capacity(p.capacity) { + + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + } + + /// Move constructor + DeviceAllocation(DeviceAllocation &&p): capacity(0) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + } + + /// Destructor + ~DeviceAllocation() { reset(); } + + /// Returns a pointer to the managed object + T* get() const { return smart_ptr.get(); } + + /// Releases the ownership of the managed object (without deleting) and resets capacity to zero + T* release() { + capacity = 0; + return smart_ptr.release(); + } + + /// Deletes the managed object and resets capacity to zero + void reset() { + capacity = 0; + smart_ptr.reset(); + } + + /// Deletes managed object, if owned, and allocates a new object + void reset(size_t _capacity) { + reset(device_memory::allocate(_capacity), _capacity); + } + + /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity + void reset(T* _ptr, size_t _capacity) { + smart_ptr.reset(_ptr); + capacity = _capacity; + } + + /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released. + void reallocate(size_t new_capacity) { + + platform::unique_ptr new_allocation(device_memory::allocate(new_capacity)); + + device_memory::copy_device_to_device( + new_allocation.get(), + smart_ptr.get(), + std::min(new_capacity, capacity)); + + std::swap(smart_ptr, new_allocation); + std::swap(new_capacity, capacity); + } + + /// Returns the number of elements + size_t size() const { + return capacity; + } + + /// Returns the number of bytes needed to store the allocation + size_t bytes() const { + return bytes(capacity); + } + + /// Returns a pointer to the object owned by *this + T* operator->() const { return smart_ptr.get(); } + + /// Returns the deleter object which would be used for destruction of the managed object. + deleter& get_deleter() { return smart_ptr.get_deleter(); } + + /// Returns the deleter object which would be used for destruction of the managed object (const) + const deleter& get_deleter() const { return smart_ptr.get_deleter(); } + + /// Copies a device-side memory allocation + DeviceAllocation & operator=(DeviceAllocation const &p) { + if (capacity != p.capacity) { + smart_ptr.reset(device_memory::allocate(p.capacity)); + capacity = p.capacity; + } + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + return *this; + } + + /// Move assignment + DeviceAllocation & operator=(DeviceAllocation && p) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + return *this; + } + + /// Copies the entire allocation from another location in device memory. + void copy_from_device(T const *ptr) const { + copy_from_device(ptr, capacity); + } + + /// Copies a given number of elements from device memory + void copy_from_device(T const *ptr, size_t elements) const { + device_memory::copy_device_to_device(get(), ptr, elements); + } + + void copy_to_device(T *ptr) const { + copy_to_device(ptr, capacity); + } + + void copy_to_device(T *ptr, size_t elements) const { + device_memory::copy_device_to_device(ptr, get(), elements); + } + + void copy_from_host(T const *ptr) const { + copy_from_host(ptr, capacity); + } + + void copy_from_host(T const *ptr, size_t elements) const { + device_memory::copy_to_device(get(), ptr, elements); + } + + void copy_to_host(T *ptr) const { + copy_to_host(ptr, capacity); + } + + void copy_to_host(T *ptr, size_t elements) const { + device_memory::copy_to_host(ptr, get(), elements); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace device_memory { + +/// Device allocation abstraction that tracks size and capacity +template +using allocation = cutlass::DeviceAllocation; + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h new file mode 100644 index 0000000000000000000000000000000000000000..8e38029951d27c0be8da059b59d2a83fe2762ef1 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h @@ -0,0 +1,141 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to transform a device memory tensor from NCHW layout to NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface to transform a device memory tensor from NCHW layout to NHWC layout. + * \tparam T: data type + */ +template +void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + +template +__global__ void nchw_to_nhwc_kernel(T *output, + const T *input, + const int n, + const int h, + const int w, + const int c) { + const int hw = h*w; + const int chw = c*hw; + __shared__ T shbuf[32 * (32 + 1)]; + const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; + const int32_t wid = tid / 32; + const int32_t lid = tid % 32; + const int32_t ni = blockIdx.z; + const int32_t ci0 = blockIdx.y * 32; + const int32_t hwi0 = blockIdx.x * 32; + + const size_t input_idx = ni * chw + (ci0 + wid) * hw + hwi0; + const T *A = input + input_idx; + if (hwi0 + lid < hw) { + const int lid_x_33 = lid * 33; + if ((ci0 + 32) <= c) { + int ci = wid; // between 0 and 7 + CUTLASS_PRAGMA_UNROLL + for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { + shbuf[lid_x_33 + ci] = A[lid]; + A = &A[8 * hw]; + ci += 8; + } + } else { + for (int ci = wid; ci < 32; ci += 8) { + if ((ci + ci0) < c) { + shbuf[lid_x_33 + ci] = A[lid]; + } + A = &A[8 * hw]; + } + } + } + __syncthreads(); + + const int32_t ciOut = ci0 + lid; + output = &output[ni * chw + ciOut]; + if (ciOut < c) { + if (hwi0 + 32 < hw) { + int hwI = wid; + CUTLASS_PRAGMA_UNROLL + for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { + output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; + hwI += 8; + } + } else { + for (int hwI = wid; hwI < 32; hwI += 8) { + if (hwi0 + hwI < hw) { + output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; + } + } + } + } +} + +template +void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream) { + + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.c() == output_tensor_size.h() && + input_tensor_size.h() == output_tensor_size.w() && + input_tensor_size.w() == output_tensor_size.c()); + + int n = output_tensor_size.n(); + int h = output_tensor_size.h(); + int w = output_tensor_size.w(); + int c = output_tensor_size.c(); + + dim3 grid((h*w + 31)/32, (c + 31)/32, n); + dim3 block(32, 8); + nchw_to_nhwc_kernel<<>>(ref_output.data(), ref_input.data(), + n, h, w, c); +} + +} //namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h new file mode 100644 index 0000000000000000000000000000000000000000..f58da62a35350b4a865f4521ec1cbb76ae87e874 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h @@ -0,0 +1,276 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels for padding in device memory with NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface for padding in a device memory tensor with NHWC layout + * \tparam T: data type + */ +template +void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + + +template +__global__ void nhwc_padding_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const int32_t c_in, + const int32_t c_out, + const T zero, + const T *input, + T *output){ + + const int32_t idx_jump = blockDim.x * gridDim.x; + const int32_t total_elements = n * h * w * c_out; + + int32_t c_idx, w_idx, h_idx, n_idx, resudial; + + T value; + for (int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += idx_jump) { + + c_idx = idx%c_out; + if (c_idx >= c_in){ + value = zero; + } + else{ + resudial = idx/c_out; + w_idx = resudial%w; + resudial = resudial/w; + h_idx = resudial%h; + n_idx = resudial/h; + resudial = ((n_idx * h + h_idx) * w + w_idx) * c_in + c_idx; + value = input[resudial]; + } + output[idx] = value; + } +} + + +// fast kernel for c_in = 3 & c_out = 4 +template +__global__ void nhwc_padding_channel_3To4_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const Tio *input, + Tio *output, + const int32_t max_output_element, + const int32_t max_input_element, + const Tio zero_io, + const Telement zero_element){ + __shared__ Tio shm[192]; + const int tidx = blockIdx.x * 192 + threadIdx.x; + const int threadidx = threadIdx.x; + + shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; + __syncthreads(); + + const int output_offset = blockIdx.x * 256; + const int lower_bound = max_output_element < output_offset + 256 ? max_output_element : output_offset + 256; + for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) + { + const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4; + Telement array[element_in_Tio]; + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k]; + output[i] = *((const Tio *)array); + } +} + +// fast kernel for c_in = 3 & c_out = 8 +template +__global__ void nhwc_padding_channel_3To8_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const Tio *input, + Tio *output, + const int32_t max_output_element, + const int32_t max_input_element, + const Tio zero_io, + const Telement zero_element){ + __shared__ Tio shm[192]; + const int tidx = blockIdx.x * 192 + threadIdx.x; + const int threadidx = threadIdx.x; + + shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; + __syncthreads(); + + const int output_offset = blockIdx.x * 512; + const int lower_bound = max_output_element < output_offset + 512 ? max_output_element : output_offset + 512; + for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) + { + const Telement* shm_element = (const Telement*)shm + (element_in_Tio == 4 ? j/2 : j)*3; + Telement array[element_in_Tio]; + //float + if (element_in_Tio == 4){ + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = ((j % 2) == 1) ? zero_element : ((k >= 3) ? zero_element : shm_element[k]); + } + //half + else{ + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = (k >= 3) ? zero_element : shm_element[k]; + } + output[i] = *((const Tio *)array); + } +} + +template +void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream){ + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.h() == output_tensor_size.h() && + input_tensor_size.w() == output_tensor_size.w() && + input_tensor_size.c() <= output_tensor_size.c()); + + int n = input_tensor_size.n(); + int h = input_tensor_size.h(); + int w = input_tensor_size.w(); + int c_in = input_tensor_size.c(); + int c_out = output_tensor_size.c(); + + //case 1 : channel == 3 padding to 4 or 8 + if ((c_out == 4 || c_out == 8) && c_in == 3 && (n*h*w % 8 == 0)){ + dim3 block(192); + const int nhw = n*h*w; + const int nhwc = nhw*c_in; + //for half_t + if (cutlass::sizeof_bits::value == 16){ + const int element_in_Tio = 8; + const int max_input_element = nhwc/element_in_Tio; + const int max_output_element = nhw*c_out/element_in_Tio; + const int4 zero_io = {0, 0, 0, 0}; + const half_t zero_element = static_cast(0.0f); + dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); + if (c_out == 4){ + nhwc_padding_channel_3To4_kernel<<>> + (n, h, w, + (const int4 *)ref_input.data(), + (int4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + else if (c_out == 8){ + nhwc_padding_channel_3To8_kernel<<>> + (n, h, w, + (const int4 *)ref_input.data(), + (int4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + } + //for float + else{ + const int element_in_Tio = 4; + const int max_input_element = nhwc/element_in_Tio; + const int max_output_element = nhw*c_out/element_in_Tio; + const float4 zero_io = {0.0f, 0.0f, 0.0f, 0.0f}; + const float zero_element = 0.0f; + dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); + if (c_out == 4){ + nhwc_padding_channel_3To4_kernel<<>> + (n, h, w, + (const float4 *)ref_input.data(), + (float4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + else if (c_out == 8){ + nhwc_padding_channel_3To8_kernel<<>> + (n, h, w, + (const float4 *)ref_input.data(), + (float4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + } + } + //case 2 : even channel + else if ((c_out % 2) == 0 && (c_in % 2) == 0){ + int32_t total_elements = n * h * w * c_out / 2; + int block_size = 256; + dim3 grid((total_elements + 255)/256); + dim3 block(block_size); + //for half_t + if (cutlass::sizeof_bits::value == 16){ + const __half2 zero = {0.0f, 0.0f}; + nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const __half2*)ref_input.data(), (__half2*)ref_output.data()); + } + //for float + else{ + const float2 zero = {0.0f, 0.0f}; + nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const float2*)ref_input.data(), (float2*)ref_output.data()); + } + } + //case 3 : odd channel + else{ + int32_t total_elements = n * h * w * c_out; + int block_size = 256; + dim3 grid((total_elements + 255)/256); + dim3 block(block_size); + const T zero = static_cast(0.0f); + nhwc_padding_kernel<<>>(n, h, w, c_in, c_out, zero, ref_input.data(), ref_output.data()); + } +} + + +} //namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..5633456c1412ff41366ec4c6ec5c3e6e3a2d6c19 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h @@ -0,0 +1,573 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do avg/max pooling on a device memory tensor with NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do avg/max pooling on a device memory tensor with NHWC layout. + * \tparam T: data type + */ +template +void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord filter_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + cutlass::MatrixCoord padding, + cutlass::MatrixCoord stride, + TensorRef ref_input, + TensorRef ref_output, + int poolingType, //0 for avg pooling ; 1 for max pooling + cudaStream_t stream); + +/** get the output size of pooling + */ +inline int getOutputSize(int H_W, int padding, int kernel_size, int stride) +{ + return (H_W + 2 * padding - kernel_size) / stride + 1; +} + +/** + * input is [N, H, W, C] + * assume stride == kernel_size + * output_h = (H + 2*padding_H - kernel_H)/stride_H + * output_w = (W + 2*padding_W - kernel_W)/stride_W + * output is [N, output_h, output_w, C] + * grid(N, output_h, output_w) + * block(min(C, 256)) : + * each block deals with C elements of output when each thread deals with ((C + 255)/256 element of output) +*/ +template +__global__ void pooling_nhwc_element1_kernel(T* output, + const T* input, + const int N, + const int H, + const int W, + const int C, + const int output_H, + const int output_W, + const int kernel_H, + const int kernel_W, + const int stride_H, + const int stride_W, + const int padding_H, + const int padding_W) +{ + const int tid = threadIdx.x; + const int n_idx = blockIdx.x; + const int output_h_idx = blockIdx.y; + const int output_w_idx = blockIdx.z; + + int h_start_idx = output_h_idx * stride_H - padding_H; + int h_end_idx = h_start_idx + kernel_H; + h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; + h_end_idx = h_end_idx > H ? H : h_end_idx; + + int w_start_idx = output_w_idx * stride_W - padding_W; + int w_end_idx = w_start_idx + kernel_W; + w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; + w_end_idx = w_end_idx > W ? W : w_end_idx; + + input += n_idx * H * W * C; + output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; + const int kernel_size2 = kernel_H * kernel_W; + for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { + float pooling; + if (IS_AVG_POOLING){ + pooling = 0.0f; + } + else{ + pooling = -FLT_MAX; + } + for (int h = h_start_idx; h < h_end_idx; h++) { + for (int w = w_start_idx; w < w_end_idx; w++) { + const int idx = (h * W + w) * C; + const float tmp = static_cast(input[idx + c_idx]); + if (IS_AVG_POOLING){ + pooling = pooling + tmp; + } + else{ + pooling = pooling > tmp ? pooling : tmp; + } + } + } + + T output_val; + if (IS_AVG_POOLING){ + output_val = T(pooling/kernel_size2); + } + else{ + output_val = T(pooling); + } + output[c_idx] = output_val; + } +} + +template +__global__ void pooling_nhwc_element2_kernel(T2* output, + const T2* input, + const int N, + const int H, + const int W, + const int C, + const int output_H, + const int output_W, + const int kernel_H, + const int kernel_W, + const int stride_H, + const int stride_W, + const int padding_H, + const int padding_W) +{ + const int tid = threadIdx.x; + const int n_idx = blockIdx.x; + const int output_h_idx = blockIdx.y; + const int output_w_idx = blockIdx.z; + + int h_start_idx = output_h_idx * stride_H - padding_H; + int h_end_idx = h_start_idx + kernel_H; + h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; + h_end_idx = h_end_idx > H ? H : h_end_idx; + + int w_start_idx = output_w_idx * stride_W - padding_W; + int w_end_idx = w_start_idx + kernel_W; + w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; + w_end_idx = w_end_idx > W ? W : w_end_idx; + + input += n_idx * H * W * C; + output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; + const int kernel_size2 = kernel_H * kernel_W; + for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { + float2 pooling; + if (IS_AVG_POOLING) { + pooling = {0.0f, 0.0f}; + } + else { + pooling = {-FLT_MAX, -FLT_MAX}; + } + for (int h = h_start_idx; h < h_end_idx; h++) { + for (int w = w_start_idx; w < w_end_idx; w++) { + const int idx = (h * W + w) * C; + const T2 tmp = input[idx + c_idx]; + const float2 tmp_flt2 = {static_cast(tmp.x), static_cast(tmp.y)}; + if (IS_AVG_POOLING) { + pooling.x += tmp_flt2.x; + pooling.y += tmp_flt2.y; + } + else { + pooling.x = pooling.x > tmp_flt2.x ? pooling.x : tmp_flt2.x; + pooling.y = pooling.y > tmp_flt2.y ? pooling.y : tmp_flt2.y; + } + } + } + + T2 output_val; + if (IS_AVG_POOLING) { + output_val.x = T(pooling.x/kernel_size2); + output_val.y = T(pooling.y/kernel_size2); + } + else { + output_val.x = T(pooling.x); + output_val.y = T(pooling.y); + } + output[c_idx] = output_val; + } +} + +/** + * output [N, 1, 1, C] + * input [N, H, W, C] + * grid(C, N) + * block(block_size) -- each block deals with H*W/block_size elements; +*/ +template +__global__ void pooling_nxhTo1x1_element1_kernel( + T* output, const T* input, const int N, const int HW, const int C) +{ + const int c_idx = blockIdx.x; + const int n_idx = blockIdx.y; + float pooling[1]; + if (IS_AVG_POOLING) { + pooling[0] = 0.0f; + } + else { + pooling[0] = -FLT_MAX; + } + const size_t input_offset = n_idx * HW * C + c_idx; + input += input_offset; + const size_t output_offset = n_idx * C + c_idx; + output += output_offset; + int tid = threadIdx.x; + + for (int index = tid; index < HW; index += blockDim.x) { + float val = static_cast(input[index * C]); + if (IS_AVG_POOLING) { + pooling[0] += val; + } + else { + pooling[0] = pooling[0] > val ? pooling[0] : val; + } + } + if (blockDim.x <= 32) { + if (IS_AVG_POOLING) { + warpReduceSum(pooling); + } + else { + warpReduceMax(pooling); + } + } + else { + if (IS_AVG_POOLING) { + blockReduceSum(pooling); + } + else { + blockReduceMax(pooling); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + T output_val; + if (IS_AVG_POOLING) { + output_val = T(pooling[0] / HW); + } + else { + output_val = T(pooling[0]); + } + output[0] = output_val; + } +} + + +/** + * output [N, 1, 1, C] + * input [N, H, W, C] + * grid(C/2, N) + * block(block_size) -- each thread deals with H*W/block_size * 2 elements; +*/ +template +__global__ void pooling_nxhTo1x1_element2_kernel( + T2* output, const T2* input, const int N, const int HW, const int C) +{ + const int c_idx = blockIdx.x; + const int n_idx = blockIdx.y; + float pooling[2]; + if (IS_AVG_POOLING) { + pooling[0] = pooling[1] = 0.0f; + } + else { + pooling[0] = pooling[1] = -FLT_MAX; + } + const int C_2 = C / 2; + const size_t input_offset = n_idx * HW * C_2 + c_idx; + input += input_offset; + const size_t output_offset = n_idx * C_2 + c_idx; + output += output_offset; + int tid = threadIdx.x; + + for (int index = tid; index < HW; index += blockDim.x) { + T2 val = input[index * C_2]; + float2 val_flt2 = {static_cast(val.x), static_cast(val.y)}; + if (IS_AVG_POOLING) { + pooling[0] += val_flt2.x; + pooling[1] += val_flt2.y; + } + else { + pooling[0] = pooling[0] > val_flt2.x ? pooling[0] : val_flt2.x; + pooling[1] = pooling[1] > val_flt2.y ? pooling[1] : val_flt2.y; + } + } + if (blockDim.x <= 32) { + if (IS_AVG_POOLING) { + warpReduceSum(pooling); + } + else { + warpReduceMax(pooling); + } + } + else { + if (IS_AVG_POOLING) { + blockReduceSum(pooling); + } + else { + blockReduceMax(pooling); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + T2 output_val; + if (IS_AVG_POOLING) { + output_val.x = T(pooling[0] / HW); + output_val.y = T(pooling[1] / HW); + } + else { + output_val.x = T(pooling[0]); + output_val.y = T(pooling[1]); + } + output[0] = output_val; + } +} + +template +void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord filter_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + cutlass::Tensor4DCoord padding, + cutlass::MatrixCoord stride, + TensorRef ref_input, + TensorRef ref_output, + int poolingType, //0 for avg pooling ; 1 for max pooling + cudaStream_t stream) { + + assert(input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.c() == output_tensor_size.c()); + + const int N = input_tensor_size.n(); + const int H = input_tensor_size.h(); + const int W = input_tensor_size.w(); + const int C = input_tensor_size.c(); + const int padding_H = padding.h(); + const int padding_W = padding.w(); + const int kernel_H = filter_tensor_size.h(); + const int kernel_W = filter_tensor_size.w(); + const int stride_H = stride.row(); + const int stride_W = stride.column(); + + const int output_H = getOutputSize(H, padding_H, kernel_H, stride_H); + const int output_W = getOutputSize(W, padding_W, kernel_W, stride_W); + + assert(output_tensor_size.h() == output_H && + output_tensor_size.w() == output_W); + + if (C % 2 != 0) { + if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { + dim3 grid(C, N); + dim3 block(256); + if (H*W < block.x){ + block.x = (H*W + 31)/32*32; + } + if (poolingType == 0) { + pooling_nxhTo1x1_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H*W, + C); + } // if (poolingType == 0) + else { + pooling_nxhTo1x1_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H*W, + C); + } + } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) + else { + dim3 grid(N, output_H, output_W); + dim3 block(256); + if (C < block.x) { + block.x = C; + } + if (poolingType == 0) { + pooling_nhwc_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H, + W, + C, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (poolingType == 0) + else { + pooling_nhwc_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H, + W, + C, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } + } // if (C % 2 != 0)) + else { + if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { + dim3 grid(C/2, N); + dim3 block(256); + if (H*W < block.x){ + block.x = (H*W + 31)/32*32; + } + if (poolingType == 0) { + if (std::is_same::value) { + pooling_nxhTo1x1_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H*W, + C); + } // if (std::is_same::value) + else { + pooling_nxhTo1x1_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H*W, + C); + } + } // if (poolingType == 0) + else { + if (std::is_same::value) { + pooling_nxhTo1x1_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H*W, + C); + } // if (std::is_same::value) + else { + pooling_nxhTo1x1_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H*W, + C); + } + } + } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) + else { + dim3 grid(N, output_H, output_W); + dim3 block(256); + if (C/2 < block.x) { + block.x = C/2; + } + if (poolingType == 0) { + if (std::is_same::value) { + pooling_nhwc_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (std::is_same::value) + else { + pooling_nhwc_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } // if (poolingType == 0) + else { + if (std::is_same::value) { + pooling_nhwc_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (std::is_same::value) + else { + pooling_nhwc_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } + } + } +} + +} //namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h new file mode 100644 index 0000000000000000000000000000000000000000..babfecd39205ebff39794133868e4a95b7e9525c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h @@ -0,0 +1,144 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to transform a device memory tensor from NHWC layout to NCHW layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface to transform a device memory tensor from NHWC layout to NCHW layout. + * \tparam T: data type + */ +template +void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + + +template +__global__ void nhwc_to_nchw_kernel(T *output, + const T *input, + const int n, + const int h, + const int w, + const int c) { + + const int hw = h*w; + const int hwc = hw*c; + __shared__ T shbuf[32 * (32 + 1)]; + const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; + const int32_t wid = tid / 32; + const int32_t lid = tid % 32; + const int32_t ni = blockIdx.z; + const int32_t hwi0 = blockIdx.y * 32; + const int32_t ci0 = blockIdx.x * 32; + + const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0; + const T *A = input + input_idx; + if (ci0 + lid < c) { + const int lid_x_33 = lid * 33; + if ((hwi0 + 32) <= hw) { + int hwi = wid; // between 0 and 7 + CUTLASS_PRAGMA_UNROLL + for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { + shbuf[lid_x_33 + hwi] = A[lid]; + A = &A[8 * c]; + hwi += 8; + } + } else { + for (int hwi = wid; hwi < 32; hwi += 8) { + if ((hwi + hwi0) < hw) { + shbuf[lid_x_33 + hwi] = A[lid]; + } + A = &A[8 * c]; + } + } + } + __syncthreads(); + + const int32_t hwiOut = hwi0 + lid; + output = &output[ni * hwc + hwiOut]; + if (hwiOut < hw) { + if (ci0 + 32 < c) { + int cI = wid; + CUTLASS_PRAGMA_UNROLL + for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { + output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; + cI += 8; + } + } else { + for (int cI = wid; cI < 32; cI += 8) { + if (ci0 + cI < c) { + output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; + } + } + } + } +} + +template +void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream) { + + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.h() == output_tensor_size.c() && + input_tensor_size.w() == output_tensor_size.h() && + input_tensor_size.c() == output_tensor_size.w()); + + int n = input_tensor_size.n(); + int h = input_tensor_size.h(); + int w = input_tensor_size.w(); + int c = input_tensor_size.c(); + + dim3 grid((c + 31)/32, (h*w + 31)/32, n); + dim3 block(32, 8); + nhwc_to_nchw_kernel<<>>(ref_output.data(), ref_input.data(), + n, h, w, c); + +} + +} //namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..0d1b1af56e4463640edc3e9c82533baf815c9b27 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h @@ -0,0 +1,186 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/device_utils.h" +#include + +namespace cutlass { + +__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input, + const float4 *weight, + const int m, const int n, float epsilon) { + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + const int n_8 = n / 8; + int offset = m_idx * n_8; + input += offset; + output += offset; + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const half2 *h1 = (half2 *)&local_val.x; + const half2 *h2 = (half2 *)&local_val.y; + const half2 *h3 = (half2 *)&local_val.z; + const half2 *h4 = (half2 *)&local_val.w; + local_sums[0] += static_cast(h1->x) * static_cast(h1->x) + + static_cast(h1->y) * static_cast(h1->y) + + static_cast(h2->x) * static_cast(h2->x) + + static_cast(h2->y) * static_cast(h2->y) + + static_cast(h3->x) * static_cast(h3->x) + + static_cast(h3->y) * static_cast(h3->y) + + static_cast(h4->x) * static_cast(h4->x) + + static_cast(h4->y) * static_cast(h4->y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + epsilon); + } + __syncthreads(); + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const float4 weight_val = weight[index]; + + const half2 *l1 = (half2 *)&local_val.x; + const half2 *l2 = (half2 *)&local_val.y; + const half2 *l3 = (half2 *)&local_val.z; + const half2 *l4 = (half2 *)&local_val.w; + + const half2 *g1 = (half2 *)&weight_val.x; + const half2 *g2 = (half2 *)&weight_val.y; + const half2 *g3 = (half2 *)&weight_val.z; + const half2 *g4 = (half2 *)&weight_val.w; + + float4 tmp; + half2 *h1 = (half2 *)&tmp.x; + half2 *h2 = (half2 *)&tmp.y; + half2 *h3 = (half2 *)&tmp.z; + half2 *h4 = (half2 *)&tmp.w; + + h1->x = half(static_cast(l1->x) * s_mean * static_cast(g1->x)); + h1->y = half(static_cast(l1->y) * s_mean * static_cast(g1->y)); + h2->x = half(static_cast(l2->x) * s_mean * static_cast(g2->x)); + h2->y = half(static_cast(l2->y) * s_mean * static_cast(g2->y)); + h3->x = half(static_cast(l3->x) * s_mean * static_cast(g3->x)); + h3->y = half(static_cast(l3->y) * s_mean * static_cast(g3->y)); + h4->x = half(static_cast(l4->x) * s_mean * static_cast(g4->x)); + h4->y = half(static_cast(l4->y) * s_mean * static_cast(g4->y)); + + output[index] = tmp; + } +} + +template +__global__ void rmsnorm_twoPassAlgo_e1(T* output, + const T* input, + const T* weight, + const int m, const int n, + float epsilon) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_sums[0] += local_val * local_val; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + epsilon); + } + __syncthreads(); + + for (int index = tid ; index < n ; index += bdimx){ + const T weight_val = weight[index]; + const T local_val = input[index]; + output[index] = T(static_cast(local_val) * s_mean * static_cast(weight_val)); + } +} + +template +void rmsnorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_weight, + cudaStream_t stream, float epsilon = 1e-5f){ + const int m = tensor_size.row(); + const int n = tensor_size.column(); + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* weight = ref_weight.data(); + dim3 grid(m); + + if (n % 8 == 0 && std::is_same::value) { + dim3 block(cutlass::platform::min(1024, (n / 8 + 31) / 32 * 32)); + + rmsnorm_twoPassAlgo_e8<<>>( + (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon); + } else { + dim3 block(cutlass::platform::min(1024, ((n + 31)/32 + 31)/32*32)); + + rmsnorm_twoPassAlgo_e1<<>>( + output, input, weight, m, n, epsilon); + } + + auto result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(result) << std::endl; + abort(); + } +} + +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9747d50975d7d35df287f6b056aedc489adb317c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief utils code for device cutlass code +*/ + +#pragma once + +#include +#include +#define FINAL_MASK 0xffffffff + +struct half4 { + half x, y, z, w; +}; + +template +__inline__ __device__ T warpReduceSum(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSum(T* val) +{ + __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSum(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSum(val); + return (T)0.0f; +} + +template +__inline__ __device__ T warpReduceMax(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceMax(T* val) +{ + static __shared__ T shared[32][NUM]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[wid][i] = val[i]; + } + } + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[lane][i] : (T)(-FLT_MAX); + } + warpReduceMax(val); + + return (T)0.0f; +} + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h new file mode 100644 index 0000000000000000000000000000000000000000..6565aba9607ad68defacb6e98d9f9bbc944cd48d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief This header contains a class to parametrize a statistical distribution function. +*/ + +#include + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Distribution type +struct Distribution { + /// Variant types + enum Kind { Invalid, Uniform, Gaussian, Identity, Sequential, AllZeros, AllOnes }; + + /// Distribution state + union { + /// Uniform distribution + struct { + double min; + double max; + // Percent elements set to NaN + double pnan; + } uniform; + + /// Gaussian distribution + struct { + double mean; + double stddev; + double pnz; + double pnzA; + double pnzB; + double pnzC; + } gaussian; + + /// Elements are linear combination of row and column index + struct { + double start; + double delta; + } sequential; + }; + + /// Active variant kind + Kind kind; + + /// Random values are cast to integer after scaling by this power of two + int int_scale; + + // + // Methods + // + + Distribution() : kind(Invalid), int_scale(0) {} + +/// Configures distribution as uniform random + Distribution &set_uniform(double _min, double _max, int _int_scale = 0, double _pnan = 0) { + kind = Uniform; + uniform.min = _min; + uniform.max = _max; + int_scale = _int_scale; + uniform.pnan = _pnan; + return *this; + } + + /// Configures distribution as Gaussian distribution + Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0, double _pnz = 1.0) { + kind = Gaussian; + gaussian.mean = _mean; + gaussian.stddev = _stddev; + gaussian.pnz = _pnz; + gaussian.pnzA = _pnz; + gaussian.pnzB = _pnz; + gaussian.pnzC = _pnz; + int_scale = _int_scale; + return *this; + } + + /// Sets identity + Distribution &set_identity() { + kind = Identity; + return *this; + } + + /// Sets sequential + Distribution &set_sequential(double start, double delta, int _int_scale = 0) { + kind = Sequential; + sequential.start = start; + sequential.delta = delta; + int_scale = _int_scale; + return *this; + } +}; + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints a Distribution to ostream +inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const &dist) { + switch (dist.kind) { + case cutlass::Distribution::Uniform: + out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max + << ", pnan: " << dist.uniform.pnan; + break; + case cutlass::Distribution::Gaussian: + out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev + << ", pnzA: " << dist.gaussian.pnzA << ", pnzB: " + << dist.gaussian.pnzB << ", pnzC: " << dist.gaussian.pnzC; + break; + case cutlass::Distribution::Identity: + out << "identity"; + break; + case cutlass::Distribution::Sequential: + out << "sequential"; + break; + default: + out << "unknown"; + } + + out << ", int_scale: " << dist.int_scale; + + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h new file mode 100644 index 0000000000000000000000000000000000000000..f2b7df6cb1c465a312d76566768cb79fcdfffee4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief C++ exception semantics for CUDA error codes + */ + +#include +#include +#include + +#include "cutlass/platform/platform.h" + +namespace cutlass { + +/// C++ exception wrapper for CUDA \p cudaError_t +class cuda_exception : public std::exception { + public: + /// Constructor + cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {} + + /// Returns the underlying CUDA \p cudaError_t + cudaError_t cudaError() const { return err; } + + protected: + /// Explanatory string + const char* msg; + + /// Underlying CUDA \p cudaError_t + cudaError_t err; +}; + +/// Writes a cuda_exception instance to an output stream +inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) { + return out << e.what() << ": " << cudaGetErrorString(e.cudaError()); +} + +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..be2264466e350c062900a50e27e923847186d084 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp @@ -0,0 +1,369 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT command line parser to gather semantic modes, their stride order, and extents. +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +namespace cutlass { + +// Output shortcuts +std::ostream& operator<<(std::ostream& os, std::vector data) { + for (auto& a : data) os << a; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, std::vector data) { + for (auto& a : data) os << a << " "; + return os; +} + +struct GettCommandLine { + struct GettProblem { + using extent_type = int; + using stride_type = int64_t; + + // Row modes: appear in A and C/D + std::vector M; + std::vector ldAm; + std::vector ldCm; + + // Column modes: appear in B and C/D + std::vector N; + std::vector ldBn; + std::vector ldCn; + + // Reduction modes: appear in A and B + std::vector K; + std::vector ldAk; + std::vector ldBk; + + // Batch modes: appear in all in/out tensors + std::vector L; + std::vector ldAl; + std::vector ldBl; + std::vector ldCl; + }; + + static GettProblem + parse(int argc, char const* argv[], bool parse_verbose = false) { + using extent_type = typename GettProblem::extent_type; + using stride_type = typename GettProblem::stride_type; + + cutlass::CommandLine cmd(argc, argv); + + // modeA + std::vector a_mode; + cmd.get_cmd_line_arguments("modeA", a_mode); + + // modeB + std::vector b_mode; + cmd.get_cmd_line_arguments("modeB", b_mode); + + // modeC + std::vector c_mode; + cmd.get_cmd_line_arguments("modeC", c_mode); + + + // mode_sizes + std::map mode_size; + // First, initialize all modes in a, b, c to make sure they're in map + for (char a : a_mode) mode_size[a] = 1; + for (char b : b_mode) mode_size[b] = 1; + for (char c : c_mode) mode_size[c] = 1; + + // Then, overwrite the ones in -extent + std::vector > extent_tokens; + cmd.get_cmd_line_argument_pairs("extents", extent_tokens); + for (auto e : extent_tokens) { + if (std::get<0>(e).size() > 1) { + std::cerr << "ERROR: Mode name must only be 1 character long.\n"; + print_usage(); + exit(1); + } + char label = std::get<0>(e)[0]; + int size = std::stoi(std::get<1>(e)); + mode_size[label] = size; + } + + // Print out symbolic modes and their extents + if (parse_verbose) { + std::cout << "C_" << c_mode << " = A_" << a_mode << " * B_" << b_mode << "\n"; + for (auto e : mode_size) std::cout << " " << std::get<0>(e) << " : " << std::get<1>(e) << "\n"; + } + + // + // Collect/Compute strides + // + + std::map mode_ldA; + std::map mode_ldB; + std::map mode_ldC; + + { + stride_type current; + + current = 1; + for (char a : a_mode) { mode_ldA[a] = current; current *= mode_size[a]; } + + current = 1; + for (char b : b_mode) { mode_ldB[b] = current; current *= mode_size[b]; } + + current = 1; + for (char c : c_mode) { mode_ldC[c] = current; current *= mode_size[c]; } + } + + // + // Collect mode categories + // + + std::vector row_mode; // rows + std::vector col_mode; // columns + std::vector red_mode; // reductions + std::vector bat_mode; // batches + + { + std::vector a_label = a_mode; + std::vector b_label = b_mode; + std::vector c_label = c_mode; + + std::sort(std::begin(a_label), std::end(a_label)); + std::sort(std::begin(b_label), std::end(b_label)); + std::sort(std::begin(c_label), std::end(c_label)); + + // std::set_intersections to find semantic category of each symbolic mode + std::set_intersection(std::begin(a_label), std::end(a_label), + std::begin(c_label), std::end(c_label), + std::back_inserter(row_mode)); + + std::set_intersection(std::begin(b_label), std::end(b_label), + std::begin(c_label), std::end(c_label), + std::back_inserter(col_mode)); + + std::set_intersection(std::begin(a_label), std::end(a_label), + std::begin(b_label), std::end(b_label), + std::back_inserter(red_mode)); + + std::set_intersection(std::begin(row_mode), std::end(row_mode), + std::begin(col_mode), std::end(col_mode), + std::back_inserter(bat_mode)); + + // std::set_difference to remove batch modes from other semantic modes + for (char l : bat_mode) { + row_mode.erase(std::remove(std::begin(row_mode), std::end(row_mode), l), std::end(row_mode)); + col_mode.erase(std::remove(std::begin(col_mode), std::end(col_mode), l), std::end(col_mode)); + red_mode.erase(std::remove(std::begin(red_mode), std::end(red_mode), l), std::end(red_mode)); + } + } + + // Print out the semantic association of each symbolic mode + if (parse_verbose) { + std::cout << " rows : " << row_mode << '\n'; + std::cout << " cols : " << col_mode << '\n'; + std::cout << " reds : " << red_mode << '\n'; + std::cout << " bats : " << bat_mode << '\n'; + } + + // + // Permute modes + // + + // Permute the batched modes to promote coalescing + // Sort the batched modes by min(ldAl,ldBl) and in case of a tie by the size + std::sort(std::begin(bat_mode), std::end(bat_mode), [&](char l1, char l2) { + return std::tie(std::min(mode_ldA[l1],mode_ldB[l1]),mode_size[l1]) + < std::tie(std::min(mode_ldA[l2],mode_ldB[l2]),mode_size[l2]); + }); + // Compute sizes and strides of ordered reduction modes + std::vector L; + std::vector ldAl; + std::vector ldBl; + std::vector ldCl; + for (char l : bat_mode) { + L.push_back(mode_size[l]); + ldAl.push_back(mode_ldA[l]); + ldBl.push_back(mode_ldB[l]); + ldCl.push_back(mode_ldC[l]); + } + + // Permute the reduction modes to promote coalescing + // Sort the reduction modes by min(ldAk,ldBk) and in case of a tie by the size + std::sort(std::begin(red_mode), std::end(red_mode), [&](char k1, char k2) { + return std::tie(std::min(mode_ldA[k1],mode_ldB[k1]),mode_size[k1]) + < std::tie(std::min(mode_ldA[k2],mode_ldB[k2]),mode_size[k2]); + }); + // Compute sizes and strides of ordered reduction modes + std::vector K; + std::vector ldAk; + std::vector ldBk; + for (char k : red_mode) { + K.push_back(mode_size[k]); + ldAk.push_back(mode_ldA[k]); + ldBk.push_back(mode_ldB[k]); + } + + // Permute the row modes to promote coalescing + // Sort the row modes by min(ldAm,ldCm) and in case of a tie by ldAm + std::sort(std::begin(row_mode), std::end(row_mode), [&](char m1, char m2) { + return std::tie(std::min(mode_ldA[m1],mode_ldC[m1]),mode_ldA[m1]) + < std::tie(std::min(mode_ldA[m2],mode_ldC[m2]),mode_ldA[m2]); + }); + // Compute sizes and strides of ordered row modes + std::vector M; + std::vector ldAm; + std::vector ldCm; + for (char m : row_mode) { + M.push_back(mode_size[m]); + ldAm.push_back(mode_ldA[m]); + ldCm.push_back(mode_ldC[m]); + } + + // Permute the col modes to promote coalescing + // Sort the col modes by min(ldBn,ldCn) and in case of a tie by ldBn + std::sort(std::begin(col_mode), std::end(col_mode), [&](char n1, char n2) { + return std::tie(std::min(mode_ldB[n1],mode_ldC[n1]),mode_ldB[n1]) + < std::tie(std::min(mode_ldB[n2],mode_ldC[n2]),mode_ldB[n2]); + }); + // Compute sizes and strides of ordered col modes + std::vector N; + std::vector ldBn; + std::vector ldCn; + for (char n : col_mode) { + N.push_back(mode_size[n]); + ldBn.push_back(mode_ldB[n]); + ldCn.push_back(mode_ldC[n]); + } + + if (parse_verbose) { + std::cout << "C_"; + if (! row_mode.empty()) { + std::cout << "(" << row_mode << ")"; + } + if (! col_mode.empty()) { + std::cout << "(" << col_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << " = A_"; + if (! row_mode.empty()) { + std::cout << "(" << row_mode << ")"; + } + if (! red_mode.empty()) { + std::cout << "(" << red_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << " * B_"; + if (! col_mode.empty()) { + std::cout << "(" << col_mode << ")"; + } + if (! red_mode.empty()) { + std::cout << "(" << red_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << '\n'; + + int M_size = std::accumulate(std::begin(M), std::end(M), 1, std::multiplies<>{}); + int N_size = std::accumulate(std::begin(N), std::end(N), 1, std::multiplies<>{}); + int K_size = std::accumulate(std::begin(K), std::end(K), 1, std::multiplies<>{}); + int L_size = std::accumulate(std::begin(L), std::end(L), 1, std::multiplies<>{}); + + std::cout << " M : (" << M_size << ") "; + for (char m : row_mode) std::cout << m << ":" << mode_size[m] << " "; + std::cout << '\n'; + std::cout << " N : (" << N_size << ") "; + for (char n : col_mode) std::cout << n << ":" << mode_size[n] << " "; + std::cout << '\n'; + std::cout << " K : (" << K_size << ") "; + for (char k : red_mode) std::cout << k << ":" << mode_size[k] << " "; + std::cout << '\n'; + std::cout << " L : (" << L_size << ") "; + for (char l : bat_mode) std::cout << l << ":" << mode_size[l] << " "; + std::cout << '\n'; + + std::cout << " ldAm : " << ldAm << '\n'; + std::cout << " ldAk : " << ldAk << '\n'; + std::cout << " ldAl : " << ldAl << '\n'; + std::cout << " ldBn : " << ldBn << '\n'; + std::cout << " ldBk : " << ldBk << '\n'; + std::cout << " ldBl : " << ldBl << '\n'; + std::cout << " ldCm : " << ldCm << '\n'; + std::cout << " ldCn : " << ldCn << '\n'; + std::cout << " ldCl : " << ldCl << '\n'; + } + + return {M, ldAm, ldCm, + N, ldBn, ldCn, + K, ldAk, ldBk, + L, ldAl, ldBl, ldCl}; + } + + static void + print_usage() { + std::cout << + "GETT problem command line parser:\n" + " --modeA=\n" + " A comma delimited list of characters that correspond to the row, reduction, and batch modes in A tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --modeB=\n" + " A comma delimited list of characters that correspond to the column, reduction, and batch modes in B tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --modeC=\n" + " A comma delimited list of characters that correspond to the row, column, and batch modes in B tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --extents=\n" + " A command delimited list of symbolic mode and its corresponding extent.\n" + " Extents are defaulted to 1 if any are not provided.\n\n" + + "Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096\n"; + } +}; + +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp new file mode 100644 index 0000000000000000000000000000000000000000..58d08b860c9e665d170fd022ed0d95875e029019 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp @@ -0,0 +1,116 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +namespace cute +{ + +void +device_init(int device_id, bool quiet = false) +{ + cudaDeviceProp device_prop; + std::size_t device_free_physmem; + std::size_t device_total_physmem; + + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + + //float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000; + + if (!quiet) { + printf("Using device %d: %s (SM%d, %d SMs)\n", + device_id, device_prop.name, + device_prop.major * 10 + device_prop.minor, + device_prop.multiProcessorCount); + fflush(stdout); + } +} + +/** + * Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores. + */ +inline int +_ConvertSMVer2Cores(int major, int minor) +{ + // Defines for GPU Architecture types (using the SM version to determine + // the # of cores per SM + typedef struct { + int SM; // 0xMm (hexadecimal notation), M = SM Major version, + // and m = SM minor version + int Cores; + } sSMtoCores; + + sSMtoCores nGpuArchCoresPerSM[] = { + {0x30, 192}, + {0x32, 192}, + {0x35, 192}, + {0x37, 192}, + {0x50, 128}, + {0x52, 128}, + {0x53, 128}, + {0x60, 64}, + {0x61, 128}, + {0x62, 128}, + {0x70, 64}, + {0x72, 64}, + {0x75, 64}, + {-1, -1}}; + + int index = 0; + + while (nGpuArchCoresPerSM[index].SM != -1) { + if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { + return nGpuArchCoresPerSM[index].Cores; + } + index++; + } + + // If we don't find the values, we default use the previous one + // to run properly + printf("MapSMtoCores for SM %d.%d is undefined." + " Default to use %d Cores/SM\n", + major, minor, nGpuArchCoresPerSM[index - 1].Cores); + + return nGpuArchCoresPerSM[index - 1].Cores; +} + +} // end namespace cute diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h new file mode 100644 index 0000000000000000000000000000000000000000..4e7718059dfaea0c77d7ebf67789f307b4ca0cf6 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief reorder data from the host side +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { + +/// This is needed for the interleaved integer tensor core kernels. The purpose +/// is to use skip the shared memory part in the epilogue. +template +void reorder_column(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + const int InstructionShapeCol = 8; + // 4 threads per Quad + const int ElementsPerThread = InstructionShapeCol / 4; + // 4 threads per Quad + const int ReorderedElementsPerThread = + Interleaved / 4; + + for (int n = 0; n < problem_size.n(); n++) { + for (int k = 0; k < problem_size.k(); k++) { + dest.at({k, (n / Interleaved) * Interleaved + + ((n % ReorderedElementsPerThread) / ElementsPerThread) * + InstructionShapeCol + + ((n % Interleaved) / ReorderedElementsPerThread) * + ElementsPerThread + + (n % ElementsPerThread)}) = src.at({k, n}); + } + } +} + +template +void reorder_convK(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + + TensorRef> mappedDest(dest.data(), dest.stride(0)); + TensorRef> mappedSrc(src.data(), src.stride(0)); + + reorder_column( + mappedDest, mappedSrc, problem_size); +} + +/// This is needed for the sparse tensor core kernels. The purpose +/// is to use ldmatrix to load from shared memory to the register file. +template +void reorder_meta(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + for (int m = 0; m < problem_size.m(); m++) { + for (int k = 0; k < problem_size.k(); k++) { + // First reorder the rows. + int group = (sizeof(Element) == 2) ? 32 : 16; + int interweave = (sizeof(Element) == 2) ? 4 : 2; + + int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8; + int dest_col = k; + + // Next swizzle the 2x2 blocks from Z to N. + if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) { + ++dest_row; + --dest_col; + } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) { + --dest_row; + ++dest_col; + } + + dest.at({dest_row, dest_col}) = src.at({m, k}); + } + } +} +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..3226055ad0836e7a3059340ff16d54594987e0c8 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensor { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRef; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorView; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + +private: + using StorageUnit = typename platform::conditional_t, uint8_t, // Avoid the std::vector specialization + typename platform::conditional_t::value % 8 == 0, // Handle subbyte types + Element, uint8_t>>; + using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; + static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits; + static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements; + static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes; + static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit; + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + /// number of containers + size_t count_to_container_storage_unit_count(size_t count) { + return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit; + } + +public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensor() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensor( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensor( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensor() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")"); +#endif + + device_.reset(); + host_.clear(); + + size_t count_container = count_to_container_storage_unit_count(count); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")"); +#endif + host_.resize(count_container); + + // Allocate memory + StorageUnit* device_memory = nullptr; + if (device_backed_) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")"); +#endif + device_memory = device_memory::allocate(count_container); + } + device_.reset(device_memory, device_backed_ ? count_container : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_))); + + if (static_cast(new_size_container) > host_.size()) { + reserve(new_size, device_backed_); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the logical number of elements stored in the host tensor + size_t size() const { + return layout_.capacity(extent_); + } + + /// Returns the logical capacity in terms of number of elements. May be larger than the size(). + LongIndex capacity() const { + return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements; + } + + /// Gets pointer to host data + Element * host_data() { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to device data + Element * device_data() { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data + Element const * device_data() const { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + + /// Returns the layout object + Layout & layout() { + return layout_; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_.data(), device_.get(), device_.size()); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_.get(), host_.data(), host_.size()); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + host_.data(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + device_.get(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + device_.get(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + host_.data(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + reinterpret_cast(ptr_host), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + reinterpret_cast(ptr_device), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + reinterpret_cast(ptr_device), host_.data(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + reinterpret_cast(ptr_host), host_.data(), container_count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..ca770e4d76cfe2df16309baca0b2de8ab6de98c4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h @@ -0,0 +1,591 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/tensor_ref_planar_complex.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensorPlanarComplex { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRefPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorViewPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + + private: + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensorPlanarComplex() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensorPlanarComplex( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensorPlanarComplex( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensorPlanarComplex() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated + + device_.reset(); + host_.clear(); + + host_.resize(count * 2); + + // Allocate memory + Element* device_memory = nullptr; + if (device_backed_) { + device_memory = device_memory::allocate(count * 2); + } + device_.reset(device_memory, device_backed_ ? count * 2 : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + + if (static_cast(new_size * 2) > host_.size()) { + reserve(new_size); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the number of elements stored in the host tensor + size_t size() const { + return host_.size() / 2; + } + + /// Returns the logical capacity based on extent and layout. May differ from size(). + LongIndex capacity() const { + return layout_.capacity(extent_); + } + + /// Stride between real and imaginary parts + LongIndex imaginary_stride() const { + return host_.size() / 2; + } + + /// Gets pointer to host data + Element * host_data() { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element * host_data_imag() { return host_.data() + imaginary_stride(); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; } + + /// Gets pointer to host data with a pointer offset + Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element const * host_data_imag() const { return host_.data() + imaginary_stride(); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to device data + Element * device_data() { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; } + + /// Gets pointer to device data + Element const * device_data() const { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; } + + /// Gets a pointer to the device data imaginary part + Element * device_data_imag() { return device_.get() + imaginary_stride(); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { + return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_real() { + return cutlass::TensorRef(host_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_imag() { + return cutlass::TensorRef(host_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { + return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_real() { + return cutlass::TensorRef(device_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_imag() { + return cutlass::TensorRef(device_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_real() { + return cutlass::TensorView(host_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_imag() { + return cutlass::TensorView(host_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_real() { + return cutlass::TensorView(device_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_imag() { + return cutlass::TensorView(device_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_data(), device_data(), imaginary_stride() * 2); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_data(), host_data(), imaginary_stride() * 2); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device_real, ///< source device memory + Element const* ptr_device_imag, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_host( + host_data(), ptr_device_real, count); + + device_memory::copy_to_host( + host_data_imag(), ptr_device_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device_real, ///< source device memory + Element const* ptr_device_imag, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_device_to_device( + device_data(), ptr_device_real, count); + + device_memory::copy_device_to_device( + device_data_imag(), ptr_device_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host_real, ///< source host memory + Element const* ptr_host_imag, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_device( + device_data(), ptr_host_real, count); + + device_memory::copy_to_device( + device_data_imag(), ptr_host_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host_real, ///< source host memory + Element const* ptr_host_imag, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_host_to_host( + host_data(), ptr_host_real, count); + + device_memory::copy_host_to_host( + host_data_imag(), ptr_host_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host_real, ///< source device memory + Element * ptr_host_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_host( + ptr_host_real, device_data(), count); + + device_memory::copy_to_host( + ptr_host_imag, device_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device_real, ///< source device memory + Element * ptr_device_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_device_to_device( + ptr_device_real, device_data(), count); + + device_memory::copy_device_to_device( + ptr_device_imag, device_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device_real, ///< source device memory + Element * ptr_device_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_device( + ptr_device_real, host_data(), count); + + device_memory::copy_to_device( + ptr_device_imag, host_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host_real, ///< source host memory + Element * ptr_host_imag, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_host_to_host( + ptr_host_real, host_data(), count); + + device_memory::copy_host_to_host( + ptr_host_imag, host_data_imag(), count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h new file mode 100644 index 0000000000000000000000000000000000000000..9cd62927432c65ce1f0187f46306f7e1198a1182 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief uncompress sparse matrix from the host side +*/ +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { + +// uncompress sparse tensor core A matrix +template +void uncompress(TensorRef uncompressed_tensor_a, + TensorRef tensor_a, + TensorRef tensor_e, int row, int col) { + // How many uncompressed data we can get with ElementE meta data + int DecompressedElementsPerElementE = + 256 / cutlass::sizeof_bits::value; + + // Process 4bit meta data a time + int step; + + // 1:2 or 2:4 or 4:8 + int a, b; + + if (cutlass::sizeof_bits::value == 4) { + step = 8; + a = 4; + b = 8; + } else if (cutlass::sizeof_bits::value == 8) { + step = 4; + a = 2; + b = 4; + } else if (cutlass::sizeof_bits::value == 16) { + step = 4; + a = 2; + b = 4; + } else if (cutlass::sizeof_bits::value == 32) { + step = 2; + a = 1; + b = 2; + } + + int ElementsPerE = (cutlass::sizeof_bits::value == 4) ? 2 : 1; + + for (int r = 0; r < row; ++r) { + for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) { + + ElementE meta = tensor_e.at(MatrixCoord(r, c)); + + for (int i = 0; i < DecompressedElementsPerElementE; i += step) { + int e = (meta >> (i / step * 4)) & 0xf; + int idx0 = e & 0x3; + int idx1 = e >> 2; + + if (a == 1) idx0 = idx0 / 2; + + for (int ii = 0; ii < step; ii += ElementsPerE) { + int real_col = + c * DecompressedElementsPerElementE + i + ii; + int compressed_col = (real_col / b) * a; + + if (ii == (idx0 * ElementsPerE)) { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + tensor_a.at(MatrixCoord(r, compressed_col)); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + tensor_a.at(MatrixCoord(r, compressed_col + 1)); + } else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE)); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + tensor_a.at( + MatrixCoord(r, compressed_col + ElementsPerE + 1)); + } else { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + ElementA(0); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + ElementA(0); + } + } + } + } + } +} + +// uncompress ELL block sparse matrix +template +void uncompress_ell_block_sparse( + TensorRef uncompressed_tensor_a, + TensorRef tensor_a, + TensorRef ell_idx, + int rows, int cols, + int ell_num_cols, int ell_blocksize) { + + for (int r = 0; r < rows / ell_blocksize; ++r) { + for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) { + + ElementE idx = ell_idx.at(MatrixCoord(r, c)); + + if (idx != -1) { + int row_begin = r * ell_blocksize; + int col_begin_real = idx * ell_blocksize; + int col_begin = c * ell_blocksize; + + for (int i = 0; i < ell_blocksize; ++i) { + for (int j = 0; j < ell_blocksize; ++j) { + uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) = + tensor_a.at( + MatrixCoord(row_begin + i, col_begin +j)); + } + } + } + } + } +} + +} // namespace cutlass + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h new file mode 100644 index 0000000000000000000000000000000000000000..6b72b043fc0c1271cf9f12e5cb9a81d29659cb0a --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h @@ -0,0 +1,38 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +// integer_sequence moved to cutlass/numeric_types.h + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..43f5a3f92d29f229703cc4c5f9071c11d0f89df4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -0,0 +1,472 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for mixed input data type kernels. +*/ + +#pragma once + +#include +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/arch/mma_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cute/util/type_traits.hpp" + +namespace cutlass { + +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleBroadCastLayout, + class ThrLayout> +__global__ void dequantize_kernel(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleBroadCastLayout const broadcasted_scale_layout, + ThrLayout thr_layout) { + using namespace cute; + + // Represent the full tensors to gmem elements. + // These are expected to have shape [MN, K, L] + cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout); + cute::Tensor gmem_op_q = cute::make_tensor(cute::make_gmem_ptr(q_buffer), operand_layout); + // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting + // It is expected that K % G == 0 + cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); + cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout); + + // Assign 1 thread per element in the thread block + auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); // + auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L) + + // Tile across the block + auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord); + auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord); + auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord); + auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord); + + auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x); + auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x); + auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x); + auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x); + + // Make a fragment of registers to hold gmem loads + cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0)); + cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0)); + cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0)); + cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0)); + cute::Tensor rmem_op_scaled = cute::make_fragment_like(rmem_op_dq); + cute::Tensor rmem_zero_buf = cute::make_fragment_like(rmem_zero); + + cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout)); + auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord); + auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x); + + const auto num_iters = cute::size<3>(tOpDq_gOpDq); + + for (int ii = 0; ii < num_iters; ++ii) { + const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii)); + if (thread_offset < cute::size<0>(operand_layout)) { + cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q); + cute::copy(tScale_gScale(_, _, _, ii), rmem_scale); + cute::copy(tZero_gZero(_, _, _, ii), rmem_zero); + cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } ); + cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } ); + cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{}); + cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{}); + cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } ); + cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii)); + } + } +} + +template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleLayout> +static void dequantize(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleLayout const scale_layout, + int const group_size, + cudaStream_t &stream) { + using namespace cute; + + constexpr int tpb = 128; + auto thr_layout = make_layout(make_shape(Int{})); + + const auto num_rows = get<0>(shape(operand_layout)); + const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L] + const auto batches = get<2>(shape(operand_layout)); // [MN, K, L] + const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L] + + if (num_rows != size<0>(scale_layout)) { + std::cerr << "Invalid first dimension for scales. Must match first dim for weights." + << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) + << std::endl; + exit(-1); + } + + const auto scale_stride0 = get<0>(stride(scale_layout)); + const auto scale_stride1 = get<1>(stride(scale_layout)); + const auto scale_stride2 = get<2>(stride(scale_layout)); + + auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches); + auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2); + auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast); + + const auto blocks_x = gemm_k; + const auto blocks_y = batches; + + dim3 blocks(blocks_x, blocks_y, 1); + dequantize_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout); + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +template +class packed_scale_t { +public: + static_assert(cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "only 8 bit arithmetic types are supported."); + CUTLASS_HOST_DEVICE + explicit packed_scale_t(T val) { + if constexpr (!cute::is_unsigned_v) { + // Only pack negative values. The positive values are generated in flight in the mainloop. + storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f)); + storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val); + } + else { + storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f)); + storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val); + } + } + CUTLASS_HOST_DEVICE + packed_scale_t() = default; + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + CUTLASS_HOST_DEVICE + bool operator==(packed_scale_t const& rhs) const { + return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1]; + } + CUTLASS_HOST_DEVICE + bool operator!=(packed_scale_t const& rhs) const { + return !(*this == rhs); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() + rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() - rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() * rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() / rhs.get()); + } + +private: + using Storage = uint32_t; + using Stage = uint8_t; + + Storage storage[2] {}; + + CUTLASS_HOST_DEVICE + static Storage pack4(T c1, T c2, T c3, T c4) { + Storage result = 0; + result |= (static_cast(reinterpret_cast(c4)) << 24); + result |= (static_cast(reinterpret_cast(c3)) << 16); + result |= (static_cast(reinterpret_cast(c2)) << 8); + result |= static_cast(reinterpret_cast(c1)); + return result; + } + CUTLASS_HOST_DEVICE + T get() const { + auto stage = static_cast(storage[0] >> 8); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } + CUTLASS_HOST_DEVICE + T get(int idx) const { + Stage stage; + if (idx < 4) stage = static_cast(storage[0] >> (8 * idx)); + else stage = static_cast(storage[1] >> (8 * idx - 32)); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } +}; + +// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. +// Here the encodings of positive values and negative values are unified (except for the sign bit). +// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). +static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) { + + using StorageType = cutlass::int4b_t::Storage; + constexpr int pack = cute::sizeof_bits_v / 4; + const size_t host_buf_size = block_size / pack; + std::vector host_buf(host_buf_size); + cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size); + + for (auto&& d : host_buf) { + StorageType out = 0; + StorageType mask = 0x0f; + for (int i = 0; i < pack; i++) { + cutlass::int4b_t curr; + curr.storage = (d >> (i * 4)) & 0x0f; + switch (curr) { + case 1: curr.storage = StorageType(0b0111); break; // 2's complement + case 2: curr.storage = StorageType(0b0110); break; // 2's complement + case 3: curr.storage = StorageType(0b0101); break; // 2's complement + case 4: curr.storage = StorageType(0b0100); break; // 2's complement + case 5: curr.storage = StorageType(0b0011); break; // 2's complement + case 6: curr.storage = StorageType(0b0010); break; // 2's complement + case 7: curr.storage = StorageType(0b0001); break; // 2's complement + default: break; + } + out |= (curr.storage << (4 * i)) & mask; + mask <<= 4; + } + d = out; + } + + cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size); + return true; +} + +template +static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array *block_out, const size_t block_size) { + std::vector data_in(block_size); + std::vector> data_out(block_size); + + try { + cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size); + } + catch (cutlass::cuda_exception const& e) { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + + for (size_t i = 0; i < block_size; i++) { + cutlass::packed_scale_t tmp(data_in[i]); + data_out[i] = reinterpret_cast const&>(tmp); + } + + try { + cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size); + } + catch (cutlass::cuda_exception const& e) { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + return true; +} + +template +struct UnderlyingElement { + using type = T; +}; + +template +struct UnderlyingElement> { + using type = typename T::Element; +}; + +// Given a type of MMA instruction, compute a memory reordering atom that places all values +// owned by each thread in contiguous memory locations. This improves smem load vectorization, +// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order +// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses. +// In addition, we can reorder the values across several MMA instructions to get even wider +// vectorization (AtomLayout parameter) and permute the values within each instruction to get +// more optimal conversion instruction sequences (ValLayout parameter). +template , + class ValLayout = cute::Layout> +constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {}) +{ + using namespace cute; + + static_assert(is_static_v, "ValLayout must be static"); + static_assert(is_static_v, "AtomLayout must be static"); + + // 1. Choose an MMA atom to access TV layout and MN shape + // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary + using MmaAtom = decltype(SM90::GMMA::rs_op_selector>()); + using MmaTraits = MMA_Traits; + auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{}); + auto tv_layout_mma = typename MmaTraits::ALayout{}; + static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout"); + + // 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val) + // Note: this assumes A is partitioned between warps along M mode + auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma)); + auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{}); + auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp)); + auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp); + + // 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization + auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout); + + // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization) + auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout)); + auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp)); + auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset)); + auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt); + + return layout_atom; +} + +template +__global__ void reorder_tensor_kernel( + cute::Tensor S, + cute::Tensor D, + TiledCopy tiled_copy) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + + Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + + auto thread_copy = tiled_copy.get_slice(threadIdx.x); + Tensor tS = thread_copy.partition_S(gS); + Tensor tD = thread_copy.partition_D(gD); + + copy(tiled_copy, tS, tD); +} + +template +void reorder_tensor( + cute::Tensor S, + cute::Tensor D) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + static_assert(is_same_v, T>, "Type mismatch"); + + // Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread + // This avoids a race condition when writing out subbyte types (e.g. int4b_t). + auto has_major_mode = [](auto s) { + return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; }); + }; + static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})), + "Could not find stride-1 mode in destination layout"); + constexpr int N = shape_div(Int<8>{}, Int>{}); + auto val_layout = conditional_return(LayoutDst{}))>( + make_layout(make_shape(Int{}, Int<1>{}), GenColMajor{}), + make_layout(make_shape(Int<1>{}, Int{}), GenRowMajor{})); + + // Make a tiled copy with a simple row-major thread order and above layout + int constexpr NumThreads = 128; + auto const thr_layout = make_layout(make_shape(Int<1>{}, Int{})); + auto tiled_copy = make_tiled_copy(Copy_Atom{}, thr_layout, val_layout); + + // Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper + using TileShape = Shape<_16>; + auto tiled_D = group_modes<3,rank_v>(tiled_divide(D, TileShape{})); + dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))}; + + reorder_tensor_kernel<<>>(S, D, tiled_copy); + CUDA_CHECK(cudaDeviceSynchronize()); +} + +// In-place version +template +void reorder_tensor( + T const* src, + LayoutSrc const& layout_src, + T * dst, + LayoutDst const& layout_dst) +{ + using namespace cute; + reorder_tensor(make_tensor(make_gmem_ptr(src), layout_src), + make_tensor(make_gmem_ptr(dst), layout_dst)); +} + +// In-place version +template +void reorder_tensor( + T * data, + LayoutSrc const& layout_src, + LayoutDst const& layout_dst) +{ + using namespace cute; + cutlass::DeviceAllocation temp(size(layout_src)); + reorder_tensor(data, layout_src, temp.get(), layout_dst); + cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); +} + +#undef CUDA_CHECK + +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp new file mode 100644 index 0000000000000000000000000000000000000000..811ba152ab7c6e8fafc1cebdbb3726798fd16b3c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp @@ -0,0 +1,570 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params. +*/ + +#pragma once + +#include "cute/layout.hpp" +#include "cute/container/array.hpp" // cute::array +#include "cutlass/conv/convolution.h" // cutlass::conv::Operator + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides without batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride> +make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride(cute::Stride, IntT> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, int64_t> +make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, int64_t> +make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with group mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride(cute::Stride, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, StrideIntT, cute::Int<0>> +make_cute_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides for convolutions + +// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0) +// Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order +// and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout +// right in KTRSC order and can be coalesced to just k. +// We enforce this condition here with asserts. +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, cute::Int<0>> s, + cute::array shape_output, + cute::array stride_output, + cutlass::conv::Operator conv_op) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + static_assert(RankT_ >= 3u); + constexpr static int RankT = static_cast(RankT_); + + assert(stride_output[RankT-1] == 1); + cute::for_each(cute::make_seq{}, [&](auto i) { + assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]); + }); + + auto s_copy = s; + cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ? + stride_output[0] : + stride_output[RankT-2]; + return s_copy; +} + +// +// Activation tensor ((w, h, d, n), _1) for fprop kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nwc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_nwc[1]; + cute::get<0,1>(s_copy) = stride_nwc[0]; + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nhwc[3] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_nhwc[2-i]; + }); + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_ndhwc[3-i]; + }); + return s_copy; +} + +// +// Filter tensor (k, (_1, s, r, t)) for fprop kernel +// + +// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>> +make_cute_packed_stride( + cute::Stride, IntT>> s, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>> s, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>> s, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + +// +// Activation tensor (_1, (w, h, d, n)) for wgrad kernel +// +// It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad +// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nwc[2] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::get<1,0>(s_copy) = stride_nwc[1]; + cute::get<1,1>(s_copy) = stride_nwc[0]; + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nwc in dgrad is ksc. + cute::get<1,0>(s_copy) = stride_nwc[0]; + cute::get<1,1>(s_copy) = stride_nwc[1]; + } + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad +// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nhwc[3] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_nhwc[2-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nhwc in dgrad is krsc. + cute::get<1,0>(s_copy) = stride_nhwc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_nhwc[i+1]; + }); + } + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_ndhwc[3-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_ndhwc in dgrad is ktrsc. + cute::get<1,0>(s_copy) = stride_ndhwc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1]; + }); + } + return s_copy; +} + +// +// NZPQ tensor (_1, nzpq) for wgrad kernel +// + +// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nqk[2] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nqk[1]; + return s_copy; +} + +// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_npqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_npqk[3] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_npqk[2]; + return s_copy; +} + +// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nzpqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nzpqk[4] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nzpqk[3]; + return s_copy; +} + + + +// +// Wgrad output tensor (k, (_1, s, r, t), _0) +// + +// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + + +// +// Wgrad output tensor ((_1, s, r, t), k, _0) +// + +// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ksc[0]; + cute::get<0,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<0,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c38ad3f710c18e5be1bb7e01dc66d7efcd2646d9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp @@ -0,0 +1,341 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include + +// The computed infinity norm does not include +// any NaN column absolute-value sums. +struct matrix_inf_norm_result { + // Accumulate errors in double, as this is generally + // the highest precision that the examples use. + double inf_norm = 0.0; + bool found_nan = false; +}; + +// In theory, cute::Tensor, T> could be treated as a view type, +// and thus passed by value (as std::span or std::string_view would be). +// However, generic cute::Tensor are more like containers +// and thus are best passed by reference or const reference. +template +matrix_inf_norm_result +matrix_inf_norm(cute::Tensor const& host_matrix) +{ + using error_type = decltype(std::declval().inf_norm); + using element_type = typename EngineType::value_type; + + error_type inf_norm = 0.0; + bool found_nan = false; + + // Computing the infinity norm requires that we be able + // to treat the input as a matrix, with rows and columns. + const int64_t num_rows = cute::size<0>(host_matrix); + const int64_t num_cols = cute::size<1>(host_matrix); + + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + + for (int64_t i = 0; i < num_rows; ++i) { + error_type row_abs_sum = 0.0; + for(int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += abs_fn(host_matrix(i, j)); + } + if (std::isnan(row_abs_sum)) { + found_nan = true; + } + else { + inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; + } + } + + return {inf_norm, found_nan}; +} + +// Infinity norm of (X - Y). +template +matrix_inf_norm_result +matrix_diff_inf_norm(cute::Tensor const& X, + cute::Tensor const& Y) +{ + using error_type = decltype(std::declval().inf_norm); + using element_type = typename EngineType::value_type; + + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + + assert(cute::size<0>(X) == cute::size<0>(Y)); + assert(cute::size<1>(X) == cute::size<1>(Y)); + + // Computing the infinity norm requires that we be able + // to treat the input as a matrix, with rows and columns. + const int64_t num_rows = cute::size<0>(X); + const int64_t num_cols = cute::size<1>(X); + + error_type inf_norm = 0.0; + bool found_nan = false; + + for (int64_t i = 0; i < num_rows; ++i) { + error_type row_abs_sum = 0.0; + for (int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += error_type(abs_fn(element_type(X(i,j)) - + element_type(Y(i,j)))); + } + if (std::isnan(row_abs_sum)) { + found_nan = true; + } + else { + inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; + } + } + + return {inf_norm, found_nan}; +} + +template +auto +print_matrix_multiply_mollified_relative_error( + char const A_value_type_name[], + cute::Tensor const& A, + char const B_value_type_name[], + cute::Tensor const& B, + char const C_value_type_name[], + cute::Tensor const& C, + cute::Tensor const& C_ref) +{ + const auto [A_norm, A_has_nan] = matrix_inf_norm(A); + const auto [B_norm, B_has_nan] = matrix_inf_norm(B); + const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref); + const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref); + + const auto A_norm_times_B_norm = A_norm * B_norm; + const auto relative_error = A_norm_times_B_norm == 0.0 ? + diff_norm : (diff_norm / A_norm_times_B_norm); + + // For expected error bounds, please refer to the LAPACK Users' Guide, + // in particular https://netlib.org/lapack/lug/node108.html . + // Printing the infinity norm of C is a way to check + // that both the function being tested (C) + // and the reference implementation (C_ref) + // don't just do nothing (or fill with zeros). + using std::cout; + using cute::shape; + cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n' + << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n' + << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n' + << std::scientific + << "Infinity norm of A: " << A_norm << '\n' + << "Infinity norm of B: " << B_norm << '\n' + << "Infinity norm of C: " << C_norm << '\n' + << "Infinity norm of (C - C_ref): " << diff_norm << '\n'; + + if(A_norm_times_B_norm == 0.0) { + cout << "Mollified relative error: " << relative_error << '\n'; + } else { + cout << "Relative error: " << relative_error << '\n'; + } + + if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) { + cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n'; + } + return relative_error; +} + +template +auto +print_matrix_multiply_mollified_relative_error( + const char value_type_name[], + const cute::Tensor& A, + const cute::Tensor& B, + const cute::Tensor& C_computed, + const cute::Tensor& C_expected) +{ + return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, + value_type_name, C_computed, C_expected); +} + +// Take a CUTLASS HostTensor (or the like) as input, +// and return a const CuTe Tensor. +// This is useful for use with the above error printing functions. +// This implicitly "transposes" if the layout is RowMajor. +// Note that the HostTensor must be captured by nonconst reference +// in order for X.host_ref().data() to compile. +// (CUTLASS is a bit more container-y than CuTe.) +template +auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X) +{ + // The tensors were created with post-transposed extents. + const auto extents = X.extent(); + const auto shape = cute::Shape{extents[0], extents[1]}; + // Both RowMajor and ColumnMajor only store one stride. + const int LDX = X.stride(0); + const auto strides = [&]() { + using input_layout_type = typename std::decay_t::Layout; + if constexpr (std::is_same_v) { + return cute::Stride{1, LDX}; + } + else { + static_assert(std::is_same_v); + return cute::Stride{LDX, 1}; + } + }(); + const auto layout = cute::make_layout(shape, strides); + auto X_data = X.host_ref().data(); + auto X_data_const = const_cast >(X_data); + return cute::make_tensor(X_data_const, layout); +}; + + +// Returns EXIT_SUCCESS if the 2-norm relative error is exactly zero, else returns EXIT_FAILURE. +// This makes the return value suitable as the return value of main(). +template +int +print_relative_error( + std::size_t n, + T1 const& data, + T2 const& reference, + bool print_verbose = false, + bool print_error = true, + double error_margin = 0.00001) { + using std::abs; using std::sqrt; + + // Use either double or complex for error computation + using value_type = cute::remove_cvref_t; + using error_type = std::conditional_t::value, + cute::complex, + double>; + + if (print_verbose) { + std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl; + } + + double eps = 1e-200; + + double tot_error_sq = 0; + double tot_norm_sq = 0; + double tot_ind_rel_err = 0; + double max_ind_rel_err = 0; + double max_diff = 0; + for (std::size_t i = 0; i < n; ++i) { + error_type val = data[i]; + error_type ref = reference[i]; + + double aref = abs(ref); + double diff = abs(ref - val); + double rel_error = diff / (aref + eps); + + // Individual relative error + tot_ind_rel_err += rel_error; + + // Maximum relative error + max_ind_rel_err = std::max(max_ind_rel_err, rel_error); + + // Maximum delta in value error + max_diff = std::max(max_diff, diff); + + // Total relative error + tot_error_sq += diff * diff; + tot_norm_sq += aref * aref; + + if (print_verbose) { + std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl; + } + } + + double ave_rel_err = tot_ind_rel_err / double(n); + if (print_error) { + printf("Average relative error: %.3e\n", ave_rel_err); + } + + if (print_error) { + printf("Maximum relative error: %.3e\n", max_ind_rel_err); + } + + if (print_error) { + printf("Maximum difference : %.3e\n", max_diff); + } + + double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps)); + if (print_error) { + printf("Vector relative error: %.3e\n", tot_rel_err); + } + + printf("Vector reference norm: %.3e\n", sqrt(tot_norm_sq)); + + return (tot_rel_err <= error_margin) ? EXIT_SUCCESS : EXIT_FAILURE; +} + +// Overload for cute::Tensor<> +template +int +print_relative_error( + cute::Tensor data, + cute::Tensor reference, + bool print_verbose = false, + bool print_error = true, + double error_margin = 0.00001) { + assert(size(data) == size(reference)); + return print_relative_error(static_cast(size(data)), + data, reference, + print_verbose, print_error, error_margin); +} diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h new file mode 100644 index 0000000000000000000000000000000000000000..8167c91bf2330d160a78ba210449357b395964ca --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +namespace cutlass { +namespace reference { +namespace detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template function to compute an inner product. +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a + // host-only type +template +CUTLASS_HOST_DEVICE +Ctype inner_product(Atype a, Btype b, Ctype c) { + return Ctype(a) * Ctype(b) + c; +} + +/// Specialization for matrix multiplication with binary operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int bit = 0; bit < 32; bit++) { + accum += a[bit] ^ b[bit]; + } + return accum + c; +} + +/* +/// Specialization for matrix multiplication with signed 4-bit integer operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int k = 0; k < 8; k++) { + accum += a[k] * b[k]; + } + return accum + c; +} + +/// Specialization for matrix multiplication with unsigned 4-bit integer operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int k = 0; k < 8; k++) { + accum += a[k] * b[k]; + } + return accum + c; +} +*/ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Cast { + // Default behavior: convert to the destination type +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + static DstType apply(SrcType src) { return static_cast(src); }; +}; + +template <> +struct Cast { + CUTLASS_HOST_DEVICE + static int8_t apply(float src) { + // Clamp to the range of signed 8-bit integers. + return static_cast(fmaxf(-128.f, fminf(127.f, src))); + }; +}; + +template <> +struct Cast { + CUTLASS_HOST_DEVICE + static uint8_t apply(float src) { + // Clamp to the range of signed 8-bit integers. + return static_cast(fmaxf(0.f, fminf(255.f, src))); + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h new file mode 100644 index 0000000000000000000000000000000000000000..652d622586cb202ecfe69ac892978b649b5d1be7 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + + int64_t prod = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - Index; i < Rank; ++i) { + prod *= int64_t(extent[i]); + } + + coord[Rank - Index - 1] = int(idx / prod); + + int64_t residual = idx % prod; + LinearToCoordinateHelper()(coord, residual, extent); + } +}; + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &) const { + coord[Rank - 1] = int(idx); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinate { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + LinearToCoordinateHelper()(coord, idx, extent); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..7c6f803c47f5c407cf058d40bc8274a448a36dc4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h @@ -0,0 +1,1549 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Reference implementation for convolution in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv2d device reference kernel +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d Fprop kernel - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t npq = npq_start + m; + + thread_n[m] = int(npq / PQ); + + int64_t residual = npq % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + int c_per_group = problem_size.C / problem_size.groups; + int k_per_group = problem_size.K / problem_size.groups; + + // Compute convolution + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int C = 0; C < problem_size.C; ++C) { + + // Get group id of currnet channel + int c_group_idx = C / c_per_group; + + // Load from activations tensor + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { + element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C})); + } + else { + element_A[m] = ElementAccumulator(); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + int k_group_idx = thread_k / k_per_group; + + if (thread_k < problem_size.K && k_group_idx == c_group_idx) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})); + } + + tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d Fprop kernel - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_z[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, Z, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + int64_t ZPQ = PQ * problem_size.Z; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t nzpq = nzpq_start + m; + + thread_n[m] = int(nzpq / ZPQ); + + int64_t residual = nzpq % ZPQ; + thread_z[m] = int(residual / PQ); + + residual = residual % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int T = 0; T < problem_size.T; ++T) { + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int C = 0; C < problem_size.C; ++C) { + + // Load from activations tensor + int filter_t = T; + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - T; + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (thread_n[m] < problem_size.N && + d >= 0 && d < problem_size.D && + h >= 0 && h < problem_size.H && + w >= 0 && w < problem_size.W) { + + element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C})); + } + else { + element_A[m] = ElementAccumulator(); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + + if (thread_k < problem_size.K) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && + thread_z[m] < problem_size.Z && + thread_p[m] < problem_size.P && + thread_q[m] < problem_size.Q) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k})); + } + + tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } // for (n) + + } + } // for (m) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d dgrad kernel - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv2dDgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_h[kThreadM]; + int thread_w[kThreadM]; + + // Compute N, H, W coordinates for each row of a thread's tile + int64_t HW = int64_t(problem_size.H) * problem_size.W; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t nhw = nhw_start + m; + + thread_n[m] = int(nhw / HW); + + int64_t residual = nhw % HW; + thread_h[m] = int(residual / problem_size.W); + thread_w[m] = int(residual % problem_size.W); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int K = 0; K < problem_size.K; ++K) { + + // Load from activations tensor + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; + + element_A[m] = ElementAccumulator(); + + if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) { + + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) { + element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K})); + } + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + + if (thread_c < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + if (thread_c < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c})); + } + + tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d dgrad kernel - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv3dDgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_d[kThreadM]; + int thread_h[kThreadM]; + int thread_w[kThreadM]; + + // Compute N, H, W coordinates for each row of a thread's tile + int64_t HW = int64_t(problem_size.H) * problem_size.W; + int64_t DHW = HW * problem_size.D; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t ndhw = ndhw_start + m; + + thread_n[m] = int(ndhw / DHW); + + int64_t residual = ndhw % DHW; + thread_d[m] = int(residual / HW); + + residual = residual % HW; + thread_h[m] = int(residual / problem_size.W); + thread_w[m] = int(residual % problem_size.W); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int T = 0; T < problem_size.T; ++T) { + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int K = 0; K < problem_size.K; ++K) { + + // Load from activations tensor + int filter_t = T; + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - T; + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d; + int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; + + element_A[m] = ElementAccumulator(); + + if (z >= 0 && !(z % problem_size.stride_d) && + p >= 0 && !(p % problem_size.stride_h) && + q >= 0 && !(q % problem_size.stride_w)) { + + z = z / problem_size.stride_d; + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { + element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K})); + } + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + + if (thread_c < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && + thread_d[m] < problem_size.D && + thread_h[m] < problem_size.H && + thread_w[m] < problem_size.W) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + if (thread_c < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c})); + } + + tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d wgrad kernel - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 8, // shape of a threadblock in units of threads + int kCtaShapeN = 16 // shape of a threadblock in units of threads +> +__global__ void Conv2dWgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_r[kThreadN]; + int thread_s[kThreadN]; + int thread_c[kThreadN]; + + // Compute R, S, C coordinates for each row of a thread's tile + int64_t SC = int64_t(problem_size.S) * problem_size.C; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + int64_t rsc = rsc_start + n; + int64_t residual = rsc % SC; + + thread_r[n] = int(rsc / SC); + thread_s[n] = int(residual / problem_size.C); + thread_c[n] = int(residual % problem_size.C); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int N = 0; N < problem_size.N; ++N) { + for (int P = 0; P < problem_size.P; ++P) { + for (int Q = 0; Q < problem_size.Q; ++Q) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + element_A[m] = ElementAccumulator(); + + if (thread_k < problem_size.K) { + element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k})); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + // Load from activations tensor + int filter_r = thread_r[n]; + int filter_s = thread_s[n]; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - filter_r; + filter_s = problem_size.S - 1 - filter_s; + } + + int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + element_B[n] = ElementAccumulator(); + + if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]})); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + if (thread_k < problem_size.K) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]})); + } + + tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d wgrad kernel - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 8, // shape of a threadblock in units of threads + int kCtaShapeN = 16 // shape of a threadblock in units of threads +> +__global__ void Conv3dWgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_t[kThreadN]; + int thread_r[kThreadN]; + int thread_s[kThreadN]; + int thread_c[kThreadN]; + + // Compute R, S, C coordinates for each row of a thread's tile + int64_t SC = int64_t(problem_size.S) * problem_size.C; + int64_t RSC = SC * problem_size.R; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + int64_t trsc = trsc_start + n; + + thread_t[n] = int(trsc / RSC); + + int64_t residual = trsc % RSC; + thread_r[n] = int(residual / SC); + + residual = residual % SC; + thread_s[n] = int(residual / problem_size.C); + thread_c[n] = int(residual % problem_size.C); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int N = 0; N < problem_size.N; ++N) { + for (int Z = 0; Z < problem_size.Z; ++Z) { + for (int P = 0; P < problem_size.P; ++P) { + for (int Q = 0; Q < problem_size.Q; ++Q) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + element_A[m] = ElementAccumulator(); + + if (thread_k < problem_size.K) { + element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k})); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + // Load from activations tensor + int filter_t = thread_t[n]; + int filter_r = thread_r[n]; + int filter_s = thread_s[n]; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - filter_t; + filter_r = problem_size.R - 1 - filter_r; + filter_s = problem_size.S - 1 - filter_s; + } + + int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + element_B[n] = ElementAccumulator(); + + if (d >= 0 && d < problem_size.D && + h >= 0 && h < problem_size.H && + w >= 0 && w < problem_size.W && + thread_c[n] < problem_size.C) { + + element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]})); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (Q) + } // for (P) + } // for (Z) + } // for (N) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + if (thread_k < problem_size.K) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + if (thread_t[n] < problem_size.T && + thread_r[n] < problem_size.R && + thread_s[n] < problem_size.S && + thread_c[n] < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]})); + } + + tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Conv2d Fprop dispatcher - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; + int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv2dFprop< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_x, + tensor_w, + tensor_y_in, + tensor_y_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Fprop dispatcher - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q; + int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv3dFprop< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_x, + tensor_w, + tensor_y_in, + tensor_y_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dDgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W; + int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv2dDgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_w, + tensor_dx_in, + tensor_dx_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dDgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W; + int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv3dDgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_w, + tensor_dx_in, + tensor_dx_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dWgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 8; // shape of a threadblock in units of threads + int const kCtaShapeN = 16; // shape of a threadblock in units of threads + + int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C; + int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); + + kernel::Conv2dWgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_x, + tensor_dw_in, + tensor_dw_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dWgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 8; // shape of a threadblock in units of threads + int const kCtaShapeN = 16; // shape of a threadblock in units of threads + + int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C; + int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); + + kernel::Conv3dWgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_x, + tensor_dw_in, + tensor_dw_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2d( + conv::Operator convolutional_operator, + conv::Conv2dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + return Conv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + case conv::Operator::kDgrad: + return Conv2dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + case conv::Operator::kWgrad: + return Conv2dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + default: break; + } + + return Status::kErrorNotSupported; +} + +/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3d( + conv::Operator convolutional_operator, + conv::Conv3dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + return Conv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + case conv::Operator::kDgrad: + return Conv3dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + case conv::Operator::kWgrad: + return Conv3dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + default: break; + } + + return Status::kErrorNotSupported; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..7d575d522c1dd87d51f9bc58d09786393c5cfea3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h @@ -0,0 +1,385 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/device/kernel/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + // Launch a GEMM kernel + kernel::Gemm< + TensorRef, + TensorRef, + TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum) { + + compute_gemm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Gemm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add-saturate +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for XOR-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Batched GEMM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a batch of GEMMs over a set of matrices of common dimension. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp, + typename ConvertOp +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c, + AccumulatorType initial_accum) { + + static_assert( + TensorRefCollectionA::kRank == 2 && + TensorRefCollectionB::kRank == 2 && + TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2"); + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn), + batch_count + ); + + // Launch a GEMM kernel + kernel::BatchedGemm< + TensorRefCollectionA, + TensorRefCollectionB, + TensorRefCollectionC, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + initial_accum + ); +} + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c) { + + BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..bddf596214da62a7aa3177f758db3710dc1d2516 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + if (grid.y <= std::numeric_limits::max()) { + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ElementD, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); + } else { + // Using bigger thread tile size + int const kBigMblock = 4; + int const kBigNblock = 16; + + dim3 Bigblock(16, 8); + dim3 Biggrid( + (problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock), + (problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ElementD, + ConvertOp, + InnerProductOp, + kBigMblock, + kBigNblock + ><<< Biggrid, Bigblock >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ElementD = ElementC +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..48819cf6eaa565b3ec41dbbf78ae244666fd8a65 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static int const kGemmPlanarComplexBlockSize = 4; + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +__global__ void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + int const kMblock = kGemmPlanarComplexBlockSize; + int const kNblock = kGemmPlanarComplexBlockSize; + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + complex accum[kMblock][kNblock]; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block < K; ++k_block) { + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC c_ij = ComplexC(); + + if (beta.real() != ScalarType() || beta.imag() != ScalarType()) { + c_ij = tensor_c.at(coord); + } + + complex src{ + ScalarType(c_ij.real()), + ScalarType(c_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + ComplexC d_ij; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag()); + + tensor_d.at(coord) = d_ij; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = kernel::kGemmPlanarComplexBlockSize; + int const kNblock = kernel::kGemmPlanarComplexBlockSize; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + 1); + + kernel::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp new file mode 100644 index 0000000000000000000000000000000000000000..497a257d170c411d891942f62fa2c960453d03d5 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT device reference code +*/ +#pragma once + +#include + +namespace cutlass::reference::device { + +template < + class ATensor, + class BTensor, + class CTensor, + class DTensor, + class ElementAccumulator, + class ElementEpilogue> +__global__ static +void +gett_kernel( + DTensor D, + ATensor const A, + BTensor const B, + CTensor const C, + ElementEpilogue alpha, ElementEpilogue beta, + ElementAccumulator acc_init) +{ + using namespace cute; + + static_assert(DTensor::rank == 3, "(M,N,L)"); + static_assert(ATensor::rank == 3, "(M,K,L)"); + static_assert(BTensor::rank == 3, "(N,K,L)"); + static_assert(CTensor::rank == 3, "(M,N,L)"); + + assert(size<0>(A) == size<0>(D)); // M + assert(size<0>(C) == size<0>(D)); // M + assert(size<0>(B) == size<1>(D)); // N + assert(size<1>(C) == size<1>(D)); // N + assert(size<1>(A) == size<1>(B)); // K + assert(size<2>(A) == size<2>(D)); // L + assert(size<2>(B) == size<2>(D)); // L + assert(size<2>(C) == size<2>(D)); // L + + NumericConverter a_converter; + NumericConverter b_converter; + NumericConverter acc_converter; + NumericConverter source_converter; + NumericConverter output_converter; + + // Thread id to each element of D + for (int tid = threadIdx.x + blockDim.x * blockIdx.x; + tid < size(D); + tid += blockDim.x * gridDim.x) { + // (m,n,l) coordinate + auto mnl_coord = idx2crd(tid, product_each(shape(D))); + auto m = get<0>(mnl_coord); + auto n = get<1>(mnl_coord); + auto l = get<2>(mnl_coord); + + auto A_ml = A(m,_,l); + auto B_nl = B(n,_,l); + + ElementAccumulator accum = ElementAccumulator(0); + for (int k = 0; k < size<1>(A); ++k) { + ElementAccumulator a = a_converter(A_ml(k)); + ElementAccumulator b = b_converter(B_nl(k)); + accum += a * b; + } + + ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l))); + D(m,n,l) = output_converter(scaled_output); + } +} + +// Most general version +template < + class ProblemShapeMNKL, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class ElementAccumulator, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class ElementEpilogue> +void +gett( + ProblemShapeMNKL problem_shape_mnkl, + ElementA const* ptr_A, StrideA stride_a_mkl, + ElementB const* ptr_B, StrideB stride_b_nkl, + ElementAccumulator _, + ElementC const* ptr_C, StrideC stride_c_mnl, + ElementD * ptr_D, StrideD stride_d_mnl, + ElementEpilogue alpha, ElementEpilogue beta, + cudaStream_t stream = 0) { + using namespace cute; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4); + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto K = get<2>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full tensors + auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L) + auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L) + auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L) + auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L) + + dim3 dimBlock(256); + dim3 dimGrid(240); + gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0)); +} + +} // namespace cutlass::reference::device diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..6e131126a336420a2b0e843e3ead3d89fce637fa --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/device/thread/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename TensorRefA, + typename TensorRefB, + typename TensorRefC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp, + typename ConvertOp +> +__global__ void Gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRefA tensor_a, + TensorRefB tensor_b, + ScalarType beta, + TensorRefC tensor_c, + TensorRefC tensor_d, + AccumulatorType initial_accum) { + + // Map each thread to a unique tile of the output matrix + MatrixCoord output_coord( + MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), + MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) + ); + + // Compute the general matrix product + thread::Gemm< + TensorRefA, + TensorRefB, + TensorRefC, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + > gemm(initial_accum); + + gemm.multiply_add( + problem_size, + tensor_a, + tensor_b, + output_coord); + + gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp, + typename ConvertOp +> +__global__ void BatchedGemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRefCollectionA tensor_collection_a, + TensorRefCollectionB tensor_collection_b, + ScalarType beta, + TensorRefCollectionC tensor_collection_c, + AccumulatorType initial_accum) { + + // Obtain batch ID + int batch_id = blockIdx.z; + + // Dereference based on batch_id + typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id); + typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id); + typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id); + + // Map each thread to a unique tile of the output matrix + MatrixCoord output_coord( + (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn, + (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow + ); + + // Compute the general matrix product + thread::Gemm< + typename TensorRefCollectionA::TensorRef, + typename TensorRefCollectionB::TensorRef, + typename TensorRefCollectionC::TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + > gemm(initial_accum); + + gemm.multiply_add( + problem_size, + tensor_a, + tensor_b, + output_coord); + + gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h new file mode 100644 index 0000000000000000000000000000000000000000..149e4b2e00e2ac8130cee9dc189a539ba3a70297 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to initialize tensor to uniform random distribution +template +__global__ void TensorInitializeUniform( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + double range = dist.uniform.max - dist.uniform.min; + + double rnd = curand_uniform(&rng_state[threadIdx.x]); + + rnd = dist.uniform.min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + if (dist.int_scale >= 0) { + rnd = double(int(rnd * double(1 << dist.int_scale))); + *tensor = T(rnd / double(1 << dist.int_scale)); + } else { + *tensor = T(rnd); + } + + tensor += ldm; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to initialize tensor to uniform distribution +template +__global__ void TensorInitializeGaussian( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + double rnd = curand_normal(&rng_state[threadIdx.x]); + + rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd; + + if (dist.int_scale >= 0) { + rnd = double(int(rnd * double(1 << dist.int_scale))); + *tensor = T(rnd / double(1 << dist.int_scale)); + } else { + *tensor = T(rnd); + } + } + } +} + +/// Kernel to initialize tensor to an identity matrix +template +__global__ void TensorInitializeLinear( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + *tensor = + dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx; + } + } +} + +/// Kernel to initialize tensor to an identity matrix +template +__global__ void TensorInitializeIdentity( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + *tensor = (c_idx == s_idx ? T(1) : T(0)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..3223cb2056ba6d88f47f7b117392a56e325d0ce7 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/fast_math.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Constructor for general rank + __inline__ __device__ + TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { + + int64_t product = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - RankRemaining; i < Rank; ++i) { + product *= size[i]; + } + + coord[Rank - 1 - RankRemaining] = index / product; + int64_t remaining = index % product; + + TensorForEachHelper(func, size, coord, remaining); + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Constructor for fastest changing rank + __inline__ __device__ + TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { + + coord[Rank - 1] = index; + + if (coord < size) { + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel calls a functor for each element in a tensor's index space +template +__global__ void TensorForEach(Coord size, Params params = Params()) { + + Func func(params); + + int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + int64_t max_index = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + max_index *= size[i]; + } + + CUTLASS_PRAGMA_NO_UNROLL + while (index < max_index) { + Coord coord; + + detail::TensorForEachHelper(func, size, coord, index); + index += blockDim.x * gridDim.x; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel calls a functor for each element along a tensor's diagonal +template +__global__ void TensorDiagonalForEach(Coord size, Params params, int start, int end) { + + Func func(params); + + int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start; + + if (index < end) { + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] = index; + } + + func(coord); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params) { + + Func func(params); + + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + + for (; index < capacity; index += blockDim.x * gridDim.x) { + ReferenceFactory::get(ptr, index) = func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..2e76fe52b06f9bb1a033c736f94fa01961ce664d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h @@ -0,0 +1,355 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + assert(M=N); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x B^T (Symmetric) or A x B^H (Hermitian) + // complex conjugation on operandB (b_t) is function of blas3 computation + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_b.at(MatrixCoord(col, k_block))) : + tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(b_t); + + // complex conjugation is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + + // B x A^T (Symmetric) or B x A^H (Hermitian) + // complex conjugation on operandB (a_t) is function of blas3 computation + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))): + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType b_ik = ComputeType(b); + ComputeType a_jk = ComputeType(a_t); + + // complex conjugation here is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_ik = conj(b_ik); + } + // complex conjugation here is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_jk = conj(a_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::Rank2KComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + fill_mode_c, + blas_mode, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h new file mode 100644 index 0000000000000000000000000000000000000000..1999730f6d24e69aef152aa332fae68af57a9c40 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" + +#include "cutlass/util/distribution.h" + +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template +__global__ void BlockCompareEqual( + int *equal, + Element const *ptr_A, + Element const *ptr_B, + size_t capacity) { + + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + for (; idx < capacity; idx += gridDim.x * blockDim.x) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (a != b) { + *equal = 0; + + return; + } + } +} + +template +__global__ void BlockCompareRelativelyEqual( + int *equal, + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + Element epsilon, + Element nonzero_floor) { + + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + for (; idx < capacity; idx += gridDim.x * blockDim.x) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (!relatively_equal(a, b, epsilon, nonzero_floor)) { + *equal = 0; + return; + } + } +} + +} // namespace kernel + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performs a bit-level equality check between two blocks +template +bool BlockCompareEqual( + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + int equal_flag = 1; + int *device_equal_flag = nullptr; + + if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { + throw std::runtime_error("Failed to allocate device flag."); + } + + if (cudaMemcpy( + device_equal_flag, + &equal_flag, + sizeof(int), + cudaMemcpyHostToDevice) != cudaSuccess) { + + throw std::runtime_error("Failed to copy equality flag to device."); + } + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockCompareEqual)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockCompareEqual<<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity); + + cudaStreamSynchronize(stream); + + if (cudaMemcpy( + &equal_flag, + device_equal_flag, + sizeof(int), + cudaMemcpyDeviceToHost) != cudaSuccess) { + + cudaFree(device_equal_flag); + + throw std::runtime_error("Failed to copy equality flag from device."); + } + + cudaFree(device_equal_flag); + + return equal_flag; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performs a bit-level equality check between two blocks +template +bool BlockCompareRelativelyEqual( + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + Element epsilon, + Element nonzero_floor, + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + int equal_flag = 1; + int *device_equal_flag = nullptr; + + if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { + throw std::runtime_error("Failed to allocate device flag."); + } + + if (cudaMemcpy( + device_equal_flag, + &equal_flag, + sizeof(int), + cudaMemcpyHostToDevice) != cudaSuccess) { + + throw std::runtime_error("Failed to copy equality flag to device."); + } + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockCompareRelativelyEqual)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockCompareRelativelyEqual<<< grid, block, 0, stream >>>( + device_equal_flag, + ptr_A, + ptr_B, + capacity, + epsilon, + nonzero_floor + ); + + cudaStreamSynchronize(stream); + + if (cudaMemcpy( + &equal_flag, + device_equal_flag, + sizeof(int), + cudaMemcpyDeviceToHost) != cudaSuccess) { + + cudaFree(device_equal_flag); + + throw std::runtime_error("Failed to copy equality flag from device."); + } + + cudaFree(device_equal_flag); + + return equal_flag; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // reference +} // cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h new file mode 100644 index 0000000000000000000000000000000000000000..a19b42825f6efb4a39466fe1cfc182ab7d831079 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -0,0 +1,2075 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) + +// Standard Library includes +#include +#include +#include +#include +#include + +#endif + +// CUDA includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_view.h" +#include "cutlass/blas3.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/layout/vector.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" +#include "cutlass/util/distribution.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +CUTLASS_DEVICE +FloatType random_normal_float(curandState_t *state) { + return curand_normal(state); +} + +template <> +CUTLASS_DEVICE +double random_normal_float(curandState_t *state) { + return curand_normal_double(state); +} + +template +CUTLASS_DEVICE +FloatType random_uniform_float(curandState_t *state) { + return curand_uniform(state); +} + +template <> +CUTLASS_DEVICE +double random_uniform_float(curandState_t *state) { + return curand_uniform_double(state); +} + +template +struct RandomGaussianFunc { + + using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element mean_ = 0, + Element stddev_ = 1, + int int_scale_ = -1, + int exclude_zero_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd = random_normal_float(&rng_state); + rnd = params.mean + params.stddev * rnd; + + Element result; + if (params.int_scale >= 0) { + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); + } + else { + result = Element(rnd); + } + + if (params.exclude_zero >=0 && result == Element(0.0)) { + if (rnd > FloatType(0)) { + rnd += FloatType(1); + } else { + rnd -= FloatType(1); + } + result = Element(rnd); + } + + return result; + } +}; + + +template +struct RandomGaussianFunc> { + + using Element = complex; + using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Real mean_ = 0, + Real stddev_ = 1, + int int_scale_ = -1, + int exclude_zero_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd_r = random_normal_float(&rng_state); + FloatType rnd_i = random_normal_float(&rng_state); + rnd_r = params.mean + params.stddev * rnd_r; + rnd_i = params.mean + params.stddev * rnd_i; + + Element result; + if (params.int_scale >= 0) { + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); + + result = { + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + if (params.exclude_zero >= 0 && + result.real() == Real(0.0) && + result.imag() == Real(0.0)) { + + if (rnd_r > FloatType(0)) { + rnd_r += FloatType(1); + } else { + rnd_r -= FloatType(1); + } + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomGaussianFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomGaussianFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = typename RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomGaussianFunc(Params const ¶ms): params(params), random(params.random) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + typename RealType::Type mean = Element(0), ///< Gaussian distribution's mean + typename RealType::Type stddev = Element(1), ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomGaussianFunc; + using Func = detail::TensorFillRandomGaussianFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, typename RandomFunc::Params(seed, mean, stddev, bits, exclude_zero)), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template ///< Element type +void BlockFillRandomGaussian( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type mean, ///< Gaussian distribution's mean + typename RealType::Type stddev, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomGaussianFunc; + + typename RandomFunc::Params params(seed, mean, stddev, bits); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random uniform distribution +template ///< Element type +struct RandomUniformFunc { + + using FloatType = typename std::conditional< + (sizeof(Element) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Element) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType max; + int int_scale; + double pnan; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element max_ = 1, + Element min = 0, + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + range(static_cast(max_) - static_cast(min)), + max(static_cast(max_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero >= 0) { + range = (min == Element(0)) ? range - FloatType(1): range; + max = (max_ == Element(0)) ? max - FloatType(1): max; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(NAN); + } + } + + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.max - params.range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); + } + else { + result = Element(rnd); + } + + if (params.exclude_zero >=0 && result == Element(0.0)) { + if (rnd > FloatType(0)) { + rnd = std::min(params.max, rnd + FloatType(1)); + } else { + rnd = std::max((params.max - params.range), rnd - FloatType(1)); + } + result = Element(rnd); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template +struct RandomUniformFunc> { + + using Element = complex; + + using FloatType = typename std::conditional< + (sizeof(Real) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Real) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType min; + int int_scale; + double pnan; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + FloatType max = 1, + FloatType min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + range(static_cast(max - min_)), + min(static_cast(min_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero >= 0) { + min = (min == FloatType(0)) ? min + FloatType(1): min; + range = (max == FloatType(0)) ? range - FloatType(1): range; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(Real(NAN), Real(NAN)); + } + } + + FloatType rnd_r = random_uniform_float(&rng_state); + FloatType rnd_i = random_uniform_float(&rng_state); + + rnd_r = params.min + params.range * rnd_r; + rnd_i = params.min + params.range * rnd_i; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); + + result = { + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + if (params.exclude_zero >= 0 && + result.real() == Real(0.0) && + result.imag() == Real(0.0)) { + + if (rnd_r > FloatType(0)) { + rnd_r = std::min(params.min + params.range, rnd_r + FloatType(1)); + } else { + rnd_r = std::max((params.min), rnd_r - FloatType(1)); + } + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + +/// Computes a random uniform distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomUniformFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomUniformFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomUniformFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + typename RealType::Type max = Element(1), ///< upper bound of distribution + typename RealType::Type min = Element(0), ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomUniformFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, max, min, bits, pnan, exclude_zero); + + TensorForEach( + view.extent(), + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type max, ///< upper bound of distribution + typename RealType::Type min, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomUniformFunc; + + typename RandomFunc::Params params(seed, max, min, bits, pnan); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random sparse meta +template ///< Element type +struct RandomSparseMetaFunc { + + using FloatType = float; + + using IntType = int32_t; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + int MetaSizeInBits; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), + MetaSizeInBits(MetaSizeInBits_) { + if (MetaSizeInBits_ == 2) { + range = 6; + } + else if (MetaSizeInBits_ == 4) { + range = 2; + } + else { + throw std::invalid_argument("Invalid MetaSizeInBits"); + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomSparseMetaFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element *MetaArray = + (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.range * rnd; + Element meta = MetaArray[(int)rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomSparseMetaFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomSparseMetaFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, MetaSizeInBits); + + TensorForEach( + view.extent(), + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomSparseMetaFunc; + + typename RandomFunc::Params params(seed, MetaSizeInBits); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element diag; + Element other; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element diag_ = Element(1), + Element other_ = Element(0) + ): + view(view_), diag(diag_), other(other_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Updates the tensor + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + params.view.at(coord) = (is_diag ? params.diag : params.other); + } +}; + +// Overwrites the elements of a tensor with a uniform value depending on fill mode +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillPartialFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element element; + FillMode fill_mode; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params(): fill_mode(FillMode::kNone) { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, + Element element_, + FillMode fill_mode_ + ): + view(view_), element(element_), fill_mode(fill_mode_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorFillPartialFunc(Params const ¶ms): params(params) { + + } + + /// Overwrites the element if it is within the covered region. + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool predicate = true; + + switch (params.fill_mode) { + case FillMode::kFull: + predicate = true; + break; + + case FillMode::kLower: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] < coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kUpper: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] > coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kDiagonal: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] != coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kNone: // fall-through + + default: + predicate = false; + break; + } + + if (predicate) { + params.view.at(coord) = params.element; + } + } +}; + + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorClearPartialFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// + static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices"); + + /// Parameters structure + struct Params { + TensorView view{}; + Element element{}; + FillMode fill_mode{FillMode::kNone}; + int alignment{0}; + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorClearPartialFunc(Params const ¶ms): params(params) { + + } + + /// Overwrites the element if it is within the covered region. + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool predicate = true; + + switch (params.fill_mode) { + + case FillMode::kLower: + if ((coord[0] >= coord[1]) || + ((coord[1] - coord[0]) >= params.alignment)) { + predicate = false; + break; + } + break; + + case FillMode::kUpper: + if ((coord[0] <= coord[1]) || + ((coord[0] - coord[1]) >= params.alignment)) { + predicate = false; + break; + } + break; + + case FillMode::kNone: // fall-through + + default: + predicate = false; + break; + } + + if (predicate) { + params.view.at(coord) = params.element; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor everywhere with a unique value for its diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillDiagonal( + TensorView view, ///< destination tensor + Element diag = Element(1), ///< value to write in the diagonal + Element other = Element(0), ///< value to write off the diagonal + cudaStream_t stream = nullptr) { + + typedef detail::TensorFillDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, diag, other), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are +/// not written. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillPartial( + TensorView view, ///< destination tensor + Element element, + FillMode fill_mode, + cudaStream_t stream = nullptr) { + + typedef detail::TensorFillPartialFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, element, fill_mode), + stream + ); +} + +/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side +/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros) +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorClearPartial( + TensorView view, ///< destination tensor + Element element, + FillMode fill_mode, + int alignment, + cudaStream_t stream = nullptr) { + + typedef detail::TensorClearPartialFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params{view, element, fill_mode, alignment}, + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorView view, ///< destination tensor + Element val = Element(0), ///< value to uniformly fill it with + cudaStream_t stream = nullptr) { + + TensorFillDiagonal(view, val, val, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor's diagonal with 1 and 0 everywhere else. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillIdentity( + TensorView view, ///< destination tensor + cudaStream_t stream = nullptr) { + + TensorFillDiagonal(view, Element(1), Element(0), stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element diag; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + Element diag_ = Element(1) + ): + view(view_), diag(diag_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorUpdateDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (is_diag) { + params.view.at(coord) = params.diag; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateDiagonal( + TensorView view, ///< destination tensor + Element diag = Element(1), + cudaStream_t stream = nullptr) { + + typedef detail::TensorUpdateDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, diag), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateOffDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element other; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + Element other_ = Element(0) + ): + view(view_), other(other_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorUpdateOffDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (!is_diag) { + params.view.at(coord) = params.other; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateOffDiagonal( + TensorView view, ///< destination tensor + Element other = Element(1), + cudaStream_t stream = nullptr) { + + typedef detail::TensorUpdateOffDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, other), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillLinearFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Array v; + Element s; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Array const & v_, + Element s_ = Element(0) + ): + view(view_), v(v_), s(s_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillLinearFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element sum = params.s; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + if constexpr (is_complex::value) { + if constexpr (sizeof_bits::value <= 32) { + sum = Element(static_cast>(sum) + + static_cast>(params.v[i]) * static_cast>(coord[i])); + } + } + else if constexpr (sizeof_bits::value <= 32) { + if constexpr (std::numeric_limits::is_integer) { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + else { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + } + else { + sum += params.v[i] * coord[i]; + } + } + + params.view.at(coord) = sum; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillLinear( + TensorView view, ///< destination tensor + Array const & v, + Element s = Element(0), + cudaStream_t stream = nullptr) { + + using Func = detail::TensorFillLinearFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, v, s), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr, + int exclude_zero = -1 ///< If non-negative, excludes 0. + /// Note that setting this flag will result in more 1's, + /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. + ) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + exclude_zero, + stream); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + dist.uniform.pnan, + exclude_zero, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + + using Layout = layout::PackedVectorLayout; + Layout::TensorCoord size(static_cast(capacity)); // -Wconversion + Layout layout = Layout::packed(size); + TensorView view(ptr, layout, size); + + Array c{}; + c[0] = v; + + TensorFillLinear(view, c, s); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillRandom( + Element *ptr, + size_t capacity, + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + BlockFillRandomGaussian( + ptr, + capacity, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + stream); + } + else if (dist.kind == Distribution::Uniform) { + BlockFillRandomUniform( + ptr, + capacity, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + dist.uniform.pnan, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorCopyDiagonalInFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element const *ptr; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Element const *ptr_ + ): + view(view_), ptr(ptr_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorCopyDiagonalInFunc(Params const ¶ms): params(params) { + + } + + /// Only update the diagonal element + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + bool is_diagonal = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[0]) { + is_diagonal = false; + } + } + if (is_diagonal) { + params.view.at(coord) = params.ptr[coord[0]]; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies a diagonal in from host memory without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalIn( + TensorView view, ///< destination tensor + Element const *ptr, ///< dense buffer of elements + cudaStream_t stream = nullptr) { + + using Func = detail::TensorCopyDiagonalInFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorCopyDiagonalOutFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element *ptr; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Element *ptr_ + ): + view(view_), ptr(ptr_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorCopyDiagonalOutFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + bool is_diagonal = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[0]) { + is_diagonal = false; + } + } + if (is_diagonal) { + params.ptr[coord[0]] = params.view.at(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies the diagonal of a tensor into a dense buffer in host memory. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalOut( + Element *ptr, ///< dense buffer of elements + TensorView view, ///< source tensor + cudaStream_t stream = nullptr) { + + using Func = detail::TensorCopyDiagonalOutFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..ba2dfd85c47b8c9450c348de32dccb7f1be9c3c1 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/reference/device/kernel/tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Launches a kernel calling a functor for each element in a tensor's index space. +template +struct TensorForEach { + + /// Constructor performs the operation. + TensorForEach( + Coord size, Params params = Params(), + int grid_size = 0, int block_size = 0, + cudaStream_t stream = nullptr) { + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::TensorForEach)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::TensorForEach<<< grid, block, 0, stream >>>(size, params); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Launches a kernel calling a functor for each element along a tensor's diagonal +template +struct TensorDiagonalForEach { + + /// Constructor performs the operation + TensorDiagonalForEach( + Coord size, Params params = Params(), + int start = 0, int end = -1, + int block_size = 128, cudaStream_t stream = nullptr) { + + if (end < 0) { + end = size.min(); + } + + dim3 block(block_size, 1, 1); + dim3 grid((end - start + block_size - 1) / block_size, 1, 1); + + kernel::TensorDiagonalForEach<<< grid, block, 0, stream >>>( + size, params, start, end); + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockForEach { + + /// Constructor performs the operation. + BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params = typename Func::Params(), + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockForEach)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockForEach<<< grid, block, 0, stream >>>(ptr, capacity, params); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..3e6d7b300f34fec6aec96e72f78427cf677936b4 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h @@ -0,0 +1,514 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/detail/linear_to_coordinate.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t size = view.size(); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + + // Fetch element + Element x = view.at(coord); + + // Transform + identity = reduce(identity, transform(x)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + auto size = static_cast(view_A.size()); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + + // Fetch element + Element a = view_A.at(coord); + Element b = view_B.at(coord); + + // Transform + identity = reduce(identity, transform(a, b)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + + +template < + typename ComputeType, + typename ReduceOp, + int kBlockSize = 32 +> +__global__ void TensorTransformReduceFinalize( + ComputeType *workspace, + ComputeType identity, + int workspace_size, + ReduceOp reduce) { + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) { + identity = reduce(identity, workspace[idx]); + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[0] = identity; + } +} + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + cudaStreamSynchronize(stream); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of two tensors, zipped together +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Extents must be equal."); + } + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view_A, view_B, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + cudaStreamSynchronize(stream); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view_A, + view_B, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSq(view, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform, stream, workspace_size); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h new file mode 100644 index 0000000000000000000000000000000000000000..0e3d99ddf845810249f909fbdee4505a0a732c4f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h @@ -0,0 +1,141 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorReLuFunc { + + /// View type + using TensorView = TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element threshold; + + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element threshold_ = Element(0) + ): + view(view_), threshold(threshold_) { + + } + }; + + // + // Data members + // + + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorReLuFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element const & value = params.view.at(coord); + params.view.at(coord) = (value < params.threshold) ? params.threshold : value; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Apply ReLu on a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorReLu( + TensorView view, ///< destination tensor + Element threshold = Element(0)) { ///< ReLu threshold + + using Func = detail::TensorReLuFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, threshold) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..dd11f96bd92f6995590e61665e41a3e830bceacd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h @@ -0,0 +1,186 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace thread { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level blocked general matrix product. +// +// Note, this is a reference implementation. Performance is not expected to approach peak. +// +template < + typename TensorRefA, + typename TensorRefB, + typename TensorRefC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +struct Gemm { + + using ElementA = typename TensorRefA::Element; + using ElementB = typename TensorRefB::Element; + using ElementC = typename TensorRefC::Element; + + // + // Data members + // + + /// Tile for A operand + ElementA A_tile[OutputTile::kColumn]; + + /// Tile for B operand + ElementB B_tile[OutputTile::kRow]; + + /// Tile for Accumulator + AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow]; + + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + Gemm(AccumulatorType initial_accum = AccumulatorType(0)) { + + // Clear fetch registers + for (int i = 0; i < OutputTile::kColumn; ++i) { + A_tile[i] = ElementA(0); + } + + for (int j = 0; j < OutputTile::kRow; ++j) { + B_tile[j] = ElementB(0); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kColumn; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kRow; ++i) { + accum[j][i] = initial_accum; + } + } + } + + /// Computes a matrix product + CUTLASS_HOST_DEVICE + Gemm & multiply_add( + gemm::GemmCoord problem_size, + TensorRefA tensor_a, + TensorRefB tensor_b, + MatrixCoord output_coord = MatrixCoord()) { + + InnerProductOp inner_product_op; + + // Loop over the GEMM K dimension + CUTLASS_PRAGMA_NO_UNROLL + for (int k = 0; k < problem_size.k(); ++k) { + + // Fetch a slice of the A matrix + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kColumn; ++i) { + if (output_coord.row() + i < problem_size.m()) { + A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k)); + } + } + + // Fetch a slice of the B matrix + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kRow; ++j) { + if (output_coord.column() + j < problem_size.n()) { + B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j)); + } + } + + // Compute an accumulated matrix product + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kRow; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kColumn; ++i) { + accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]); + } + } + } + + return *this; + } + + /// Performs linear scaling of matrix product and updates output tensor + CUTLASS_HOST_DEVICE + Gemm & epilogue( + gemm::GemmCoord problem_size, + ScalarType alpha, + ScalarType beta, + TensorRefC tensor_c, + TensorRefC tensor_d, + MatrixCoord output_coord = MatrixCoord()) { + + ConvertOp convert_op; + + // Update the output tensor + for (int j = 0; j < OutputTile::kRow; ++j) { + for (int i = 0; i < OutputTile::kColumn; ++i) { + MatrixCoord coord = output_coord + MatrixCoord(i, j); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[j][i]) + + beta * ScalarType(tensor_c.at(coord)) + ); + } + } + } + + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..57443325629ea4e5d855fe18f94c73b10a71a73a --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp @@ -0,0 +1,782 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for CONV in host-side code. +*/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +#include "cute/tensor.hpp" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<5>(activation)) && + (n_ >= 0 && n_ < size<4>(activation)) && + (d_ >= 0 && d_ < size<3>(activation)) && + (h_ >= 0 && h_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<4>(activation)) && + (n_ >= 0 && n_ < size<3>(activation)) && + (h_ >= 0 && h_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<3>(activation)) && + (n_ >= 0 && n_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +} // namespace detail + +template< + class ElementAcc_, + class ElementScalar_, + class ElementCompute_, + class ElementC_, + class ElementOut_, + bool ResidualAdd_, + class TensorAlpha_, + class TensorBeta_, + class TensorBias_, + class ActivationFunctor_ = cutlass::epilogue::thread::Identity +> +struct ConvEpilogueFusionParams { + using ElementAcc = ElementAcc_; + using ElementScalar = ElementScalar_; + using ElementCompute = ElementCompute_; + using ElementC = ElementC_; + using ElementOut = ElementOut_; + using TensorAlpha = TensorAlpha_; + using TensorBeta = TensorBeta_; + using TensorBias = TensorBias_; + using ActivationFunctor = ActivationFunctor_; + static constexpr bool ResidualAdd = ResidualAdd_; // Source added after activation + + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorAlpha tensor_alpha{}; + TensorBeta tensor_beta{}; + TensorBias tensor_bias{}; +}; + +template< + cutlass::conv::Operator ConvOp, + int NumSpatialDims, + class TensorA, + class TensorB, + class TensorC, + class TensorD, + class ShapePadding, + class StrideTraversal, + class ShapeDilation, + class EpilogueFusionParams +> +struct ConvReferenceImpl { + // Hard code accumlulator type to float to avoid data lost in accumulating add. + using ElementAcc = cutlass::platform::conditional_t, double, float>; + using ElementC = typename EpilogueFusionParams::ElementC; + using ElementOut = typename EpilogueFusionParams::ElementOut; + using ElementScalar = typename EpilogueFusionParams::ElementScalar; + using ElementCompute = typename EpilogueFusionParams::ElementCompute; + using ElementBias = typename EpilogueFusionParams::TensorBias::value_type; + using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor; + + // Input related converter + NumericConverter acc_converter; + NumericConverter residual_converter; + NumericConverter bias_converter; + // Scale related converter + NumericConverter scale_converter; + // Output related converter + NumericConverter output_converter; + + EpilogueFusionParams& epi_fusion_params_; + TensorA const& tensor_a_; + TensorB const& tensor_b_; + TensorC const& tensor_c_; + TensorD& tensor_d_; + + ShapePadding const& padding_; + StrideTraversal const& tstride_; + ShapeDilation const& dilation_; + + // Epilogue activation operation + ActivationFunctor epi_activation; + + ConvReferenceImpl( + TensorA const& tensor_a, + TensorB const& tensor_b, + TensorC const& tensor_c, + TensorD& tensor_d, + ShapePadding const& padding, + StrideTraversal const& tstride, + ShapeDilation const& dilation, + EpilogueFusionParams& epi_fusion_params) + : tensor_a_(tensor_a), + tensor_b_(tensor_b), + tensor_c_(tensor_c), + tensor_d_(tensor_d), + padding_(padding), + tstride_(tstride), + dilation_(dilation), + epi_fusion_params_(epi_fusion_params) + { + static_assert(rank(ShapePadding{}) == rank(ShapeDilation{})); + static_assert(rank(ShapePadding{}) == rank(StrideTraversal{})); + } + + void compute_reference() { + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + fprop_reference(cute::Int{}); + } + else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + dgrad_reference(cute::Int{}); + } + else { + wgrad_reference(cute::Int{}); + } + } + +private: + // Specialization for 1D fprop kernel + void fprop_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, w, c, g)) { + auto a = tensor_a_(c, w, n, g); + auto b = tensor_b_(c, s, k, g); + accumulator += ElementAcc(a * b); + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + } + tensor_d_(k, q, n, g) = output_converter(output); + } + } + } + } + + } + + // Specialization for 2D fprop kernel + void fprop_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = size<3>(tensor_d_); + int32_t P = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c, g)) { + auto a = tensor_a_(c, w, h, n, g); + auto b = tensor_b_(c, s, r, k, g); + accumulator += ElementAcc(a * b); + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + } + tensor_d_(k, q, p, n, g) = output_converter(output); + } + } + } + } + } + + } + + // Specialization for 3D fprop kernel + void fprop_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = size<4>(tensor_d_); + int32_t Z = size<3>(tensor_d_); + int32_t P = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t T = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t z = 0; z < Z; ++z) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c, g)) { + auto a = tensor_a_(c, w, h, d, n, g); + auto b = tensor_b_(c, s, r, t, k, g); + accumulator += ElementAcc(a * b); + } + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + } + tensor_d_(k, q, p, z, n, g) = output_converter(output); + } + } + } + } + } + } + + } + + // Specialization for 1D dgrad kernel + void dgrad_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, n, g) * tensor_b_(c, s, k, g)); + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + } + tensor_d_(c, w, n, g) = output_converter(output); + } + } + } + } + + } + + // Specialization for 2D dgrad kernel + void dgrad_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = size<3>(tensor_d_); + int32_t H = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (p % cute::get<1>(tstride_) == 0) { + p /= cute::get<1>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, p, n, g) * tensor_b_(c, s, r, k, g)); + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + } + + tensor_d_(c, w, h, n, g) = output_converter(output); + } + } + } + } + } + + } + + // Specialization for 3D dgrad kernel + void dgrad_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = size<4>(tensor_d_); + int32_t D = size<3>(tensor_d_); + int32_t H = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<4>(tensor_b_); + int32_t T = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t d = 0; d < D; ++d) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); + int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (p % cute::get<1>(tstride_) == 0) { + p /= cute::get<1>(tstride_); + } else { + continue; + } + + if (z % cute::get<2>(tstride_) == 0) { + z /= cute::get<2>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, p, z, n, g) * tensor_b_(c, s, r, t, k, g)); + } + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + } + tensor_d_(c, w, h, d, n, g) = output_converter(output); + } + } + } + } + } + } + + } + + // Specialization for 1D wgrad kernel + void wgrad_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, n, g); + auto xformed_act = + tensor_a_(k, q, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + } + tensor_d_(c, s, k, g) = output_converter(output); + } + } + } + } + } + + // Specialization for 2D wgrad kernel + void wgrad_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t R = size<2>(tensor_d_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, h, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, n, g); + auto xformed_act = + tensor_a_(k, q, p, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + } + tensor_d_(c, s, r, k, g) = output_converter(output); + } + } + } + } + } + } + + // Specialization for 3D wgrad kernel + void wgrad_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = + size<4>(tensor_a_); + int32_t Z = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t T = size<3>(tensor_d_); + int32_t R = size<2>(tensor_d_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0 ; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t z = 0; z < Z; ++z) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, d, n, g); + auto xformed_act = + tensor_a_(k, q, p, z, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + } + tensor_d_(c, s, r, t, k, g) = output_converter(output); + } + } + } + } + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..73298e5794f0f2658ef18fb3f46466c400fc831e --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h @@ -0,0 +1,802 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Reference implementation for convolution in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Forward propagation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// y = conv2d(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + int group_idx = k / (problem_size.K / problem_size.groups); + int channels_per_group = problem_size.C / problem_size.groups; + + ElementAccumulator acc = ElementAccumulator(); + + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < channels_per_group; ++c) { + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { + + ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group}); + ElementB b = tensor_w.at({k, r, s, c}); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k)); + } + + tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } +} + +/// Depthwise-separable convolution +template , + typename InnerProductOp = multiply_add> +void Depsep_Fprop(cutlass::TensorView tensor_A, + cutlass::TensorView tensor_B, + cutlass::TensorView tensor_C, + cutlass::TensorView tensor_D, + ElementCompute alpha, + ElementCompute beta, + cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(), + cutlass::Coord<2> conv_stride = cutlass::Coord<2>(), + cutlass::Coord<2> dilation = cutlass::Coord<2>(), + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < tensor_C.extent().n(); ++n) { + for (int p = 0; p < tensor_C.extent().h(); ++p) { + for (int q = 0; q < tensor_C.extent().w(); ++q) { + for (int g = 0; g < tensor_C.extent().c(); ++g) { + ElementAccumulator acc = ElementAccumulator(); + for (int r = 0; r < tensor_B.extent().h(); ++r) { + for (int s = 0; s < tensor_B.extent().w(); ++s) { + + // input activation H and W + int h = p * conv_stride[0] - padding[0] + r * dilation[0]; + int w = q * conv_stride[1] - padding[2] + s * dilation[1]; + + if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) { + ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g)); + + ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation) + ? tensor_B.at(cutlass::make_Coord(g, r, s, 0)) + : tensor_B.at(cutlass::make_Coord( + g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0)); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g)); + tensor_D.at(cutlass::make_Coord(n, p, q, g)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dgrad / Deconv +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dDgrad( + cutlass::conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + bool is_deconv = false) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int h = 0; h < problem_size.H; ++h) { + for (int w = 0; w < problem_size.W; ++w) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int k = 0; k < problem_size.K; ++k) { + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; + + if (p >= 0 && (p % problem_size.stride_h) == 0 && + q >= 0 && (q % problem_size.stride_w) == 0) { + + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; +#if 0 + std::cout << "row:" + << n * problem_size.H * problem_size.W + + h * problem_size.W + + w << " " + << "n, p, q: (" + << n << ", " + << p << ", " + << q << ") * " + << "r, s: (" + << r << ", " + << s << ") [" + << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]" + << std::endl; +#endif + if (p < problem_size.P && q < problem_size.Q) { + + ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, r, s, c)); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + + } // for (K) + } // for (S) + } // for (R) + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c)); + } + + tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (W) + } // for (H) + } // for (N) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Wgrad +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dWgrad( + cutlass::conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta) { + + InnerProductOp inner_product_op; + ConvertOp convert_op; + + // Apply MMA and accumulate ElementAccumulator + for (int k = 0; k < problem_size.K; ++k) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + + cutlass::Tensor4DCoord b_coord; + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + b_coord = make_Coord( + n, + p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, + q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, + c); + + if (b_coord.h() < problem_size.H && b_coord.h() >= 0 && + b_coord.w() < problem_size.W && b_coord.w() >= 0) { + + ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k))); + ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); + acc = inner_product_op(a, b, acc); + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c)); + } + + tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (S) + } // for (R) + } // for (K) +} + +/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2d( + conv::Operator convolutional_operator, + conv::Conv2dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + Conv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + case conv::Operator::kDeconv: + case conv::Operator::kDgrad: + Conv2dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); + break; + + case conv::Operator::kWgrad: + Conv2dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + default: + break; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// 3D convolution +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// y = conv3d(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int z = 0; z < problem_size.Z; ++z) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (d >= 0 && d < problem_size.D && + h >=0 && h < problem_size.H && + w >= 0 && w < problem_size.W) { + + ElementA a = tensor_x.at({n, d, h, w, c}); + ElementB b = tensor_w.at({k, t, r, s, c}); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k)); + } + + tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dgrad / Deconv +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dDgrad( + cutlass::conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + bool is_deconv = false) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int d = 0; d < problem_size.D; ++d) { + for (int h = 0; h < problem_size.H; ++h) { + for (int w = 0; w < problem_size.W; ++w) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int k = 0; k < problem_size.K; ++k) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d; + int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; + + if (z >= 0 && (z % problem_size.stride_d) == 0 && + p >= 0 && (p % problem_size.stride_h) == 0 && + q >= 0 && (q % problem_size.stride_w) == 0) { + + z = z / problem_size.stride_d; + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { + + ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + + } // for (K) + } // for (S) + } // for (R) + } // for (T) + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c)); + } + + tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (W) + } // for (H) + } // for (D) + } // for (N) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Wgrad +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dWgrad( + cutlass::conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta) { + + InnerProductOp inner_product_op; + ConvertOp convert_op; + + // Apply MMA and accumulate ElementAccumulator + for (int k = 0; k < problem_size.K; ++k) { + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int n = 0; n < problem_size.N; ++n) { + for (int z = 0; z < problem_size.Z; ++z) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + Tensor5DCoord b_coord = make_Coord( + n, + z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d, + p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, + q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, + c); + + if (b_coord.d() < problem_size.D && b_coord.d() >= 0 && + b_coord.h() < problem_size.H && b_coord.h() >= 0 && + b_coord.w() < problem_size.W && b_coord.w() >= 0) { + + ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k))); + ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); + + acc = inner_product_op(a, b, acc); + } + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c)); + } + + tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + } // for (K) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3d( + conv::Operator convolutional_operator, + conv::Conv3dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + Conv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + case conv::Operator::kDeconv: + case conv::Operator::kDgrad: + Conv3dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); + break; + + case conv::Operator::kWgrad: + Conv3dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + default: + break; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h new file mode 100644 index 0000000000000000000000000000000000000000..12ead83354b785096e8029b49f1ac353d5ce5f82 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h @@ -0,0 +1,66 @@ + +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/util/reference/host/tensor_reduce.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorRelativeErrorMetric( + TensorView view_A_computed, + TensorView view_B_reference, + ComputeType identity = ComputeType() +) { + + return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) / + cutlass::reference::host::TensorNorm(view_B_reference, identity); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..2afee7b36d9822cc196f0f167f9dbec4c295d1a6 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h @@ -0,0 +1,531 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" + +namespace cutlass { +namespace reference { +namespace host { + +template +struct CastIfScalar { + static Out cast(In in) { + return Out(in); + } +}; + +template +struct CastIfScalar, In> { + typedef cutlass::complex Out; + static Out cast(In in) { + return Out(static_cast(in)); + } +}; + +template +struct CastIfScalar, cutlass::complex> { + typedef cutlass::complex Out; + typedef cutlass::complex In; + static Out cast(In in) { + return Out(in); + } +}; + +template +Out cast_if_scalar(In in) { + return CastIfScalar::cast(in); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_gemm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Gemm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add-saturate +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for XOR-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +/// Partial specialization for AND-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Batched GEMM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a batch of GEMMs over a set of matrices of common dimension. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c, + AccumulatorType initial_accum) { + + typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin(); + typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin(); + typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin(); + + for (int batch = 0; + batch < batch_count; + ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) { + + Gemm + gemm; + + gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it, + initial_accum); + } +} + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c) { + + BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..221a6040854a74ce465af7b021bbbfae9b96a90b --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h @@ -0,0 +1,210 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass/tensor_view.h" + +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ElementD = ElementC +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..507c37d9eb5a8c998f1075d547e8430b2edc5685 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h @@ -0,0 +1,228 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + complex accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC d_ij = tensor_c.at(coord); + + complex src{ + ScalarType(d_ij.real()), + ScalarType(d_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag()); + + tensor_d.at(coord) = d_ij; + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd54dc6e378d0d0f0549ec922da8357841ac558f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -0,0 +1,916 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GETT in host-side code. +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/gemm.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/relatively_equal.h" + +#include "cute/tensor.hpp" +#include "cute/pointer.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +template +struct ElementTraits { + using type = T; +}; + +template +struct ElementTraits().get()), void> > > { + using type = decltype(std::declval().get()); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////// +// +// Gett Mainloop Parameters +// +/////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorB_ // (N, K, L) + + , class TensorSfA_ = TensorA_, + class TensorSfB_ = TensorB_ + +> +struct GettMainloopParams { + using ElementAccumulator = ElementAccumulator_; + using TensorA = TensorA_; + using TensorB = TensorB_; + using EngineA = typename TensorA::engine_type; + using LayoutA = typename TensorA::layout_type; + using EngineB = typename TensorB::engine_type; + using LayoutB = typename TensorB::layout_type; + + TensorA A{}; + TensorB B{}; + + ComplexTransform transform_A = ComplexTransform::kNone; + ComplexTransform transform_B = ComplexTransform::kNone; + + + using TensorSfA = TensorSfA_; + using TensorSfB = TensorSfB_; + using EngineSfA = typename TensorSfA::engine_type; + using LayoutSfA = typename TensorSfA::layout_type; + using EngineSfB = typename TensorSfB::engine_type; + using LayoutSfB = typename TensorSfB::layout_type; + TensorSfA_ SfA{}; + TensorSfB_ SfB{}; + + + GettMainloopParams() {} + + GettMainloopParams(TensorA tensor_A, TensorB tensor_B) + : A(tensor_A), B(tensor_B) {} + + + GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) + : A(tensor_A), SfA(tensor_SfA), + B(tensor_B), SfB(tensor_SfB) {} + + +}; + + + +//////////////////////////////////////////////////////////////////////// +// +// Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels +// +//////////////////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorSfA_, // (M, K, L) + class TensorB_, // (N, K, L) + class TensorSfB_ // (N, K, L) +> +struct GettBlockScalingMainloopParams : public GettMainloopParams { + using Base = GettMainloopParams; + using ElementAccumulator = typename Base::ElementAccumulator; + using TensorA = typename Base::TensorA; + using TensorB = typename Base::TensorB; + using EngineA = typename Base::EngineA; + using LayoutA = typename Base::LayoutA; + using EngineB = typename Base::EngineB; + using LayoutB = typename Base::LayoutB; + ComplexTransform transform_A = Base::transform_A; + ComplexTransform transform_B = Base::transform_B; + + using TensorSfA = typename Base::TensorSfA; + using TensorSfB = typename Base::TensorSfB; + using EngineSfA = typename Base::EngineSfA; + using LayoutSfA = typename Base::LayoutSfA; + using EngineSfB = typename Base::EngineSfB; + using LayoutSfB = typename Base::LayoutSfB; + + GettBlockScalingMainloopParams() {} + + GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) + : Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {} + + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class SfStrategy { + None = 0, + SfDGen = 1 +}; + + +/////////////////////////////////////////////////////////// +// +// Gett Epilogue Parameters +// +/////////////////////////////////////////////////////////// + +template< + class ElementScalar_, + class ElementScalingFactor_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, // (M, N, L) + class TensorD_, // (M, N, L) + class VectorBias_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) + class TensorAux_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, N, L) + class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) + class VectorBeta_ = VectorAlpha_, // (M, 1) + class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + class TensorSFD_ = TensorD_, + class SFD_VectorSize_ = cute::Int<0>, + class BiasBinaryOp_ = cutlass::plus, + bool PerColumnBias_ = false + , + SfStrategy SfGenStrategy_ = SfStrategy::None +> +struct GettEpilogueParams { + using ElementScalar = ElementScalar_; + using ElementScalingFactor = ElementScalingFactor_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using TensorC = TensorC_; + using TensorD = TensorD_; + using TensorAux = TensorAux_; + using VectorBias = VectorBias_; + using VectorAlpha = VectorAlpha_; + using VectorBeta = VectorBeta_; + using TensorSFD = TensorSFD_; + using SFD_VectorSize = SFD_VectorSize_; + using ActivationFunctor = ActivationFunctor_; + using BiasBinaryOp = BiasBinaryOp_; + + using EngineC = typename TensorC::engine_type; + using LayoutC = typename TensorC::layout_type; + using EngineD = typename TensorD::engine_type; + using LayoutD = typename TensorD::layout_type; + using EngineSfD = typename TensorSFD::engine_type; + using LayoutSfD = typename TensorSFD::layout_type; + static constexpr bool PerColumnBias = PerColumnBias_; + static constexpr SfStrategy SfGenStrategy = SfGenStrategy_; + + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorC C{}; + TensorD D{}; + VectorBias Bias{}; + TensorAux Aux{}; + VectorAlpha Valpha{}; + VectorBeta Vbeta{}; + TensorSFD SfD{}; + ElementCompute st = ElementCompute(1); + + ElementAccumulator* abs_max_D = nullptr; + ElementAccumulator* abs_max_Aux = nullptr; + + ElementScalingFactor scale_a = ElementScalingFactor(1); + ElementScalingFactor scale_b = ElementScalingFactor(1); + ElementScalingFactor scale_c = ElementScalingFactor(1); + ElementScalingFactor scale_d = ElementScalingFactor(1); + ElementScalingFactor scale_aux = ElementScalingFactor(1); + + bool beta_per_channel_scaling = false; + GettEpilogueParams() {} + + GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) + : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {} + + + GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) + : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {} + + + GettEpilogueParams( + ElementScalar alpha, ElementScalar beta, + TensorC tensor_C, TensorD tensor_D, + VectorBias bias, TensorAux tensor_aux, + VectorAlpha vector_alpha, VectorBeta vector_beta) + : alpha(alpha), beta(beta), + C(tensor_C), D(tensor_D), + Bias(bias), Aux(tensor_aux), + Valpha(vector_alpha), Vbeta(vector_beta) {} +}; + + + +//////////////////////////////////////////////////////////////////////// +// +// Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels +// +//////////////////////////////////////////////////////////////////////// + +template< + class ElementScalar_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, + class TensorD_, + class TensorSfD_ = TensorD_, + class SFD_VectorSize_ = cute::Int<0>, + SfStrategy SfGenStrategy_ = SfStrategy::None +> +struct GettBlockScalingEpilogueParams : public GettEpilogueParams< + ElementScalar_, // ElementScalar + ElementScalar_, // ElementScalingFactor + ElementAccumulator_, // ElementAccumulator + ElementCompute_, // ElementCompute + TensorC_, // TensorC (M, N, L) + TensorD_, // TensorD (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) + cutlass::epilogue::thread::Identity, // + TensorSfD_, // TensorSfD + SFD_VectorSize_, // SFD_VectorSize + cutlass::plus, // class BiasBinaryOp_ = + false, //PerColumnBias_ + SfGenStrategy_ // SfGenStrategy + > { + using Base = GettEpilogueParams< + ElementScalar_, // ElementScalar + ElementScalar_, // ElementScalingFactor + ElementAccumulator_, // ElementAccumulator + ElementCompute_, // ElementCompute + TensorC_, // TensorC (M, N, L) + TensorD_, // TensorD (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) + cutlass::epilogue::thread::Identity, // + TensorSfD_, // TensorSfD + SFD_VectorSize_, // SFD_VectorSize + cutlass::plus, // BiasBinaryOp + false, // PerColumnBias + SfGenStrategy_ // SfGenStrategy + >; + using ElementScalar = typename Base::ElementScalar; + using ElementScalingFactor = typename Base::ElementScalingFactor; + using ElementAccumulator = typename Base::ElementAccumulator; + using ElementCompute = typename Base::ElementCompute; + using TensorC = typename Base::TensorC; + using TensorD = typename Base::TensorD; + using TensorAux = typename Base::TensorAux; + using VectorBias = typename Base::VectorBias; + using VectorAlpha = typename Base::VectorAlpha; + using VectorBeta = typename Base::VectorBeta; + using TensorSFD = typename Base::TensorSFD; + using SFD_VectorSize = typename Base::SFD_VectorSize; + using ActivationFunctor = typename Base::ActivationFunctor; + using BiasBinaryOp = typename Base::BiasBinaryOp; + + using EngineC = typename Base::EngineC; + using LayoutC = typename Base::LayoutC; + using EngineD = typename Base::EngineD; + using LayoutD = typename Base::LayoutD; + using EngineSfD = typename Base::EngineSfD; + using LayoutSfD = typename Base::LayoutSfD; + static constexpr bool PerColumnBias = Base::PerColumnBias; + static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy; + + GettBlockScalingEpilogueParams() {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) + : Base(alpha, beta, tensor_C, tensor_D) {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD) + : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) + : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {} +}; + + + + + +/////////////////////////////////////////////////////////// +// +// Generic Gett 3x Implementation +// +/////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +void compute_1d_scaling_factor_and_quantized_output( + EpilogueParams const& epilogue_params, + TensorD &tensor_D, + TensorSFD &tensor_SfD, + int64_t m, + int64_t n, + int64_t l, + ElementCompute (&acc)[kBlockM][kBlockN]) +{ + using ElementD = typename ElementTraits::type; + using ElementSfD = typename ElementTraits::type; + + int const M = cute::size<0>(tensor_D.layout()); + int const N = cute::size<1>(tensor_D.layout()); + int const L = cute::size<2>(tensor_D.layout()); + + auto mul = cutlass::multiplies{}; + auto div = divides{}; + // Get FP max + ElementCompute fp_max = ElementCompute(std::numeric_limits::max()); + float scale_down_factor = div(1.0f, fp_max); + // Get st' = st / FP max + ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor); + + absolute_value_op abs_op; + maximum_with_nan_propogation max_op; + + if constexpr (cute::is_constant<1, decltype(cute::stride<0,0,1>(tensor_SfD))>::value) { + // MN major output + int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize); + // Col major output + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { + int64_t col = n + n_b; + + /// Step1: get max across a vector + ElementCompute accum_max = ElementCompute(0); + for (int v = 0; v < kVectorSize; v++) { + int accum_row = v_b * kVectorSize + v; + int64_t output_row = accum_row + m; + if (output_row < M && col < N) { + accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b])); + } + } + + /// Step2: Compute Scale + ElementCompute pvscale = mul(accum_max, st_scaled_down); + ElementSfD qpvscale = static_cast(pvscale); + // Store the Scaling Factors + int64_t sf_row = m + kVectorSize * v_b; + if (sf_row < M && col < N) { + tensor_SfD(sf_row, col, l) = qpvscale; + } + + /// Step3: Compute quantized output values + ElementCompute qpvscale_up = NumericConverter{}(qpvscale); + // Get float reciprocal + ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); + ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Store the intermediate_accum + for (int v = 0; v < kVectorSize; v++) { + int accum_row = v_b * kVectorSize + v; + int64_t output_row = accum_row + m; + if (output_row < M && col < N) { + acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale); + } + } + } + } + } + else { + int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize); + // row major output + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { + int64_t row = m + m_b; + + /// Step1: get max across a vector + ElementCompute accum_max = ElementCompute(0); + for (int v = 0; v < kVectorSize; v++) { + int accum_col = v_b * kVectorSize + v; + int64_t output_col = accum_col + n; + if (row < M && output_col < N) { + accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col])); + } + } + + /// Step2: Compute Scale + ElementCompute pvscale = mul(accum_max, st_scaled_down); + ElementSfD qpvscale = static_cast(pvscale); + // Store the Scaling Factors + int64_t sf_col = n + kVectorSize * v_b; + + if (row < M && sf_col < N) { + tensor_SfD(row, sf_col, l) = qpvscale; + } + + /// Step3: Compute quantized output values + ElementCompute qpvscale_up = NumericConverter{}(qpvscale); + // Get float reciprocal + ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); + ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Store the intermediate_accum + for (int v = 0; v < kVectorSize; v++) { + int accum_col = v_b * kVectorSize + v; + int64_t output_col = accum_col + n; + if (row < M && output_col < N) { + acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale); + } + } + } + } + } +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - General Tensor-Tensor contraction reference kernel +template < + class MainloopParams, + class EpilogueParams +> +void Gett( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + gett_epilogue(epilogue_params, m, n, l, acc); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Mainloop +template +void gett_mainloop( + MainloopParams const& mainloop_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); + static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementA = typename ElementTraits::type; + using ElementB = typename ElementTraits::type; + + + using ElementSFA = typename ElementTraits::type; + using ElementSFB = typename ElementTraits::type; + + + using RingOp = multiply_add; + RingOp fma_op; + + // Zero out accumulators + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Compute on this k-block + for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { + // Load A + ElementAccumulator a_frag[kBlockM]; + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); + + + if constexpr (not cute::is_same_v){ + // Load SFA + auto sfa = static_cast(mainloop_params.SfA(m + m_b, k, l)); + a_frag[m_b] *= sfa; + } + + + if (mainloop_params.transform_A == ComplexTransform::kConjugate) { + a_frag[m_b] = conj(a_frag[m_b]); + } + } else { + a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Load B + ElementAccumulator b_frag[kBlockN]; + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); + + + if constexpr (not cute::is_same_v){ + // Load SFB + auto sfb = static_cast(mainloop_params.SfB(n + n_b, k, l)); + b_frag[n_b] *= sfb; + } + + + if (mainloop_params.transform_B == ComplexTransform::kConjugate) { + b_frag[n_b] = conj(b_frag[n_b]); + } + } else { + b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // do compute + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]); + } + } + + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Epilogue +template +void gett_epilogue( + EpilogueParams const& epilogue_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); + static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementCompute = typename EpilogueParams::ElementCompute; + using ElementC = typename EpilogueParams::TensorC::value_type; + using ElementD = typename EpilogueParams::TensorD::value_type; + using ElementSfD = typename EpilogueParams::TensorSFD::value_type; + using ElementAux = typename EpilogueParams::TensorAux::value_type; + using ElementBias = typename EpilogueParams::VectorBias::value_type; + using ElementScalar = typename EpilogueParams::ElementScalar; + using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; + using ActivationFunctor = typename EpilogueParams::ActivationFunctor; + using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; + + constexpr bool PerColBias = EpilogueParams::PerColumnBias; + constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy; + + constexpr bool IsScalingAndAmaxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsScalingAndAmaxAuxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsReLUAuxNeeded = + (cute::is_same_v> or + cute::is_same_v>) and + cute::is_same_v; + constexpr bool UseReLU = + cute::is_same_v>; // Treat Clamp as ReLU + + constexpr bool IsBackpropFusion = + cute::is_same_v> or + cute::is_same_v>; + + // Input related converter + NumericConverter accumulator_converter; + NumericConverter source_converter; + NumericConverter bias_converter; + [[maybe_unused]] NumericConverter aux_source_converter; + + // Scale related converter + NumericConverter scale_converter; + NumericConverter scaling_factor_converter; + + // Abs max converter + [[maybe_unused]] NumericConverter abs_max_output_converter; + + // Output related converter + NumericConverter destination_converter; + [[maybe_unused]] NumericConverter aux_destination_converter; + NumericConverter dBias_converter; + + // Epilogue operations + multiply_add epilogue_fma; + multiplies mul; + plus add; + + // Activation operation + ActivationFunctor activation; + + // Bias binary operation + BiasBinaryOp bias_op; + + // Do conversion + ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); + ElementCompute converted_beta = scale_converter(epilogue_params.beta); + ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); + ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); + ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); + ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); + ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); + + // Init local var + [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); + [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); + + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + converted_beta = mul(converted_beta, converted_scale_c); + + ElementCompute inter_accum[kBlockM][kBlockN]; + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + ElementCompute local_dBias = ElementCompute(0); + + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + // Convert every type to ElementCompute first, do compute, convert to output type, write it out + ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); + // vector alpha + if (raw_pointer_cast(epilogue_params.Valpha.data())) { + converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l)); + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + } + ElementCompute output = mul(converted_alpha, converted_acc); + + if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { + ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); + output = bias_op(output, converted_bias); + } + + if (raw_pointer_cast(epilogue_params.C.data())) { + ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); + // vector beta + if (epilogue_params.Vbeta.data()) { + converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l)); + converted_beta = mul(converted_beta, converted_scale_c); + } + output = epilogue_fma(converted_beta, converted_src, output); + } + + if constexpr (IsBackpropFusion) { + ElementAux aux_input = ElementAux(0); + if (raw_pointer_cast(epilogue_params.Aux.data())) { + aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); + } + + output = activation(output, aux_source_converter(aux_input)); + local_dBias = add(local_dBias, output); + } + else { + if (raw_pointer_cast(epilogue_params.Aux.data())) { + auto aux_output = output; + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); + aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); + } + + if constexpr (IsReLUAuxNeeded) { + epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); + } else { + epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); + } + } + + if constexpr (UseReLU) { + cutlass::epilogue::thread::ReLU relu; + output = relu(output); + } + else { + output = activation(output); + } + } + + if constexpr (IsScalingAndAmaxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_output = amax_op(local_abs_max_output, output); + output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); + } + + inter_accum[m_b][n_b] = ElementCompute(output); + } + } // n_b + + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { + if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { + ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); + local_dBias = add(local_dBias, converted_dBias); + epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); + } + } + } // m_b + + if constexpr ( + SfGenStrategy == SfStrategy::SfDGen + ) { + // 1d scale factor generation + constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{}; + if (epilogue_params.SfD.data() != nullptr) { + compute_1d_scaling_factor_and_quantized_output(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum); + } + } + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); + } + } + } + +#if defined(_OPENMP) + #pragma omp critical(Abs_Max_Data_Update) +#endif + { + if constexpr (IsScalingAndAmaxOutputNeeded) { + if (epilogue_params.abs_max_D) { + *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); + } + } + + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + if (epilogue_params.abs_max_Aux) { + *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +auto make_layout_rank3(const TensorType& tensor) { + // append a batch mode of size 1 if we do not have tensors that are rank 3 + return make_layout( + make_shape(cute::get<0>(tensor.shape()), cute::get<1>(tensor.shape()), cute::Int<1>{}), + make_stride(cute::get<0>(tensor.stride()), cute::get<1>(tensor.stride()), int64_t(cosize(tensor.layout())))); +} + +/// GEMM - General Matrix-Matrix contraction without conjugation options +template < + class MainloopParams, + class EpilogueParams +> +void Gemm3x( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + using namespace cute; + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); + + if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) { + cute::Layout layout_A = make_layout_rank3(mainloop_params.A); + cute::Layout layout_B = make_layout_rank3(mainloop_params.B); + cute::Layout layout_C = make_layout_rank3(epilogue_params.C); + cute::Layout layout_D = make_layout_rank3(epilogue_params.D); + cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux); + cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias); + cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha); + cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta); + + auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); + auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); + auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); + auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); + auto TensorAux = make_tensor(epilogue_params.Aux.data(), layout_Aux); + auto VectorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias); + auto VectorAlpha = make_tensor(epilogue_params.Valpha.data(), layout_Valpha); + auto VectorBeta = make_tensor(epilogue_params.Vbeta.data(), layout_Vbeta); + + // Reconstruct mainloop params + GettMainloopParams + mainloop_params_converted{TensorA, + TensorB, + mainloop_params.transform_A, + mainloop_params.transform_B}; + + // Reconstruct epilogue params + GettEpilogueParams + epilogue_params_converted{epilogue_params.alpha, + epilogue_params.beta, + TensorC, + TensorD, + VectorBias, + TensorAux, + VectorAlpha, + VectorBeta, + epilogue_params.abs_amax_D, + epilogue_params.abs_amax_Aux, + epilogue_params.scale_a, + epilogue_params.scale_b, + epilogue_params.scale_c, + epilogue_params.scale_d, + epilogue_params.scale_aux + }; + + Gett(mainloop_params_converted, epilogue_params_converted); + } + else { + // if we already have a batch mode, just pass it through + Gett(mainloop_params, epilogue_params); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h new file mode 100644 index 0000000000000000000000000000000000000000..67867533d5783b6e0047ac2110dc47adaa277e25 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for Rank 2k update in host-side code. + + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_rank2k( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + static_assert( + FillModeC == FillMode::kLower || + FillModeC == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower), + std::greater_equal, + std::less_equal>::type; + + // Note: batch is ignored. + // Note: M is same as N for Rank 2k update + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < N; row_block += Nblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Nblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < N && col < N && compare_op(row, col)) + { + + // A x B^T + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b_t(cast_if_scalar(b_t)); + + accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]); + + // B x A^T + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType compute_b(cast_if_scalar(b)); + ComputeType compute_a_t(cast_if_scalar(a_t)); + + accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < N && col < N && + ( (FillModeC == FillMode::kLower && row >= col) || + (FillModeC == FillMode::kUpper && row <= col) ) + ) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_rank2k( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_rank2k( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Rank2K; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Rank2K { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_rank2k>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_rank2k>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..a738101660f7ebbdd7c7796d46df244f1e3f5f70 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h @@ -0,0 +1,318 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued Rank 2K update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Rank2K update operates on A=NxK, B=NxK, and C=NxN + assert(M==N); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x B^T (Symmetric) or A x B^H (Hermitian) + // complex conjugation on operandB (b_t) is function of blas3 computation + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_b.at(MatrixCoord(col, k_block))) : + tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(b_t); + + // complex conjugation is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + } + } + } + } + + /* HER2K need two epilogues to handle complex alpha value */ + if ( blas_mode == BlasMode::kHermitian ) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + tensor_d.at(coord) = convert_op(alpha * + ScalarType(accum[i][j]) + + beta * c); + } + } + } + + /* Zeoring out accum for second HERK */ + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // B x A^T (Symmetric) or B x A^H (Hermitian) + // complex conjugation on operandB (a_t) is function of blas3 computation + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))): + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType b_ik = ComputeType(b); + ComputeType a_jk = ComputeType(a_t); + + // complex conjugation here is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_ik = conj(b_ik); + } + // complex conjugation here is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_jk = conj(a_jk); + } + + accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]); + } + } + } + } + + ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ? + conj(alpha) : alpha; + ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ? + 1 : beta; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType d = (blas_mode == BlasMode::kHermitian) ? + tensor_d.at(coord) : tensor_c.at(coord); + + ScalarType tmp_d = convert_op( + alpha_hermitian * ScalarType(accum[i][j]) + + beta_hermitian * d); + + if (blas_mode == BlasMode::kHermitian && row == col ) { + tensor_d.at(coord) = real(tmp_d); + } else { + tensor_d.at(coord) = tmp_d; + } + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..1aad33fd643b60752bc0845e403cebc43ad7d047 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued Rank 2K update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Rank2K update operates on A=NxK, B=NxK, and C=NxN + assert(M==N); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x A^T (Symmetric) or A x A^H (Hermitian) + // complex conjugation on operandB (a_t) (function of blas3 computation) + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))) : + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(a_t); + + // complex conjugation (function of input layouts) + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation (function of input layouts) + if (transform_a == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + ScalarType tmp_d = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + + if (blas_mode == BlasMode::kHermitian && row == col ) { + tensor_d.at(coord) = real(tmp_d); + } else { + tensor_d.at(coord) = tmp_d; + } + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void RankKComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h new file mode 100644 index 0000000000000000000000000000000000000000..34f9648f25f8965f6730999b7763220c360683a8 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h @@ -0,0 +1,285 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for SYMM update in host-side code. + + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert( + FillModeA == FillMode::kLower || + FillModeA == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp_w_diag = typename TrMatrixCompareOp::Type; + using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp_w_diag compare_op_1; + CompareOp_wo_diag compare_op_2; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a_1 = ElementA(); + ElementB b_1 = ElementB(); + ElementA a_2 = ElementA(); + ElementB b_2 = ElementB(); + + // A x B or B x A (with diagonal) + if (SideModeA == SideMode::kLeft) { + a_1 = (compare_op_1(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); + b_1 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a_1 = tensor_b.at(MatrixCoord(row, k_block)); + b_1 = (compare_op_1(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); + } + + ComputeType compute_a_1(cast_if_scalar(a_1)); + ComputeType compute_b_1(cast_if_scalar(b_1)); + + accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); + + // A^T x B or B x A^T (without diagonal) + if (SideModeA == SideMode::kLeft) { + a_2 = (compare_op_2(k_block, row)) ? + (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); + b_2 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a_2 = tensor_b.at(MatrixCoord(row, k_block)); + b_2 = (compare_op_2(col, k_block)) ? + tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); + } + + ComputeType compute_a_2(cast_if_scalar(a_2)); + ComputeType compute_b_2(cast_if_scalar(b_2)); + + accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_symm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Symm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Symm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..79e146f69b784a92ce61a093f410e93a66005cf8 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued SYMM update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + BlasMode BlasMode_ = BlasMode::kSymmetric, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm_complex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static SideMode const kSideModeA = SideModeA; + static FillMode const kFillModeA = FillModeA; + static BlasMode const kBlasMode = BlasMode_; + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(kSideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert( + kFillModeA == FillMode::kLower || + kFillModeA == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp_w_diag = typename TrMatrixCompareOp::Type; + using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp_w_diag compare_op_1; + CompareOp_wo_diag compare_op_2; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) + { + ElementA a_1 = ElementA(); + ElementB b_1 = ElementB(); + ElementA a_2 = ElementA(); + ElementB b_2 = ElementB(); + + // A x B or B x A (with diagonal) + if (kSideModeA == SideMode::kLeft) { + a_1 = (compare_op_1(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); + b_1 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (kSideModeA == SideMode::kRight) { + a_1 = tensor_b.at(MatrixCoord(row, k_block)); + b_1 = (compare_op_1(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); + } + ComputeType compute_a_1 = ComputeType(a_1); + ComputeType compute_b_1 = ComputeType(b_1); + + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) { + compute_a_1 = real(compute_a_1); + } else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) { + compute_b_1 = real(compute_b_1); + } + + accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); + + // A^T x B or B x A^T (without diagonal) + if (kSideModeA == SideMode::kLeft) { + a_2 = (compare_op_2(k_block, row)) ? + (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); + b_2 = tensor_b.at(MatrixCoord(k_block, col)); + if (kBlasMode == BlasMode::kHermitian) + a_2 = conj(a_2); + } else if (kSideModeA == SideMode::kRight) { + a_2 = tensor_b.at(MatrixCoord(row, k_block)); + b_2 = (compare_op_2(col, k_block)) ? + tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); + if (kBlasMode == BlasMode::kHermitian) + b_2 = conj(b_2); + } + + ComputeType compute_a_2 = ComputeType(a_2); + ComputeType compute_b_2 = ComputeType(b_2); + + accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + ScalarType c = tensor_c.at(coord); + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric, + typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex +> +struct SymmComplex; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct SymmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm_complex>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for gaussian multiply-add +template +struct SymmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm_complex>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h new file mode 100644 index 0000000000000000000000000000000000000000..d6b85ca1baf65ba811b7c8b3a224ca90bbce1680 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h @@ -0,0 +1,616 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "cutlass/util/distribution.h" +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorGreatestErrorFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double result; + + /// Ctor + TensorGreatestErrorFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + result(0.0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + result = std::max(result, std::abs(double(lhs_) - double(rhs_))); + } + + /// Returns true if equal + operator double() const { + return result; + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorMREFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double sum; + uint64_t count; + static constexpr double epsilon = 1e-6; + + /// Ctor + TensorMREFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + sum(0.0), + count(0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + sum += std::abs(double(lhs_) - double(rhs_) / (double(rhs_) + epsilon)); + ++count; + } + + /// Returns true if equal + operator double() const { + return sum / double(count); + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorMSEFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double sum; + uint64_t count; + + /// Ctor + TensorMSEFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + sum(0.0), + count(0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + sum += std::pow((double(lhs_) - double(rhs_)), 2); + ++count; + } + + /// Returns true if equal + operator double() const { + return sum / double(count); + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + bool result; + + /// Ctor + TensorEqualsFunc(): result(true) { } + + /// Ctor + TensorEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), rhs(rhs_), result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (lhs_ != rhs_) { + result = false; + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorRelativelyEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + Element epsilon; + Element nonzero_floor; + bool result; + + /// Ctor + TensorRelativelyEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_, + Element epsilon_, + Element nonzero_floor_ + ) : + lhs(lhs_), + rhs(rhs_), + epsilon(epsilon_), + nonzero_floor(nonzero_floor_), + result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (!relatively_equal(lhs_, rhs_, epsilon, nonzero_floor)) { + result = false; + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the Mean Squared Error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorMSE( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorMSEFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the Mean Relative Error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorMRE( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorMREFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the greatest error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorGreatestError( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorGreatestErrorFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorView const &lhs, + TensorView const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc func(lhs, rhs, epsilon, nonzero_floor); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are NOT equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return true; + } + + detail::TensorEqualsFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return !bool(func); +} + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + return !TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorContainsFunc { + + // + // Data members + // + + TensorView view; + Element value; + bool contains; + Coord location; + + // + // Methods + // + + /// Ctor + TensorContainsFunc(): contains(false) { } + + /// Ctor + TensorContainsFunc( + TensorView const &view_, + Element value_ + ) : + view(view_), value(value_), contains(false) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + if (view.at(coord) == value) { + if (!contains) { + location = coord; + } + contains = true; + } + } + + /// Returns true if equal + operator bool() const { + return contains; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if a value is present in a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorContains( + TensorView const & view, + Element value) { + + detail::TensorContainsFunc func( + view, + value + ); + + TensorForEach( + view.extent(), + func + ); + + return bool(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns a pair containing a boolean of whether a value exists in a tensor and the location of +/// of the first occurrence. If the value is not contained in the tensor, the second element of the +/// pair is undefined. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +std::pair > TensorFind( + TensorView const & view, + Element value) { + + detail::TensorContainsFunc func( + view, + value + ); + + TensorForEach( + view.extent(), + func + ); + + return std::make_pair(bool(func), func.location); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp new file mode 100644 index 0000000000000000000000000000000000000000..27ef969b4ff2b6d8f3a53f3d1a3e5ec3e5203ec3 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename TensorL, + typename TensorR +> +bool TensorEquals( + TensorL lhs, + TensorR rhs) { + + // Extents must be identical + if (cute::size(lhs) != cute::size(rhs)) { + return false; + } + + for (int64_t idx = 0; idx < cute::size(lhs); ++idx) { + if (lhs(idx) != rhs(idx)) { + return false; + } + } + + return true; +} + +/// Returns true if two tensor views are NOT equal. +template < + typename TensorL, + typename TensorR +> +bool TensorNotEquals( + TensorL lhs, + TensorR rhs) { + + return TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h new file mode 100644 index 0000000000000000000000000000000000000000..d2a43b1295c8ab18c7d649c79b0364b6d3e7c48c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h @@ -0,0 +1,256 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper to convert between types +template < + typename DstElement, + typename SrcElement +> +struct TrivialConvert { + + TrivialConvert() { } + + DstElement operator()(SrcElement src) const { + return DstElement(src); + } +}; + +/// Helper to conditionally copy between tensor views. +template < + typename DstElement, + typename DstLayout, + typename SrcElement, + typename SrcLayout, + typename F +> +struct TensorCopyIf { + + using DstTensorView = TensorView; + using SrcTensorView = TensorView; + + // + // Data members + // + + DstTensorView dst; + SrcTensorView src; + F convert; + + // + // Methods + // + + TensorCopyIf() { } + + TensorCopyIf( + DstTensorView const &dst_, + SrcTensorView const &src_, + F const &convert_): dst(dst_), src(src_), convert(convert_) {} + + /// Copies based on destination and source bounds + void operator()(Coord const &coord) { + if (dst.contains(coord) && src.contains(coord)) { + dst.at(coord) = convert(src.at(coord)); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorView src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + CopyIf copy_if(dst, src, transform); + + TensorForEach(dst.extent(), copy_if); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent +/// to avoid out of bounds accesses. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorRef src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + TensorView src_view(src, dst.extent()); + + CopyIf copy_if(dst, src_view, transform); + + TensorForEach(dst.extent(), copy_if); +} + +/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent +/// to avoid out of bounds accesses. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorRef dst, + TensorView src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + TensorView dst_view(dst, src.extent()); + + CopyIf copy_if(dst_view, src, transform); + + TensorForEach(src.extent(), copy_if); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout /// Source tensor's layout +> +void TensorCopy( + TensorView dst, + TensorView src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorRef src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout /// Source tensor's layout +> +void TensorCopy( + TensorRef dst, + TensorView src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h new file mode 100644 index 0000000000000000000000000000000000000000..5470df29358799f6d5e6628e8722f0e3dc05485f --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h @@ -0,0 +1,341 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" + +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to apply a binary operator in place +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementD, + typename LayoutD, + typename BinaryFunc> +struct TensorFuncBinaryOp { + + // + // Data members + // + + /// View of left-hand-side tensor + TensorView view_d; + TensorRef view_a; + TensorRef view_b; + BinaryFunc func; + + // + // Methods + // + + /// Constructor + TensorFuncBinaryOp() { } + + /// Constructor + TensorFuncBinaryOp( + TensorView const & view_d_, + TensorRef const & view_a_, + TensorRef const & view_b_, + BinaryFunc func = BinaryFunc() + ): + view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { } + + /// Equality check + void operator()(Coord const &coord) const { + view_d.at(coord) = func( + ElementD(view_a.at(coord)), + ElementD(view_b.at(coord)) + ); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Adds two tensors and stores in the destination tensor: d = a + b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorAdd( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::plus + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Adds a tensor in place: d = d .+ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorAdd( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorAdd(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Subtracts two tensors and stores in the destination tensor: d = a - b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorSub( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference + ) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::minus + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Subtracts two tensors in place: d = d .- a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorSub( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference + ) { + + TensorSub(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Multiplies two tensors and stores in the destination tensor: d = a .* b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorMul( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::multiplies + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Multiplies tensors in place: d = d .* a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorMul( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorMul(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Divides two tensors and stores in the destination tensor: d = a ./ b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorDiv( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::divides + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Divides tensors in place: d = d ./ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorDiv( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorDiv(d, d, a); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Divides two tensors and stores in the destination tensor: d = a ./ b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorModulus( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::divides + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Divides tensors in place: d = d ./ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorModulus( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorDiv(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h new file mode 100644 index 0000000000000000000000000000000000000000..645902f7dd7b62bc98a479e4956dfb4b437d46a7 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -0,0 +1,1718 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include +#include +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/distribution.h" +#include "tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element value; + + // + // Methods + // + + TensorFillFunc( + TensorView const &view_ = TensorView(), + Element value_ = Element(0) + ): view(view_), value(value_) { } + + void operator()(Coord const & coord) const { + view.at(coord) = value; + } +}; + +/// Returns a pair of values of the Gaussian distribution generated by the Box Muller method +struct BoxMullerFunc { + + BoxMullerFunc() {} + + void operator()( + double* rnd, ///< Size-2 vector to be filled with random values + double mean = 0, ///< Mean of the Gaussian distribution + double stddev = 1, ///< Standard deviation of the Gaussian distribution + double pi = std::acos(-1)) const { + + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + rnd[0] = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd[1] = std::sqrt(-2 * std::log(u1)) * std::sin(2 * pi * u2); + rnd[0] = mean + stddev * rnd[0]; + rnd[1] = mean + stddev * rnd[1]; + } +}; +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorView dst, ///< destination tensor + Element val = Element(0)) { ///< value to uniformly fill it with + + detail::TensorFillFunc func(dst, val); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorViewPlanarComplex dst, ///< destination tensor + cutlass::complex val = cutlass::complex(0)) { ///< value to uniformly fill it with + + TensorFill(dst.view_real(), val.real()); + TensorFill(dst.view_imag(), val.imag()); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + } + else { + result = static_cast(0); + } + + // Note that exclude_zero = true will disable the bernoulli_result above by unsetting zeros + if (exclude_zero && result == Element(0)) { + if (rnd > 0) { + rnd += 1; + } else { + rnd -= 1; + } + result = Element(rnd); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + double rnd[2]; + detail::BoxMullerFunc func; + func(rnd, mean, stddev, pi); + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd[0] = double(std::llround(rnd[0] * double(1 << int_scale))); + rnd[1] = double(std::llround(rnd[1] * double(1 << int_scale))); + reals[0] = from_real(rnd[0] / double(1 << int_scale)); + reals[1] = from_real(rnd[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd[0]); + reals[1] = from_real(rnd[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + } + + // Note that this will invalidate the above else statement because it unsets zero elements + if (exclude_zero && + reals[0] == from_real(0.0) && + reals[1] == from_real(0.0)) { + + if (rnd[0] > 0.0) { + rnd[0] += 1.0; + } else { + rnd[0] -= 1.0; + } + reals[0] = from_real(rnd[0]); + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + double rnd1[2]; + double rnd2[2]; + detail::BoxMullerFunc func; + func(rnd1, mean, stddev, pi); + func(rnd2, mean, stddev, pi); + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd1[0] = double(std::llround(rnd1[0] * double(1 << int_scale))); + rnd1[1] = double(std::llround(rnd1[1] * double(1 << int_scale))); + rnd2[0] = double(std::llround(rnd2[0] * double(1 << int_scale))); + rnd2[1] = double(std::llround(rnd2[1] * double(1 << int_scale))); + + reals[0] = from_real(rnd1[0] / double(1 << int_scale)); + reals[1] = from_real(rnd1[1] / double(1 << int_scale)); + reals[2] = from_real(rnd2[0] / double(1 << int_scale)); + reals[3] = from_real(rnd2[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd1[0]); + reals[1] = from_real(rnd1[1]); + reals[2] = from_real(rnd2[0]); + reals[3] = from_real(rnd2[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + reals[2] = from_real(0); + reals[3] = from_real(0); + } + + // Note that this will invalidate the above else statement because it unsets zero elements + if (exclude_zero && + reals[0] == from_real(0) && + reals[1] == from_real(0) && + reals[2] == from_real(0) && + reals[3] == from_real(0)) { + + if (rnd1[0] > 0.0) { + rnd1[0] += 1.0; + } else { + rnd1[0] -= 1.0; + } + reals[0] = from_real(rnd1[0]); + } + + return Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillGaussianFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomGaussianFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillGaussianFunc( + TensorView view_ = TensorView(), + RandomGaussianFunc func_ = RandomGaussianFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + view.at(coord) = func(); + } +}; + +/// Computes a random Gaussian distribution for a rank-2 tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillSymmetricGaussianFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomGaussianFunc func; + cutlass::FillMode fill_mode; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillSymmetricGaussianFunc( + TensorView view_ = TensorView(), + RandomGaussianFunc func_ = RandomGaussianFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid + ): + view(view_), func(func_), fill_mode(fill_mode_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kLower && + coord[0] >= coord[1]) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + coord[0] <= coord[1]) { + view.at(coord) = func(); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of + /// data. + bool exclude_zero = false) { ///< Exclude zeros from tensor init. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz, exclude_zero); + + detail::TensorFillGaussianFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of + /// data. + bool exclude_zero = false) { ///< Exclude zeros from tensor init. + + TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits, pnz); + TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits, pnz); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSymmetricRandomGaussian( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); + + detail::TensorFillSymmetricGaussianFunc func( + dst, + random_func, + fill_mode + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values of a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of + /// data. + + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); + + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) + , bernoulli_rnd{static_cast(seed_)} + , bernoulli_dist(pnan_) + , exclude_zero(exclude_zero_) + { + std::srand((unsigned)seed); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero) { + min = (min == 0.0) ? min + 1: min; + range = (max == 0.0) ? range - 1: range; + } + } + + + /// Compute random value and update RNG state + Element operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + if (exclude_zero && result == Element(0)) { + if (rnd > 0.0) { + rnd = std::min(min + range, rnd + 1.0); + } else { + rnd = std::max(min, rnd - 1.0); + } + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) + , bernoulli_rnd{static_cast(seed_)} + , bernoulli_dist(pnan_) + , exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero) { + min = (min == 0.0) ? min + 1: min; + range = (max == 0.0) ? range - 1: range; + } + } + + + /// Compute random value and update RNG state + complex operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + + if (exclude_zero && + i == 0 && + reals[0] == from_real(0.0)) { + + if (rnd > 0.0) { + rnd = std::min(min + range, rnd + 1.0); + } else { + rnd = std::max(min, rnd - 1.0); + } + reals[0] = from_real(Real(rnd)); + } + + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_), + bernoulli_rnd{static_cast(seed_)}, + bernoulli_dist(pnan_) + { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +/// Computes a random uniform distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + + view.at(coord) = func(); + } +}; + +/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a uniform distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillSymmetricRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + cutlass::FillMode fill_mode; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillSymmetricRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid + ): + view(view_), func(func_), fill_mode(fill_mode_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kLower && + coord[0] >= coord[1]) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + coord[0] <= coord[1]) { + view.at(coord) = func(); + } + } +}; + +/// Computes a random Uniform distribution and pads diagonal with zeros +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillPadDiagonalRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + cutlass::FillMode fill_mode; + int alignment; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillPadDiagonalRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid, + int alignment_ = 1 + ): + view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + (fill_mode == cutlass::FillMode::kLower) && + (coord[0] >= coord[1]) || + ((coord[1] - coord[0]) >= alignment)) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + (coord[0] <= coord[1]) || + ((coord[0] - coord[1]) >= alignment)) { + view.at(coord) = func(); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values of a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + bool exclude_zero = false) { ///< Exclude zero from tensor init + detail::RandomUniformFunc random_func(seed, max, min, bits, pnan, exclude_zero); + + detail::TensorFillRandomUniformFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values of a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + bool exclude_zero = false) { ///< Exclude zero from tensor init + + TensorFillRandomUniform(dst.view_real(), seed, max, min, bits, pnan, exclude_zero); + TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits, pnan, exclude_zero); +} + + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView, Layout> dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc> random_func(seed, max, min, bits); + + detail::TensorFillRandomUniformFunc, Layout> func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSymmetricRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + detail::TensorFillSymmetricRandomUniformFunc func( + dst, + random_func, + fill_mode + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillPadDiagonalRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + int alignment = 1 +) { + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + detail::TensorFillPadDiagonalRandomUniformFunc func( + dst, + random_func, + fill_mode, + alignment + ); + + TensorForEach( + dst.extent(), + func + ); +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element ///< Element type +> +void BlockFill( + Element *ptr, + size_t capacity, + Element val + ) { + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = val; + } +} + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0) { ///< Percentage of NaN elements. + detail::RandomUniformFunc random_func(seed, max, min, bits, pnan); + + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillDiagonalFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element diag; + Element other; + + // + // Methods + // + + TensorFillDiagonalFunc( + TensorView const &view_ = TensorView(), + Element diag_ = Element(1), + Element other_ = Element(0) + ): + view(view_), diag(diag_), other(other_) { } + + void operator()(Coord const & coord) const { + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + view.at(coord) = (is_diag ? diag : other); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor everywhere with a unique value for its diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillDiagonal( + TensorView dst, ///< destination tensor + Element diag = Element(1), ///< value to write in the diagonal + Element other = Element(0)) { ///< value to write off the diagonal + + detail::TensorFillDiagonalFunc func( + dst, + diag, + other + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to fill a tensor's diagonal with 1 and 0 everywhere else. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillIdentity( + TensorView dst) { ///< destination tensor + + TensorFillDiagonal(dst, Element(1), Element(0)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateDiagonal( + TensorView dst, ///< destination tensor + Element val = Element(1)) { + + typename Layout::Index extent = dst.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + dst.at(coord) = val; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateOffDiagonalFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element other; + + // + // Methods + // + + TensorUpdateOffDiagonalFunc( + TensorView const &view_ = TensorView(), + Element other_ = Element(0) + ): + view(view_), other(other_) { } + + void operator()(Coord const & coord) const { + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (!is_diag) { + view.at(coord) = other; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateOffDiagonal( + TensorView dst, ///< destination tensor + Element other = Element(1)) { + + detail::TensorUpdateOffDiagonalFunc func( + dst, + other + ); + + TensorForEach( + dst.extent(), + func + ); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillLinearFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Array v; + Element s; + + // + // Methods + // + + TensorFillLinearFunc() { } + + /// Constructs functor + TensorFillLinearFunc( + TensorView const &view_, + Array const & v_, + Element s_ = Element(0) + ): + view(view_), v(v_), s(s_) { } + + /// Updates the tensor + void operator()(Coord const & coord) const { + + Element sum(s); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + sum += Element(coord[i]) * v[i]; + } + + view.at(coord) = sum; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillLinear( + TensorView dst, ///< destination tensor + Array const & v, + Element s = Element(0)) { + + detail::TensorFillLinearFunc func( + dst, + v, + s + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSequential( + TensorView dst, ///< destination tensor + Element s = Element(0)) { + + Array stride; + + stride[0] = Element(1); + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]); + } + + TensorFillLinear(dst, stride, s); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + bool exclude_zero = false ///< If true, excludes 0. + /// Note that setting this flag will result in more 1's, + /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. +) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + dist.gaussian.mean, + dist.gaussian.stddev, + dist.int_scale, + dist.gaussian.pnz, + exclude_zero); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + dist.uniform.max, + dist.uniform.min, + dist.int_scale, + dist.uniform.pnan, + exclude_zero); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = s; + + s = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = Element(s); + + s = int64_t(s + v) % mod; + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillRandom( + Element *ptr, + size_t capacity, + uint64_t seed, + Distribution dist) { + + if (dist.kind == Distribution::Gaussian) { + BlockFillRandomGaussian( + ptr, + capacity, + seed, + dist.gaussian.mean, + dist.gaussian.stddev, + dist.int_scale, + dist.gaussian.pnz); + } + else if (dist.kind == Distribution::Uniform) { + BlockFillRandomUniform( + ptr, + capacity, + seed, + dist.uniform.max, + dist.uniform.min, + dist.int_scale, + dist.uniform.pnan); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomSparseMetaFunc { + + uint64_t seed; + int range; + int MetaSizeInBits; + + // + // Methods + // + + RandomSparseMetaFunc( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), MetaSizeInBits(MetaSizeInBits_) { + std::srand((unsigned)seed); + if (MetaSizeInBits_ == 2) { + range = 6; + } + else if (MetaSizeInBits_ == 4) { + range = 2; + } + else { + throw std::invalid_argument("Invalid MetaSizeInBits"); + } + } + + /// Compute random value and update RNG state + Element operator()() const { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + int rnd = std::rand() % range; + Element meta = MetaArray[rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random sparse meta +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomSparseMetaFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillRandomSparseMetaFunc( + TensorView view_ = TensorView(), + RandomSparseMetaFunc func_ = RandomSparseMetaFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + + view.at(coord) = func(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4 bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + detail::TensorFillRandomSparseMetaFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a ell block index matrix with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomEllIdx( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int rows, int ell_cols, int cols) { ///< dimension of the matrix + + std::srand((unsigned)seed); + + for (int i = 0; i < rows; ++i) { + int col_idx = std::rand() % cols; + + for (int j = 0; j < ell_cols; ++j) { + dst.at({i, j}) = col_idx; + + if (col_idx != -1) { + if (col_idx == (cols - 1)) { + col_idx = -1; + } else { + col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1; + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies a diagonal in from host memory without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalIn( + TensorView dst, ///< destination tensor + Element const *ptr) { ///< dense buffer of elements + + typename Layout::Index extent = dst.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + dst.at(coord) = ReferenceFactory::get(ptr, i); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies the diagonal of a tensor into a dense buffer in host memory. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalOut( + Element *ptr, ///< dense buffer of elements + TensorView src) { ///< source tensor + + typename Layout::Index extent = src.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + ReferenceFactory::get(ptr, i) = src.at(coord); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1b3df239a1b9d69fc12e7ec4be2de6f87b3a0e3c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp @@ -0,0 +1,432 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Uniform and procedural tensor fills +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a scalar element +template +void TensorFill(Tensor dst, typename Tensor::value_type element) { + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = element; + } +} + +/// Fills a tensor with the contents of its layout +template +void TensorFillSequential(Tensor dst) { + + auto layout = dst.layout(); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = layout(idx); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random uniform values +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Element operator()() const { + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template ///< Tensor object +void TensorFillRandomUniform( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random Gaussian +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1 + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Tensor +> +void TensorFillRandomGaussian( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = static_cast(int32_t(int64_t(s + v) % mod)); + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..bcb1af995805e3fbcbdbf398ce7191ea2f0dbe8d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Index of the active rank + static int const kActiveRank = Rank - RankRemaining - 1; + + /// Constructor for general rank + TensorForEachHelper( + Func &func, + Coord const &extent, + Coord &coord) { + + for (int i = 0; i < extent.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + TensorForEachHelper(func, extent, coord); + } + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Index of the active rank + static int const kActiveRank = Rank - 1; + + /// Constructor for fastest changing rank + TensorForEachHelper( + Func &func, + Coord const &extent, + Coord &coord) { + + for (int i = 0; i < extent.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor +template < + typename Func, ///< function applied to each point in a tensor's index space + int Rank> ///< rank of index space +void TensorForEach(Coord extent, Func & func) { + Coord coord; + detail::TensorForEachHelper(func, extent, coord); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor and calls a C++ lambda +template < + typename Func, ///< function applied to each point in a tensor's index space + int Rank> ///< rank of index space +void TensorForEachLambda(Coord extent, Func func) { + Coord coord; + detail::TensorForEachHelper(func, extent, coord); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockForEach { + + /// Constructor performs the operation. + BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params = typename Func::Params()) { + + Func func(params); + + for (size_t index = 0; index < capacity; ++index) { + ptr[index] = func(); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..d44dda1f5472f13b7212f7e2e4020e254ff92f88 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" + +// The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions. + +#include "cutlass/util/reference/host/tensor_reduce.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..887c568059a90f749fc0ac75dd211ce77085a5a9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/util/reference/detail/linear_to_coordinate.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < int64_t(view.size()); ++idx) { + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + Element x = view.at(coord); + identity = reduce(identity, transform(x)); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Tensor extents must match."); + } + + for (int64_t idx = 0; idx < int64_t(view_A.size()); ++idx) { + + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + Element a = view_A.at(coord); + Element b = view_B.at(coord); + identity = reduce(identity, transform(a, b)); + } + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea711466df86703aae1702605a928754c9f4e944 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tensor reductions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Tensor, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + Tensor view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < cute::size(view); ++idx) { + identity = reduce(identity, transform(view(idx))); + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename TensorA, + typename TensorB, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorA view_A, + TensorB view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (cute::size(view_A) != cute::size(view_B)) { + throw std::runtime_error("Tensor sizes must match."); + } + + for (int64_t idx = 0; idx < cute::size(view_A); ++idx) { + identity = reduce(identity, transform(view_A(idx), view_B(idx))); + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSum( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSumSq( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Tensor, + typename ComputeType = double +> +ComputeType TensorNorm( + Tensor view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h new file mode 100644 index 0000000000000000000000000000000000000000..09b1aff9c0ea9922af46c928a3dd61595be2e4cd --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for TRMM in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_trmm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper + , "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = ElementA(); + ElementB b = ElementB(); + + if (SideModeA == SideMode::kLeft) { + a = (compare_op(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); + if (row == k_block && DiagTypeA == DiagType::kUnit) { + a = ElementA(1); + } + b = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a = tensor_b.at(MatrixCoord(row, k_block)); + b = (compare_op(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); + if (k_block == col && DiagTypeA == DiagType::kUnit) { + b = ElementA(1); + } + } + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j])); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Trmm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Trmm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..e8db2a4deaf8608882595d68e611f8ae79e134e8 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued TRMM in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + ComplexTransform TransformA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + ComplexTransform TransformB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_trmm_complex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper + , "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = ElementA(); + ElementB b = ElementB(); + + if (SideModeA == SideMode::kLeft) { + a = (compare_op(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); + if (row == k_block && DiagTypeA == DiagType::kUnit) { + a = ElementA(1); + } + b = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a = tensor_b.at(MatrixCoord(row, k_block)); + b = (compare_op(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); + if (k_block == col && DiagTypeA == DiagType::kUnit) { + b = ElementA(1); + } + } + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + // Conjugate, and hence hermitian, is only allowed for the triangular matrix + if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j])); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + ComplexTransform TransformA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + ComplexTransform TransformB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex +> +struct TrmmComplex; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct TrmmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm_complex>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for gaussian multiply-add +template +struct TrmmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm_complex>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h new file mode 100644 index 0000000000000000000000000000000000000000..0ce1d8a65fdd66ace69f91525b678dd6ad132d24 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h @@ -0,0 +1,270 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +#pragma once + +#include "cutlass/core_io.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/complex.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorView_WriteLeastSignificantRank( + std::ostream& out, + TensorView const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + out << ScalarIO(view.at(coord)); + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorView_WriteRank( + std::ostream& out, + TensorView const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by "\n" + if (idx) { + out << ",\n"; + } + TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + if (idx) { + out << ",\n\n"; + } + TensorView_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + + complex x = view.at(coord); + out << x; + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by ";\n" + if (idx) { + out << ";\n"; + } + TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + if (idx) { + out << "\n"; + } + TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorView const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorView_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorView const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..5dfbfe274dec368cfac291a1c78ece6ffb203c72 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Type traits for common CUDA types +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/numeric_types.h" +#include "cutlass/complex.h" + +namespace cutlass { +struct half_t; + +template +struct TypeTraits { + typedef T host_type; + typedef T device_type; + static inline T remove_negative_zero(T x) { return x; } + static inline T to_print(T x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef int8_t host_type; + typedef int8_t device_type; + typedef int8_t integer_type; + typedef uint8_t unsigned_type; + static inline int8_t remove_negative_zero(int8_t x) { return x; } + static inline int to_print(int8_t x) { return (int)x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef uint8_t host_type; + typedef uint8_t device_type; + typedef uint8_t integer_type; + typedef uint8_t unsigned_type; + static inline uint8_t remove_negative_zero(uint8_t x) { return x; } + static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32I; + typedef int host_type; + typedef int device_type; + typedef int32_t integer_type; + typedef uint32_t unsigned_type; + static inline int32_t remove_negative_zero(int32_t x) { return x; } + static inline int to_print(int x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32I; + typedef unsigned host_type; + typedef unsigned device_type; + typedef uint32_t integer_type; + typedef uint32_t unsigned_type; + static inline uint32_t remove_negative_zero(uint32_t x) { return x; } + static inline uint32_t to_print(uint32_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef int64_t host_type; + typedef int64_t device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + static inline int64_t remove_negative_zero(int64_t x) { return x; } + static inline int64_t to_print(int64_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef uint64_t host_type; + typedef uint64_t device_type; + typedef uint64_t integer_type; + typedef uint64_t unsigned_type; + static inline uint64_t remove_negative_zero(uint64_t x) { return x; } + static inline uint64_t to_print(uint64_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_16F; + typedef half_t host_type; + typedef half_t device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline half_t remove_negative_zero(half_t x) { + return (x.raw() == 0x8000 ? half_t::bitcast(0) : x); + } + static inline half_t to_print(half_t x) { return x; } + static inline device_type to_device(half_t x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32F; + typedef float host_type; + typedef float device_type; + typedef int32_t integer_type; + typedef uint32_t unsigned_type; + static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; } + static inline float to_print(float x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_64F; + typedef double host_type; + typedef double device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; } + static inline double to_print(double x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Complex types +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_16F; + typedef complex host_type; + typedef complex device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_16F; + typedef complex host_type; + typedef complex device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0_hf ? 0_hf : real(x), + imag(x) == -0_hf ? 0_hf : imag(x) + ); + } + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + + static cudaDataType_t const cublas_type = CUDA_C_32F; + typedef complex host_type; + typedef complex device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0.f ? 0.f : real(x), + imag(x) == -0.f ? 0.f : imag(x) + ); + } + + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_64F; + typedef complex host_type; + typedef complex device_type; + struct integer_type { int64_t real, imag; }; + struct unsigned_type { uint64_t real, imag; }; + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0.0 ? 0.0 : real(x), + imag(x) == -0.0 ? 0.0 : imag(x) + ); + } + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py new file mode 100644 index 0000000000000000000000000000000000000000..6541ce1b26722ff1f0dba0b4c034067a62f9b96d --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py @@ -0,0 +1,356 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + + +""" +Given a set of test files to be included in a CMake target, this script extracts +the TEST definitions from each file, writes them into new files, and prints the names +of the new files so that they can be processed as part of a new CMake target. + +For example, given a set of --src_files test_a.cu test_b.cu containing 3 and 2 TEST +definitions, respectively, this script would produce: + test_a_000.cu + test_a_001.cu + test_a_002.cu + test_b_000.cu + test_b_001.cu + +The splitting follows a fairly rudimentary algorithm that does not support all valid C++ programs. +We walk through a given input test file line by line. Any lines that are not within a TEST definition is added to a running +"filler" text. When a TEST definition is encountered, the current filler text becomes the prefix +for that test. All subsequent lines are considered to be part of the TEST definition until the +number of starting function braces ('{') match the number of closing function braces ('}'). When +these counts are equal, the TEST definition is considered to be completed. At this point, we return +to adding lines to the "filler" text until a new TEST definition is encountered. Any "filler" text +following a TEST definition is added to the suffix of that TEST definition (this is useful for finishing +off #if statements, as is common in unit tests.). + +A state machine illustrating this algorithm at a high level is provided in the source below. + +Example: Suppose an input test `test.cu` has the following source: + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + TEST(SM90_a, 256x128x64_2x2x1) { + std::cout << "Test #1" << std::endl; + } + + // Test #2 + TEST(SM90_b, 256x128x64_1x1x1) { + std::cout << "Test #2" << std::endl; + } + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +The contents of the two resulting test files will be: + $ cat test_000.cu + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + TEST(SM90_a, 256x128x64_2x2x1) { + std::cout << "Test #1" << std::endl; + } + + // Test #2 + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + $ cat test_001.cu + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + + // Test #2 + TEST(SM90_b, 256x128x64_1x1x1) { + std::cout << "Test #2" << std::endl; + } + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +Notice that each of test_000.cu and test_001.cu contain comments that appear outside +the TEST definitions not included in each file. This is by design, as these +would be considered "filler" text. + +As expected, some cases can't be handled. Below is a non-exhaustive list: + 1. New TEST following the closing '}' of a TEST case on the same line: + TEST(x, y) { + // Do stuff + } TEST(a, b) { + + In this case, "TEST(a, b) {" will be ignored + + 2. Preprocessor macros that occur midway through a test case and extend + beyond the conclusion of a testcase + + Example: + TEST(a, b) { + // Do stuff + #if X + // Do more stuff + } + #else + // Do other stuff + } + #endif +""" + + +import argparse +import enum +import os + + +parser = argparse.ArgumentParser() +parser.add_argument("cmake_target", type=str, + help="Name of the CMake target being generated.") +parser.add_argument("src_dir", type=str, + help="Path to the directory containing test files.") +parser.add_argument("--src_files", nargs='+', + help="Files containing TEST instances to split.") +parser.add_argument("--max_tests_per_file", type=int, default=1, + help="Maximum number of TEST instances per file.") +parser.add_argument("--dst_dir", type=str, + help="Path to the directory to which to write new test files. If not set, uses src_dir.") +args = parser.parse_args() + + +if args.dst_dir == None: + args.dst_dir = args.src_dir + + +class Testcase: + """ + Lightweight tracker of test-case processing status + """ + def __init__(self, prefix_text): + # Any text that preceded the TEST definition that was + # not part of another TEST definition + self.prefix = prefix_text + + # Any text within the TEST definition + self.test = "" + + # Any text that follows the completion of the TEST definition + # and is not included in other TEST definitions + self.suffix = "" + + # Whether the test's definition has concluded + self.completed = False + + # Current balance of opening and closing curly brackets in + # the TEST definition. '{' increments the count and '}' decrements it. + # A value of 0 (when self.completed == False) indicates that the test + # has completed. + self.curly_bracket_balance = 0 + + +class ParseState(enum.Enum): + """ + State machine for processing. + Transitions occur on each line encountered in the soruce file + + + Line does not contain 'TEST(' + +----+ + | | + | v 'TEST(' + +--------+ encountered +--------------------------+ + ------>| Filler | -----------------------> | TestDeclaredWaitingStart | + +--------+ +--------------------------+ + ^ | + Number of '{' | | First '{' encountered + equals number of | +--------+ | + '}' encountered +-----------| InTest | <------------------+ + +--------+ + | ^ + | | + +----+ + Number of '{' encountered + exceeds number of '}' encountered + """ + + + # Any text that is not part of a TEST case + Filler = 0 + + # Processing text within the first { of the TEST case + # and before the en of the final } of the TEST case + InTest = 1 + + # Processing text from the start of the TEST definition + # but before the first {. This could occur if the opening { + # occurs on a separate line than the TEST definition. + TestDeclaredWaitingStart = 2 + + +cmake_src_list = [] +for filename in args.src_files: + if '.' not in filename: + # Add any non-filename arguments to the command list by default + cmake_src_list.append(filename) + continue + + if '/' in filename: + raise Exception( + f"Source files passed to {__file__} must be within the same directory " + "as the CMakeLists defining the target using the files. " + f"Provided path {filename} is in a different directory.") + + full_filename = os.path.join(args.src_dir, filename) + with open(full_filename, 'r') as infile: + lines = infile.readlines() + + # Find the number of instances of "TEST(" + ntest = sum([1 for line in lines if "TEST(" in line]) + + if ntest <= args.max_tests_per_file: + # File contains fewer than max_tests_per_file TEST instances. It does + # not need to be split + cmake_src_list.append(filename) + continue + + # Current state of the parsing state machine. We start with filler text + state = ParseState.Filler + + # List of individual TESTs found + tests = [] + + # Ongoing text that is not included in a TEST definition. This will serve + # as the prefix for any yet-to-be encountered TEST definitions. + filler_text = "" + + def add_filler_text(text): + global filler_text + # Add new text to the ongoing filler text and to the suffixes of + # any completed tests + filler_text += text + for i in range(len(tests)): + if tests[i].completed: + tests[i].suffix += text + + for line in lines: + if state == ParseState.Filler: + # We are not currently within a TEST definition. + + if 'TEST(' in line: + # We have encountered a new TEST( case. Any text preceding this + # must be added to the filler text (e.g., if we have a line of the form: + # "static constexpr int Val = 4; TEST(blah) {" + # then "static constexpr int Val = 4;" needs to be included in filler + # text, as it could be used by subsequent tests.) + splits = line.split('TEST') + + # There should not be more than one TEST definition on a given line + assert len(splits) <= 2 + + if len(splits) > 1: + if not splits[0].isspace(): + # Only add text to filler if there are non-whitespace charcters + # preceding the TEST definition in the line + filler_text += splits[0] + + # The new line is just the TEST-related line + line = 'TEST' + splits[-1] + + # Add tests and transtion to TestDeclaredWaitingStart state. + # Do not add the line to the test text of the new test case; this + # will be done in either the TestDeclaredWaitingStart state processing + # below or in the InTest state processing below. + tests.append(Testcase(filler_text)) + state = ParseState.TestDeclaredWaitingStart + else: + # Any remaining filler text is added to the running filler_text + # which will be used as the prefix for any new tests, and to the + # suffix of any completed tests + add_filler_text(line) + + if state == ParseState.TestDeclaredWaitingStart: + # We have seen a TEST definition but have not yet seen its opening {. + + if '{' in line: + # The first curly bracket for the TEST definition has been found. + # Advance to state InTests. Do not add the line to the test's text + # or change the curly-brace balance of the test; these will be done + # when processing the state == ParseState.InTest condition below. + state = ParseState.InTest + else: + tests[-1].test += line + + if state == ParseState.InTest: + # We are currently within a TEST definition. + # Process lines character-by-character looking for opening and closing + # braces. If we reach parity between opening and closing braces, the + # test is considered done. + filler_text_to_add = "" + for char in line: + if not tests[-1].completed: + tests[-1].test += char + if char == '{': + tests[-1].curly_bracket_balance += 1 + elif char == '}': + tests[-1].curly_bracket_balance -= 1 + if tests[-1].curly_bracket_balance == 0: + tests[-1].completed = True + else: + filler_text_to_add += char + + if filler_text_to_add != "" and (not filler_text_to_add.isspace() or '\n' in filler_text_to_add): + add_filler_text('\n' + filler_text_to_add) + + if tests[-1].completed: + state = ParseState.Filler + + # Write out the new files for tests + filename_prefix, filename_suffix = filename.split('.') + for i, test in enumerate(tests): + assert test.completed + new_filename = filename_prefix + '_' + str(i).zfill(3) + '.' + filename_suffix + full_new_filename = os.path.join(args.dst_dir, new_filename) + + # Replace any '\' with '/'. CMake doesn't like '\'. + full_new_filename = full_new_filename.replace('\\', '/') + + with open(full_new_filename, 'w') as outfile: + outfile.write(test.prefix + test.test + test.suffix) + cmake_src_list.append(full_new_filename) + + +for cmake_file in cmake_src_list: + print(cmake_file) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/legacy/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/legacy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cce39ec7be8e80c6c99a4f9f10cba12c63f059ad --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/legacy/__init__.py @@ -0,0 +1,5 @@ +# All kernels may be deprecated in the future (or rewrite in TileLang) +from .m_grouped_gemm import * +from .a_fused_m_grouped_gemm import * +from .a_fused_k_grouped_gemm import * +from .b_fused_k_grouped_gemm import * diff --git a/build/torch212-cxx11-cu132-x86_64-linux/legacy/a_fused_k_grouped_gemm.py b/build/torch212-cxx11-cu132-x86_64-linux/legacy/a_fused_k_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..7b42f152ac183ecbdf72aae4e121295af6504e11 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/legacy/a_fused_k_grouped_gemm.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def a_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = k + tl.arange(0, BLOCK_SIZE_K) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + rows[None, :] * M + + b_ptrs = b_ptr + k_range[:, None].to(tl.int64) * N + n_range[None, :] + a = tl.load(a_ptrs, mask=(rows >= 0)[None, :] & m_mask, other=0) + b = tl.load(b_ptrs, mask=n_mask, other=0) + acc = tl.dot(a, b, acc) + + # Write back + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def a_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == b.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert b.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K_, M = a.shape + K, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + a_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/legacy/a_fused_m_grouped_gemm.py b/build/torch212-cxx11-cu132-x86_64-linux/legacy/a_fused_m_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..41b35d539796c30bb7589b9f5b5f98bb5a4d468e --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/legacy/a_fused_m_grouped_gemm.py @@ -0,0 +1,92 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def a_fused_m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, m_row_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # b block + rows = tl.load(m_row_indices_ptr + m_range).to(tl.int64) + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + k_mask = k_range < K + a_ptrs = a_ptr + rows[:, None] * K + k_range[None, :] + b_ptrs = b_ptr + batch_id * K * N + k_range[:, None] * (1 if IS_B_K_MAJOR else N) + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + a = tl.load(a_ptrs, mask=(rows >= 0)[:, None] & k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + d = acc.to(d_ptr.dtype.element_ty) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, d, mask=n_mask) + + +def a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + m_indices, m_row_indices = mappings + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous() or b.mT.is_contiguous()) and d.is_contiguous() + assert m_indices.is_contiguous() and m_row_indices.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 and d.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and m_row_indices.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and d.size(0) == m_indices.numel() and d.size(1) == r1 + assert m_indices.numel() == m_row_indices.numel() + assert m_indices.numel() % get_mk_alignment_for_contiguous_layout() == 0 + + if d.size(0) == 0: + return d + + M_, K = a.shape + B, K, N = r0, r2, r1 + M = m_indices.numel() + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), ) + a_fused_m_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, m_indices, m_row_indices, + M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def a_fused_m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, mappings) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/legacy/b_fused_k_grouped_gemm.py b/build/torch212-cxx11-cu132-x86_64-linux/legacy/b_fused_k_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..7df8741fa9b8d00498b5de61b609ef0980a3e873 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/legacy/b_fused_k_grouped_gemm.py @@ -0,0 +1,86 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def b_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + k_range[None, :] * M + b_ptrs = b_ptr + rows[:, None] * N + n_range[None, :] + a = tl.load(a_ptrs, mask=m_mask, other=0.0) + b = tl.load(b_ptrs, mask=(rows >= 0)[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def b_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == a.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K, M = a.shape + K_, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + b_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/legacy/m_grouped_gemm.py b/build/torch212-cxx11-cu132-x86_64-linux/legacy/m_grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..e685a9ab01b44ead9d16e4d1696d08716e12e47c --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/legacy/m_grouped_gemm.py @@ -0,0 +1,84 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + # Empty tokens + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # Compute + a_ptrs = a_ptr + m_range[:, None].to(tl.int64) * K + tl.arange(0, BLOCK_SIZE_K)[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + b_ptrs = b_ptr + batch_id * K * N + \ + tl.arange(0, BLOCK_SIZE_K)[:, None].to(tl.int64) * (1 if IS_B_K_MAJOR else N) + \ + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + for k in range(0, K, BLOCK_SIZE_K): + k_mask = (k + tl.arange(0, BLOCK_SIZE_K)) < K + a = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * (1 if IS_B_K_MAJOR else N) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, accumulator.to(d_ptr.dtype.element_ty), mask=n_mask) + + +def m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous or b.mT.is_contiguous()) + assert m_indices.is_contiguous() and d.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and d.dtype == torch.bfloat16 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and a.size(0) == d.size(0) and r1 == d.size(1) + assert m_indices.numel() == a.size(0) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + M, K = a.shape + B, N, K_ = r0, r1, r2 + + # For Triton 2.0, persistent kernel will lead to errors + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + m_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, m_indices, M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, m_indices) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/legacy/tune_options.py b/build/torch212-cxx11-cu132-x86_64-linux/legacy/tune_options.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6a7f77c05ccea324a0e99e12e1506cdea0a086 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/legacy/tune_options.py @@ -0,0 +1,28 @@ +from triton import Config +from .._C import get_mk_alignment_for_contiguous_layout + + +def get_config_smem_size(config: Config, elem_bytes: int = 2): + # NOTES: FP8 kernels will not use Triton, so by default we assume BF16 kernels + return (config.kwargs['BLOCK_SIZE_M'] + config.kwargs['BLOCK_SIZE_N']) * config.kwargs['BLOCK_SIZE_K'] * elem_bytes * config.num_stages + + +_gemm_configs = [ + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), +] + +# NOTES: we only consider A100 shared memory sizes here, as legacy kernels are only used for Ampere +_gemm_configs = list(filter(lambda x: get_config_smem_size(x) <= 166912, _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) + +get_m_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +get_k_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/mega/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/mega/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..670b409dada5ef46b62324fb458a10b585ca0a01 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/mega/__init__.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import torch +from typing import Tuple, Optional +from ..utils.math import align + +# noinspection PyBroadException +try: + # noinspection PyProtectedMember + import torch.distributed._symmetric_memory as symm_mem + import torch.distributed as dist +except Exception as exception: + print(f'Failed to load mega kernels, please check your PyTorch version: {exception}') + +from .. import _C + + +class SymmBuffer: + def __init__(self, group: dist.ProcessGroup, + # MoE arguments + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu'): + self.group = group + self.num_experts = num_experts + self.num_max_tokens_per_rank = num_max_tokens_per_rank + self.num_topk = num_topk + self.hidden = hidden + self.intermediate_hidden = intermediate_hidden + + # Allocate a symmetric buffer + num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe( + group.size(), num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') + self.handle = symm_mem.rendezvous(self.buffer, group=group) + self.buffer.zero_() + self.group.barrier() + torch.cuda.synchronize() + + # Create input buffer views + (self.x, self.x_sf, + self.topk_idx, self.topk_weights, + self.l1_acts, self.l1_acts_sf, + self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) + + def destroy(self): + self.handle = None + self.buffer = None + self.group = None + self.x = None + self.x_sf = None + + +def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu') -> SymmBuffer: + # Token count must be aligned to block sizes + num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) + + return SymmBuffer( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + + +def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + # [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up] + def interleave(t, gran: int = 8) -> torch.Tensor: + g, n, *rest = t.shape + half = n // 2 + gate = t[:, :half].reshape(g, half // gran, gran, *rest) + up = t[:, half:].reshape(g, half // gran, gran, *rest) + return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) + + return interleave(l1_weights[0]), interleave(l1_weights[1]) + + +def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: + num_groups, mn, packed_sf_k = sf.shape + assert sf.dtype == torch.int and mn % 128 == 0 + result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k) + .transpose(2, 3) + .reshape(num_groups, mn, packed_sf_k)) + return torch.empty_like(sf).copy_(result) + + +def transform_weights_for_mega_moe( + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + # L1: interleave gate/up, then transpose SF for UTCCP + l1_interleaved = _interleave_l1_weights(l1_weights) + l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1])) + # L2: only transpose SF for UTCCP + l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1])) + return l1_weights, l2_weights + + +def fp8_fp4_mega_moe(y: torch.Tensor, + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor], + sym_buffer: SymmBuffer, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + recipe: Tuple[int, int, int] = (1, 1, 32), + activation: str = 'swiglu', + activation_clamp: Optional[float] = None, + fast_math: bool = True): + _C.fp8_fp4_mega_moe( + y, + l1_weights, l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer.buffer, + sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), + sym_buffer.num_max_tokens_per_rank, + sym_buffer.num_experts, sym_buffer.num_topk, + recipe, + activation, activation_clamp, + fast_math + ) diff --git a/build/torch212-cxx11-cu132-x86_64-linux/metadata.json b/build/torch212-cxx11-cu132-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..843aad1b6073c0237b3a8e4e8a99029dadeceb52 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/metadata.json @@ -0,0 +1,13 @@ +{ + "name": "deep-gemm", + "id": "_deep_gemm_cuda_388adb9", + "version": 2, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "9.0a" + ] + } +} diff --git a/build/torch212-cxx11-cu132-x86_64-linux/testing/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13a9d78dea58a6492183f9ddc50f1510a679cbe6 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/testing/__init__.py @@ -0,0 +1,4 @@ +from . import bench, numeric, utils +from .bench import * +from .numeric import * +from .utils import * diff --git a/build/torch212-cxx11-cu132-x86_64-linux/testing/bench.py b/build/torch212-cxx11-cu132-x86_64-linux/testing/bench.py new file mode 100644 index 0000000000000000000000000000000000000000..552b9aa18a037a14d0869fac3527e34bac6d7760 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/testing/bench.py @@ -0,0 +1,146 @@ +import os +import sys +import torch +from typing import Callable, Optional + + +def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache.zero_() + + # Warmup + for _ in range(num_warmups): + fn() + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: + x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + x @ y + + # Testing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_tests / 1e3 + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, + with_multiple_kernels: bool = False, + barrier: Optional[Callable] = None): + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + + # Skip profiling + # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer + if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)): + return (1, ) * len(kernel_names) if is_tuple else 1 + + # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule, acc_events=True) + with profiler: + for i in range(2): + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + if barrier is not None: + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + # noinspection PyProtectedMember + torch.cuda._sleep(int(2e7)) # ~10ms + barrier() + fn() + torch.cuda.synchronize() + profiler.step() + + # Parse the profiling table + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table {prof_lines}' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num if total_num > 0 else 0) + + return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/build/torch212-cxx11-cu132-x86_64-linux/testing/numeric.py b/build/torch212-cxx11-cu132-x86_64-linux/testing/numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..a42c4318db47593c47a4ea89fbdbcb1ffb5cd30e --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/testing/numeric.py @@ -0,0 +1,21 @@ +import torch +from typing import Iterable + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + if denominator == 0: # Which means that all elements in x and y are 0 + return 0.0 + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(*tensors): + total = 0 + for t in tensors: + if isinstance(t, (tuple, list)): + total += count_bytes(*t) + elif t is not None: + total += t.numel() * t.element_size() + return total diff --git a/build/torch212-cxx11-cu132-x86_64-linux/testing/utils.py b/build/torch212-cxx11-cu132-x86_64-linux/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d202d4192ed385f986ac5cc216acc69378d8ea9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/testing/utils.py @@ -0,0 +1,38 @@ +import functools +import os +import torch +from typing import Callable + +def get_arch_major() -> int: + major, minor = torch.cuda.get_device_capability() + return major + + +def test_filter(condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + func(*args, **kwargs) + else: + print(f'{func.__name__}:') + print(f' > Filtered by {condition}') + print() + return wrapper + return decorator + + +def ignore_env(name: str, condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + saved = os.environ.pop(name, None) + func(*args, **kwargs) + if saved is not None: + os.environ[name] = saved + else: + func(*args, **kwargs) + + return wrapper + return decorator diff --git a/build/torch212-cxx11-cu132-x86_64-linux/utils/__init__.py b/build/torch212-cxx11-cu132-x86_64-linux/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0dc6f783bcd24c7be4d6afaec9fe5d12a6847d0 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/utils/__init__.py @@ -0,0 +1,4 @@ +from . import math, layout +from .layout import * +from .math import * +from .dist import init_dist, uneven_all_gather diff --git a/build/torch212-cxx11-cu132-x86_64-linux/utils/dist.py b/build/torch212-cxx11-cu132-x86_64-linux/utils/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..426c39676f2f4374fa6e6c646cbcc0ca8b5a7b88 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/utils/dist.py @@ -0,0 +1,74 @@ +import inspect +import os +import torch +import torch.distributed as dist +from typing import Tuple + +_local_rank = None + + +def init_dist(local_rank: int, num_local_ranks: int) -> Tuple[int, int, dist.ProcessGroup]: + # NOTES: you may rewrite this function with your own cluster settings + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + node_rank = int(os.getenv('RANK', 0)) + + # Set local rank + global _local_rank + _local_rank = local_rank + + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': 'nccl', + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') + dist.init_process_group(**params) + torch.set_default_device('cuda') + torch.cuda.set_device(local_rank) + + return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes))) + + +def uneven_all_gather(tensor: torch.Tensor, dim: int = 0, group: dist.ProcessGroup = None) -> torch.Tensor: + world_size = dist.get_world_size(group) + + # Exchange sizes + local_dim_size = torch.tensor([tensor.shape[dim]], device=tensor.device, dtype=torch.long) + all_dim_sizes = [torch.zeros_like(local_dim_size) for _ in range(world_size)] + dist.all_gather(all_dim_sizes, local_dim_size, group=group) + all_dim_sizes = [s.item() for s in all_dim_sizes] + max_dim_size = max(all_dim_sizes) + + # Pad + if tensor.shape[dim] < max_dim_size: + pad_shape = list(tensor.shape) + pad_shape[dim] = max_dim_size - tensor.shape[dim] + padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + tensor_padded = torch.cat([tensor, padding], dim=dim) + else: + tensor_padded = tensor.contiguous() + + # All-gather + gathered = [torch.zeros_like(tensor_padded) for _ in range(world_size)] + dist.all_gather(gathered, tensor_padded, group=group) + + # Remove padding + trimmed = [ + torch.narrow(gathered[i], dim, 0, all_dim_sizes[i]) + for i in range(world_size) + ] + return torch.cat(trimmed, dim=dim) + + +def dist_print(s: str = '', once_in_node: bool = False) -> None: + global _local_rank + assert _local_rank is not None + if not once_in_node or _local_rank == 0: + print(s, flush=True) + dist.barrier() diff --git a/build/torch212-cxx11-cu132-x86_64-linux/utils/layout.py b/build/torch212-cxx11-cu132-x86_64-linux/utils/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..6512c5ab7aee2bb07ca8324b7c6e49c420bd9df9 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/utils/layout.py @@ -0,0 +1,21 @@ +try: + from .._C import ( + get_tma_aligned_size, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor + ) +except ImportError: + # Expected behavior for CUDA runtime version before 12.1 + pass + +# Valid for all CUDA versions +from .._C import ( + set_mk_alignment_for_contiguous_layout, + get_mk_alignment_for_contiguous_layout, + get_theoretical_mk_alignment_for_contiguous_layout, +) + +# Some alias +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/build/torch212-cxx11-cu132-x86_64-linux/utils/math.py b/build/torch212-cxx11-cu132-x86_64-linux/utils/math.py new file mode 100644 index 0000000000000000000000000000000000000000..f1582ed560344e18980054bf502083fd641c1437 --- /dev/null +++ b/build/torch212-cxx11-cu132-x86_64-linux/utils/math.py @@ -0,0 +1,143 @@ +import torch +from typing import Tuple + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + bits = x.abs().float().view(torch.int) + exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).bool().int() + return (exp.clamp(1, 254) << 23).view(torch.float) + + +def pack_ue8m0_to_int(x: torch.Tensor): + assert x.dtype == torch.float and x.size(-1) % 4 == 0 + assert (x.view(torch.int) & ((1 << 23) - 1) == 0).all() + return (x.view(torch.int) >> 23).to(torch.uint8).view(torch.int) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + padded_n = align(n, gran_k) + x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) + x_padded[:, :n] = x + x_view = x_padded.view(m, padded_n // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=2).view(m, padded_n // gran_k).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_fp8 = (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous() + return x_fp8, pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf + + +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % gran_k == 0 + m, n = x.shape + x_view = x.view(-1, gran_k, n) + x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + ax = x.abs().clamp_max(6.0) + # {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], + device=x.device, dtype=ax.dtype) + idx = torch.bucketize(ax, boundaries) + code = idx.to(torch.uint8) + sign = (x < 0) & (idx != 0) + code = code | (sign.to(torch.uint8) << 3) + return code.view(torch.int8) + + +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + m, n = x.shape + assert n % 2 == 0 + assert not use_packed_ue8m0 or use_ue8m0 + padded_n = align(n, gran_k) + x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = x_amax / 6.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = x_view * (1.0 / sf.unsqueeze(2)) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # int8, (m, padded_n) + codes2 = codes.view(m, padded_n // 2, 2) + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # int8 + return packed[:, :n // 2].contiguous(), pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf + + +def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.int8 + assert a.dim() == 2 + m, n2 = a.shape + n = n2 * 2 + assert (m % 2) == 0 + lo = a & 0x0F + hi = (a >> 4) & 0x0F + codes = torch.empty((m, n), device=a.device, dtype=torch.int8) + codes[:, 0::2], codes[:, 1::2] = lo, hi + codes_t = codes.transpose(0, 1).contiguous() + codes2 = codes_t.view(n, m // 2, 2) + out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) + return out.contiguous() + + +def _dequantize_from_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + fp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=x.device, dtype=torch.float) + sign, value_idx = (x & 0x08) != 0, (x & 0x07).to(torch.int) + value = fp4_values[value_idx] + return torch.where(sign & (value_idx != 0), -value, value) + + +def unpack_ue8m0_from_int(packed_sf: torch.Tensor) -> torch.Tensor: + return (packed_sf.view(torch.uint8).to(torch.int) << 23).view(torch.float) + + +def cast_back_from_fp4(packed: torch.Tensor, sf: torch.Tensor, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> torch.Tensor: + m, n2 = packed.shape + n = n2 * 2 + if use_packed_ue8m0: + sf = unpack_ue8m0_from_int(sf) + unpacked = torch.zeros((m, n), dtype=torch.int8, device=packed.device) + unpacked[:, ::2] = packed & 0x0F + unpacked[:, 1::2] = (packed >> 4) & 0x0F + x_dequantized = _dequantize_from_fp4_e2m1(unpacked) + group_idx = torch.arange(n, device=packed.device) // gran_k + x_restored = x_dequantized * sf[:, group_idx] + return x_restored \ No newline at end of file