kernels-bot commited on
Commit
abc8e91
·
verified ·
1 Parent(s): ba53917

Uploaded using `kernel-builder` (batch 32/32).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py +997 -0
  2. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py +725 -0
  3. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py +269 -0
  4. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py +431 -0
  5. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py +184 -0
  6. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py +65 -0
  7. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py +41 -0
  8. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py +262 -0
  9. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py +362 -0
  10. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py +41 -0
  11. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py +196 -0
  12. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py +63 -0
  13. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py +621 -0
  14. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py +482 -0
  15. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py +250 -0
  16. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py +868 -0
  17. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py +1613 -0
  18. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py +0 -0
  19. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py +415 -0
  20. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py +175 -0
  21. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py +1531 -0
  22. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py +868 -0
  23. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py +438 -0
  24. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py +427 -0
  25. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py +342 -0
  26. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py +661 -0
  27. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py +212 -0
  28. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py +753 -0
  29. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py +440 -0
  30. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py +447 -0
  31. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py +132 -0
  32. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py +36 -0
  33. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py +225 -0
  34. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py +367 -0
  35. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py +129 -0
  36. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py +42 -0
  37. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py +74 -0
  38. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_library.py +46 -0
  39. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py +46 -0
  40. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py +661 -0
  41. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py +146 -0
  42. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py +428 -0
  43. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py +44 -0
  44. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py +309 -0
  45. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py +198 -0
  46. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py +173 -0
  47. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py +142 -0
  48. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py +319 -0
  49. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py +180 -0
  50. build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py +44 -0
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Ease-of-use interface for constructing, compiling, and running CONVs
35
+
36
+ The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run
37
+ CONV2D operations in CUTLASS via Python, without specifying many configuration parameters.
38
+ Under the hood, the interface will select sensible default parameters for the many template
39
+ parameters for CUTLASS CONVs.
40
+
41
+ Note: optimal performance is not to be expected from this interface. To achieve optimal
42
+ performance, one should specify and tune each configuration parameter.
43
+
44
+ The simplest example of using this interface is the following:
45
+
46
+ .. highlight:: python
47
+ .. code-block:: python
48
+
49
+ # A, B, C, and D are torch/numpy/cupy tensor objects
50
+ plan = cutlass_cppgen.op.Conv(A, B, C, D)
51
+ plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
52
+
53
+ One can also use the interface by specifying data types of operands at construction
54
+ and using different tensor objects with these data types at runtime:
55
+
56
+ .. highlight:: python
57
+ .. code-block:: python
58
+
59
+ # The following is shorthand for:
60
+ # cutlass_cppgen.op.Conv2d(kind="fprop",
61
+ # element_A=torch.float32, element_B=torch.float32,
62
+ # element_C=torch.float32, element_D=torch.float32,
63
+ # element_accumulator=torch.float32)
64
+ plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32)
65
+
66
+ A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
67
+ B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
68
+ C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda')
69
+ D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda')
70
+ plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
71
+
72
+ A = torch.rand((32, 128), dtype=torch.float32, device='cuda')
73
+ B = torch.rand((128, 256), dtype=torch.float32, device='cuda')
74
+ C = torch.zeros((32, 256), dtype=torch.float32, device='cuda')
75
+ D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda')
76
+ plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
77
+
78
+ The interface additionally enables one to decouple the compilation of the underlying CUTLASS
79
+ kernel from its execution:
80
+
81
+ .. highlight:: python
82
+ .. code-block:: python
83
+
84
+ plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
85
+
86
+ # Do other work...
87
+
88
+ plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
89
+
90
+ # Do other work...
91
+
92
+ plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
93
+
94
+ Elementwise activation functions are easily fused to the GEMM via the interface:
95
+
96
+ .. highlight:: python
97
+ .. code-block:: python
98
+
99
+ plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
100
+ plan.activation = cutlass_cppgen.epilogue.relu
101
+
102
+ Operations can also be run asynchronously:
103
+
104
+ .. highlight:: python
105
+ .. code-block:: python
106
+
107
+ plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
108
+ args = plan.run()
109
+
110
+ # Do other work...
111
+
112
+ args.sync()
113
+ """
114
+
115
+ from __future__ import annotations
116
+ from typing import Optional
117
+ from cutlass_cppgen.utils.lazy_import import lazy_import
118
+ cuda = lazy_import("cuda.cuda")
119
+ cudart = lazy_import("cuda.cudart")
120
+ from cutlass_library import (
121
+ ConvKind,
122
+ ConvMode,
123
+ DataTypeSize,
124
+ IteratorAlgorithm,
125
+ OperationKind,
126
+ SplitKMode,
127
+ StrideSupport,
128
+ )
129
+
130
+ import cutlass_cppgen
131
+ from cutlass_cppgen import epilogue
132
+ from cutlass_cppgen.backend import compiler
133
+ from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
134
+ from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments
135
+ from cutlass_cppgen.backend.library import TensorDescription, TileDescription
136
+ from cutlass_cppgen.op.op import OperationBase
137
+ from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord
138
+ from cutlass_cppgen.utils import check, datatypes
139
+
140
+
141
+ class Conv2d(OperationBase):
142
+ """
143
+ Constructs a ``Conv2d`` object.
144
+
145
+ The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
146
+ along with the data type of output D and that used for accumulation, are bound to the ``Conv``
147
+ object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
148
+
149
+ The constructor has optional parameters for flexibly setting these parameters. The following
150
+ constructors are equivalent:
151
+
152
+ .. highlight:: python
153
+ .. code-block:: python
154
+
155
+ # Use F32 for A, B, C, D, and accumulation in fprop
156
+
157
+ # Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
158
+ Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32)
159
+
160
+ # Explicitly specify the data types to use for A, B, C, and D.
161
+ Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32,
162
+ element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32)
163
+
164
+ # Set the data types and elements from existing tensors. Note that one can use different tensors when
165
+ # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
166
+ # have the same data type as those passed in here).
167
+ # A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
168
+ Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
169
+
170
+ # Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
171
+ # those passed in via the generic ``element``
172
+ Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
173
+ element=cutlass_cppgen.DataType.f32)
174
+
175
+ The order of precedence for the setting of the data type for a given operand/output is as follows:
176
+ 1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
177
+ 2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those
178
+ 3) Otherwise, use the generic values (e.g., ``element``)
179
+
180
+ :param kind: the convolution kind (i.e. fprop, wgrad, and dgrad)
181
+ :type kind: str
182
+ :param A: tensor representing data type of operand A
183
+ :param B: tensor representing data type of operand B
184
+ :param C: tensor representing data type of operand C
185
+ :param D: tensor representing data type of operand D
186
+ :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
187
+ :param beta: scalar parameter beta from GEMM operation that scales operand C
188
+ :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
189
+ :type element: cutlass_cppgen.DataType
190
+ :param element_A: data type to be used for operand A
191
+ :type element_A: cutlass_cppgen.DataType
192
+ :param element_B: data type to be used for operand B
193
+ :type element_B: cutlass_cppgen.DataType
194
+ :param element_C: data type to be used for operand C
195
+ :type element_C: cutlass_cppgen.DataType
196
+ :param element_D: data type to be used for operand D
197
+ :type element_D: cutlass_cppgen.DataType
198
+ :param element_accumulator: data type to be used in accumulation of the product of operands A and B
199
+ :type element_accumulator: cutlass_cppgen.DataType
200
+ :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
201
+ :type cc: int
202
+ :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
203
+ :type kernel_cc: int
204
+ """
205
+ def __init__(
206
+ self, kind="fprop",
207
+ A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
208
+ element=None,
209
+ element_A=None, element_B=None, element_C=None, element_D=None,
210
+ element_accumulator=None,
211
+ cc: int = None, kernel_cc: int = None
212
+ ):
213
+ super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d)
214
+ # Verify the kernel cc
215
+ if self.current_cc in [90, 100, 101, 103]:
216
+ # The Conv2d kernel on Hopper (SM90) is currently unsupported
217
+ # Revert to use SM80-tagged kernels
218
+ cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
219
+ self.specified_kernel_cc = 80
220
+ self._reset_options(80)
221
+
222
+ # The arch is used in testing
223
+ self.arch = self.current_cc
224
+ self.name = "conv2d" + kind
225
+
226
+ # The convolution kind. (concept: cutlass_library.library.ConvKind)
227
+ self.conv_kind = datatypes.getattr_enum(ConvKind, kind)
228
+
229
+ # The element types (concept: cutlass library types) of A, B, C, and D
230
+ elements = []
231
+ layouts = []
232
+
233
+ # Complete the data types based on user-provided arguments
234
+ for elt, tens, name in zip([element_A, element_B, element_C, element_D],
235
+ [A, B, C, D],
236
+ ["A", "B", "C", "D"]):
237
+ if elt is not None and tens is not None:
238
+ raise Exception(f'Must not specify both element_{name} and tensor {name}')
239
+ if elt is None and tens is None and element is None:
240
+ raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
241
+
242
+ elt_to_set = None
243
+ lay_to_set = None
244
+
245
+ if tens is not None:
246
+ elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
247
+ else:
248
+ elt_to_set = elt if elt is not None else element
249
+
250
+ assert elt_to_set is not None
251
+
252
+ # Currently we only support layout TensorNHWC
253
+ lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC
254
+ elements.append(datatypes.library_type(elt_to_set))
255
+ layouts.append(lay_to_set)
256
+
257
+ self._element_a, self._element_b, self._element_c, self._element_d = elements
258
+ self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
259
+
260
+ self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
261
+
262
+ if element_accumulator is None:
263
+ self._element_accumulator = self._element_c
264
+ else:
265
+ self._element_accumulator = datatypes.library_type(element_accumulator)
266
+
267
+ # Default inputs if none is supplied in run()
268
+ self.A = A
269
+ self.B = B
270
+ self.C = C
271
+ self.D = D
272
+
273
+ self.alpha = alpha
274
+ self.beta = beta
275
+
276
+ # We only specify the stride of the swizzling functor here
277
+ # The actual swizzling functor is determined in run based on conv_kind and stride
278
+ self._swizzling_stride = 1
279
+
280
+ # Arguments that will be set to default value in _reset_operations
281
+ # The default tile_description and op_class are fetched from manifest of cutlass library
282
+ self._tile_description = None
283
+ self.op_class = None
284
+ # The default identity epilogue will be created
285
+ self.epilogue_functor = None
286
+
287
+ self._reset_operations()
288
+
289
+ # Arguments that will be determined online based on arguments of "run"
290
+ # based on stride, input/output channels, alignment, and conv_kind
291
+ self._iterator_algorithm = None
292
+ self._stride_support = None
293
+
294
+ def _reset_operations(self, reset_epilogue: bool = True):
295
+ # Set the default op class
296
+ datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
297
+ layout_comb = (self._layout_a, self._layout_b)
298
+
299
+ self.possible_op_classes = self.options.supporting_opclasses(
300
+ self._element_a, self._element_b, self._element_accumulator,
301
+ self._layout_a, self._layout_b, self._math_operation
302
+ )
303
+
304
+ if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
305
+ self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
306
+ elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
307
+ self.opclass = cutlass_cppgen.OpcodeClass.Simt
308
+ else:
309
+ if self._math_operation is not None:
310
+ math_op_str = f' and math operation {self._math_operation}'
311
+ else:
312
+ math_op_str = ''
313
+
314
+ raise Exception(f'No kernel configuration found for supported data type and layout '
315
+ f'combination {datatype_comb}x{layout_comb}{math_op_str}')
316
+
317
+ if reset_epilogue:
318
+ self._reset_epilogue_functor_activation(epilogue.identity)
319
+
320
+ self.alignment_pref_A = min(
321
+ 128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
322
+ self.alignment_pref_B = min(
323
+ 128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
324
+ self.alignment_pref_C = min(
325
+ 128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C")))
326
+
327
+ #
328
+ # Tile description Related
329
+ #
330
+
331
+ @property
332
+ def tile_description(self) -> TileDescription:
333
+ """
334
+ Returns the tile description
335
+ """
336
+ return self._tile_description
337
+
338
+ @tile_description.setter
339
+ def tile_description(
340
+ self, td=None):
341
+ """
342
+ Set the tile description
343
+
344
+ :param td: tile description
345
+ :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
346
+ {
347
+ "threadblock_shape": [int, int, int],
348
+ "warp_count": [int, int, int],
349
+ "stages": int,
350
+ "instruction_shape": [int, int, int] (optional),
351
+ "cluster_shape": [int, int, int] (optional)
352
+ }
353
+ """
354
+ if td is None:
355
+ return
356
+ if isinstance(td, dict):
357
+ if self._tile_description is None:
358
+ op = self.possible_operations.default_operation(self._math_operation)
359
+ self._tile_description = datatypes.td_from_profiler_op(op)
360
+ if "cluster_shape" in td.keys():
361
+ if td["cluster_shape"] != [1, 1, 1]:
362
+ cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
363
+ td["cluster_shape"] = [1, 1, 1]
364
+ td = self._tile_description.clone_and_update(td)
365
+
366
+ valid, msg = self._valid_tile_description(td)
367
+ if valid:
368
+ self._tile_description = td
369
+ else:
370
+ raise Exception(msg)
371
+
372
+ def _valid_tile_description(self, td: TileDescription) -> tuple:
373
+ """
374
+ Checks whether the provided tile description is valid for the given compute capability. At present,
375
+ this checks the following:
376
+
377
+ - Does the tile description use a number of stages supported by the compute capability in question?
378
+ - Does the tile size requested fit within shared memory?
379
+ - Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
380
+ more non-unit cluster dimensions for pre-SM90 architectures)?
381
+ - Is the kernel schedule being used supported on the architecture in question?
382
+
383
+ :param td: tile description to validate
384
+ :type td: cutlass_cppgen.backend.TileDescription
385
+ :return: tuple in which the first element is a bool indicating that the tile description is valid
386
+ and the second element is a string providing an optional error message.
387
+ :rtype: tuple
388
+ """
389
+ valid, msg = check.valid_stage_count(self.cc, self.current_cc, td)
390
+ if not valid:
391
+ return (valid, msg)
392
+
393
+ valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
394
+ if not valid:
395
+ return (valid, msg)
396
+
397
+ return valid, msg
398
+
399
+ def tile_descriptions(self) -> list:
400
+ """
401
+ Returns a list of valid tile descriptions for the operations
402
+
403
+ :returns: list of valid tile descriptions for the operations
404
+ :rtype: list
405
+ """
406
+ descriptions = []
407
+ description_str = []
408
+ for op in self.possible_operations.all_operations:
409
+ td = datatypes.td_from_profiler_op(op)
410
+
411
+ if self._math_operation is not None:
412
+ if td.math_instruction.math_operation != self._math_operation:
413
+ continue
414
+
415
+ if str(td) not in description_str:
416
+ description_str.append(str(td))
417
+ descriptions.append(td)
418
+ return descriptions
419
+
420
+ #
421
+ # Swizzling functor Related
422
+ #
423
+
424
+ @property
425
+ def swizzling_stride(self):
426
+ """
427
+ Returns the stride of swizzling currently being used by the Conv2d
428
+
429
+ :return: swizzing stride
430
+ """
431
+ return self._swizzling_stride
432
+
433
+ @swizzling_stride.setter
434
+ def swizzling_stride(self, stride: int):
435
+ """
436
+ Sets the swizzling functor to the type specified by `swizzling_functor`
437
+ """
438
+ if not isinstance(stride, int):
439
+ raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}")
440
+ self._swizzling_stride = stride
441
+
442
+ def _propose_swizzling_functor(self, stride):
443
+ """
444
+ Automatically propose the swizzling functor based on the stride
445
+ """
446
+ if self.conv_kind == ConvKind.Dgrad:
447
+ if stride[0] != 1 or stride[1] != 1:
448
+ return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
449
+
450
+ return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
451
+
452
+ #
453
+ # Iterator Algorithm Related
454
+ #
455
+
456
+ @property
457
+ def iterator_algorithm(self) -> IteratorAlgorithm:
458
+ """
459
+ Returns the iterator algorithm
460
+ """
461
+ return self._iterator_algorithm
462
+
463
+ @iterator_algorithm.setter
464
+ def iterator_algorithm(self, alg: str):
465
+ """
466
+ Sets the iterator algorithm
467
+
468
+ :param alg: The iterator algorithm
469
+ :type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
470
+ """
471
+ iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg)
472
+
473
+ # Check if the iterator algorithm is valid
474
+ if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop:
475
+ raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
476
+
477
+ self._iterator_algorithm = iterator_alg
478
+
479
+ def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm:
480
+ """
481
+ Propose a valid iterator algorithm based on problem size and alignment
482
+ """
483
+ if self.conv_kind == ConvKind.Fprop:
484
+ # Check whether the fixed channel is applicable
485
+ if problem_size.C == alignment_a:
486
+ return IteratorAlgorithm.FixedChannels
487
+ elif (problem_size.C % alignment_a == 0 and
488
+ problem_size.R <= 32 and problem_size.S <= 32):
489
+ return IteratorAlgorithm.Optimized
490
+ else:
491
+ return IteratorAlgorithm.Analytic
492
+ elif self.conv_kind == ConvKind.Dgrad:
493
+ if (problem_size.K % alignment_a == 0 and
494
+ problem_size.R <= 32 and problem_size.S <= 32 and
495
+ problem_size.C % alignment_b == 0):
496
+ return IteratorAlgorithm.Optimized
497
+ else:
498
+ return IteratorAlgorithm.Analytic
499
+ elif self.conv_kind == ConvKind.Wgrad:
500
+ if (problem_size.K % alignment_a == 0 and
501
+ problem_size.C % alignment_b == 0):
502
+ return IteratorAlgorithm.Optimized
503
+ else:
504
+ return IteratorAlgorithm.Analytic
505
+
506
+ def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
507
+ """
508
+ Validate whether the user provide iterator algorithm works for the given problem size
509
+ """
510
+ if self.conv_kind == ConvKind.Fprop:
511
+ if iterator_algorithm == IteratorAlgorithm.FixedChannels:
512
+ return problem_size.C == alignment_a
513
+ elif iterator_algorithm == IteratorAlgorithm.Optimized:
514
+ return (problem_size.C % alignment_a == 0 and
515
+ problem_size.R <= 32 and problem_size.S <= 32)
516
+ elif iterator_algorithm == IteratorAlgorithm.FewChannels:
517
+ return problem_size.C % alignment_a == 0
518
+ elif self.conv_kind == ConvKind.Dgrad:
519
+ if iterator_algorithm == IteratorAlgorithm.Optimized:
520
+ return (problem_size.K % alignment_a == 0 and
521
+ problem_size.R <= 32 and problem_size.S <= 32 and
522
+ problem_size.C % alignment_b == 0)
523
+ elif self.conv_kind == ConvKind.Wgrad:
524
+ if iterator_algorithm == IteratorAlgorithm.Optimized:
525
+ return (problem_size.K % alignment_a == 0 and
526
+ problem_size.C % alignment_b == 0)
527
+
528
+ return True
529
+
530
+ #
531
+ # Stride Support Related
532
+ #
533
+
534
+ def _propose_stride_support(self, stride):
535
+ if self.conv_kind == ConvKind.Dgrad:
536
+ if stride[0] == 1 and stride[1] == 1:
537
+ return StrideSupport.Unity
538
+
539
+ return StrideSupport.Strided
540
+
541
+ #
542
+ # Construct and Compilation
543
+ #
544
+
545
+ def construct(
546
+ self, tile_description: TileDescription = None,
547
+ alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
548
+ iterator_algorithm: IteratorAlgorithm = None,
549
+ stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
550
+ epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation:
551
+ """
552
+ Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current
553
+ kernel specification of the ``Conv2d`` object.
554
+
555
+ :param tile_description: tile description specifying shapes and operand types to use in the kernel
556
+ :type tile_description: cutlass_cppgen.backend.TileDescription
557
+ :param alignment_A: alignment of operand A
558
+ :type alignment_A: int
559
+ :param alignment_B: alignment of operand B
560
+ :type alignment_B: int
561
+ :param alignment_C: alignment of operand C
562
+ :type alignment_C: int
563
+ :param iterator_algorithm: the iterator algorithm used
564
+ :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
565
+ :param stride_support: the stride support of dgrad
566
+ :type stride_support: cutlass_library.library.StrideSupport
567
+ :param swizzling_functor: the swizzling functor
568
+ :type swizzling_functor: cutlass_cppgen.swizzle
569
+ :param epilogue_functor: the epilogue functor
570
+
571
+ :return: operation that was constructed
572
+ :rtype: cutlass_cppgen.backend.Conv2dOperation
573
+ """
574
+ # Get alignment
575
+ alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
576
+ alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
577
+ alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
578
+
579
+ tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
580
+ tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
581
+ tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
582
+
583
+ if tile_description is None:
584
+ if self.tile_description is not None:
585
+ tile_description = self.tile_description
586
+ else:
587
+ op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
588
+ tile_description = datatypes.td_from_profiler_op(op)
589
+ else:
590
+ valid, err_str = self._valid_tile_description(tile_description)
591
+ if not valid:
592
+ raise Exception(f"Invalid tile description. {err_str}")
593
+ self.tile_description = tile_description
594
+
595
+ if iterator_algorithm is None:
596
+ # If the iterator algorithm is already set
597
+ if self.iterator_algorithm is not None:
598
+ iterator_algorithm = self.iterator_algorithm
599
+ else:
600
+ # Otherwise, we conservatively use the analytic iterator for correctness
601
+ iterator_algorithm = IteratorAlgorithm.Analytic
602
+
603
+ if stride_support is None:
604
+ # If the stride support is already set
605
+ if self._stride_support is not None:
606
+ stride_support = self._stride_support
607
+ else:
608
+ # Otherwise, we assume strided
609
+ stride_support = StrideSupport.Strided
610
+
611
+ if swizzling_functor is None:
612
+ # If the swizzling functor is already set
613
+ swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
614
+
615
+ if epilogue_functor is None:
616
+ if self.epilogue_functor is not None:
617
+ epilogue_functor = self.epilogue_functor
618
+ else:
619
+ epilogue_functor = self._create_epilogue_functor_activation(self._activation)
620
+
621
+ # Reset the alignment of the epilogue functor
622
+ epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
623
+
624
+ operation = Conv2dOperation(
625
+ conv_kind=self.conv_kind,
626
+ iterator_algorithm=iterator_algorithm,
627
+ arch=self.current_cc,
628
+ tile_description=tile_description,
629
+ A=tensor_A, B=tensor_B, C=tensor_C,
630
+ stride_support=stride_support,
631
+ epilogue_functor=epilogue_functor,
632
+ swizzling_functor=swizzling_functor,
633
+ )
634
+
635
+ return operation
636
+
637
+ def compile(self, tile_description: TileDescription = None,
638
+ alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
639
+ iterator_algorithm: IteratorAlgorithm = None,
640
+ stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
641
+ epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation:
642
+ """
643
+ Emits and compiles the kernel currently specified. If ``tile_description`` and any
644
+ of the ``alignment`` parameters are set, the kernel will be chosen using this
645
+ tile description and alignments. Otherwise, a default tile description and alignment
646
+ will be used.
647
+
648
+ ::param tile_description: tile description specifying shapes and operand types to use in the kernel
649
+ :type tile_description: cutlass_cppgen.backend.TileDescription
650
+ :param alignment_A: alignment of operand A
651
+ :type alignment_A: int
652
+ :param alignment_B: alignment of operand B
653
+ :type alignment_B: int
654
+ :param alignment_C: alignment of operand C
655
+ :type alignment_C: int
656
+ :param iterator_algorithm: the iterator algorithm used
657
+ :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
658
+ :param stride_support: the stride support of dgrad
659
+ :type stride_support: cutlass_library.library.StrideSupport
660
+ :param swizzling_functor: the swizzling functor
661
+ :type swizzling_functor: cutlass_cppgen.swizzle
662
+ :param epilogue_functor: the epilogue functor
663
+
664
+ :return: operation that was compiled
665
+ :rtype: cutlass_cppgen.backend.Conv2dOperation
666
+ """
667
+
668
+ self.operation = self.construct(
669
+ tile_description, alignment_A, alignment_B, alignment_C,
670
+ iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
671
+
672
+ if print_module:
673
+ print(self.operation.rt_module.emit())
674
+
675
+ compiler.add_module([self.operation,])
676
+ return self.operation
677
+
678
+ #
679
+ # Run Related
680
+ #
681
+
682
+ def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
683
+ """
684
+ Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
685
+ is raised if it does not.
686
+
687
+ :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
688
+ :type tensor: numpy/cupy/torch array/tensor object
689
+ :param ref_dtype: data type for the tensor that this object was initialized to
690
+ :param name: identifier of the tensor to verify. Used in raising exceptions
691
+ :type name: str
692
+ """
693
+ dtype, _ = datatypes.get_datatype_and_layout(tensor)
694
+ if dtype != ref_type:
695
+ raise Exception(f'Tensor {name} with type and layout {dtype} '
696
+ f'does not match the expected type of {ref_type}.')
697
+
698
+ def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
699
+ if self.conv_kind == ConvKind.Fprop:
700
+ input = A
701
+ weight = B
702
+ output = C
703
+ output_tensor = "C"
704
+ elif self.conv_kind == ConvKind.Dgrad:
705
+ output = A
706
+ weight = B
707
+ input = C
708
+ output_tensor = "A"
709
+ elif self.conv_kind == ConvKind.Wgrad:
710
+ output = A
711
+ input = B
712
+ weight = C
713
+ output_tensor = "A"
714
+ else:
715
+ raise Exception(f"Convolution kind {self.conv_kind} is not supported")
716
+
717
+ N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV")
718
+ K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV")
719
+ _, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV")
720
+
721
+ problem_size = Conv2DProblemSize(
722
+ N_, H_, W_, C_,
723
+ K_, R_, S_, C_,
724
+ padding[0], padding[1],
725
+ stride[0], stride[1],
726
+ dilation[0], dilation[1],
727
+ ConvMode.CrossCorrelation,
728
+ 1, 1
729
+ )
730
+
731
+ if P_ != problem_size.P or Q_ != problem_size.Q:
732
+ raise Exception(
733
+ f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
734
+
735
+ return problem_size
736
+
737
+ def run(self, A=None, B=None, C=None, D=None,
738
+ stride=(1, 1), padding=(0, 0), dilation=(1, 1),
739
+ alpha=None, beta=None,
740
+ split_k=("serial", 1), sync: bool = True,
741
+ print_module: bool = False,
742
+ stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
743
+ """
744
+ Runs the kernel currently specified. If it has not already been, the kernel is emitted and
745
+ compiled. Tensors holding operands and outputs of the kernel are sourced either from the
746
+ ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
747
+ parameters provided in the call, or from those
748
+ passed in on the construction of this object -- one of the two must be specified.
749
+
750
+ By default, this call returns only once the kernel has completed. To launch the kernel
751
+ and immediately return, set ``sync=False``. In this case, it is the responsibility of the
752
+ caller to syncrhonize the results of the kernel before attempting to access outputs
753
+ by calling ``sync()`` on the arguments returned from this call.
754
+
755
+ :param A: tensor representing data type and layout of operand A
756
+ :param B: tensor representing data type and layout of operand B
757
+ :param C: tensor representing data type and layout of operand C
758
+ :param D: tensor representing data type and layout of operand D
759
+ :param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1)
760
+ :param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0)
761
+ :param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1)
762
+ :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
763
+ :param beta: scalar parameter beta from GEMM operation that scales operand C
764
+ :param split_k: a tuple (split_k_mode, split_k_slices)
765
+ :param sync: whether the call should wait for the kernel to complete before returning
766
+ :type sync: bool
767
+ :param print_module: whether to print the emitted C++ code
768
+ :type print_module: bool
769
+ :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
770
+ :type stream: :class:`cuda.cuda.CUstream`
771
+
772
+ :return: arguments passed in to the kernel
773
+ :rtype: cutlass_cppgen.backend.Conv2dArguments
774
+ """
775
+ if not stream:
776
+ stream = cuda.CUstream(0)
777
+ super().run_setup()
778
+
779
+ A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
780
+ B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
781
+ C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
782
+ D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
783
+ alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
784
+ beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
785
+
786
+ # handle the case when there is no C
787
+ if C is None:
788
+ if beta != 0:
789
+ raise Exception(f"With beta {beta} != 0, C has to be provided.")
790
+ else:
791
+ C = D
792
+
793
+ # Construct problem size based on input
794
+ # It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
795
+ problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
796
+
797
+ # Propose stride support based on input
798
+ stride_support = self._propose_stride_support(stride)
799
+
800
+ # Propose swizzling functor
801
+ swizzling_functor = self._propose_swizzling_functor(stride)
802
+
803
+ shape_a = datatypes.get_tensor_shape(A, op="CONV")
804
+ shape_b = datatypes.get_tensor_shape(B, op="CONV")
805
+ shape_c = datatypes.get_tensor_shape(C, op="CONV")
806
+
807
+ # Get the alignment
808
+ alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A")
809
+ alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B")
810
+ alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C")
811
+
812
+ alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
813
+ alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
814
+ alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
815
+
816
+ # Propose iterator algorithm based on input
817
+ if self._iterator_algorithm is None:
818
+ # Propose a default iterator algorithm based on the problem size
819
+ iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
820
+ else:
821
+ if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
822
+ iterator_algorithm = self._iterator_algorithm
823
+ else:
824
+ raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
825
+
826
+ epilogue_args = [alpha, beta]
827
+
828
+ if hasattr(self, "_activation_args"):
829
+ if isinstance(self._activation_args, list):
830
+ epilogue_args += self._activation_args
831
+ else:
832
+ epilogue_args.append(self._activation_args)
833
+
834
+ if split_k[0] == "parallel" and split_k[1] > 1:
835
+ epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
836
+ else:
837
+ epilogue_functor = self.epilogue_functor
838
+
839
+ # The alignment is determined by the iterator function (I believe)
840
+ self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
841
+ alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
842
+ swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
843
+
844
+ # Create reduction operation for parallel split-k
845
+ if split_k[0] == "parallel" and split_k[1] > 1:
846
+ epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
847
+ self.reduction_operation = ReductionOperation(
848
+ shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
849
+ element_accumulator=self._element_accumulator,
850
+ element_compute=self._element_accumulator,
851
+ epilogue_functor=epilogue_functor_reduction,
852
+ count=alignment_c
853
+ )
854
+ if print_module:
855
+ print(self.reduction_operation.rt_module.emit())
856
+ compiler.add_module([self.reduction_operation,])
857
+
858
+ arguments = Conv2dArguments(
859
+ operation=self.operation, problem_size=problem_size,
860
+ A=A, B=B, C=C, D=D,
861
+ output_op=self.operation.epilogue_type(*epilogue_args),
862
+ split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
863
+ split_k_slices=split_k[1],
864
+ stream=stream
865
+ )
866
+
867
+ self.operation.run(arguments)
868
+
869
+ if split_k[0] == "parallel" and split_k[1] > 1:
870
+ implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind)
871
+ reduction_arguments = ReductionArguments(
872
+ self.reduction_operation,
873
+ problem_size=[implicit_gemm_size.m, implicit_gemm_size.n],
874
+ partitions=split_k[1],
875
+ workspace=arguments.ptr_D,
876
+ destination=D,
877
+ source=C,
878
+ output_op=self.reduction_operation.epilogue_type(*epilogue_args),
879
+ stream=stream
880
+ )
881
+ self.reduction_operation.run(reduction_arguments)
882
+
883
+ if sync:
884
+ if split_k[0] == "parallel" and split_k[1] > 1:
885
+ reduction_arguments.sync()
886
+
887
+ # Free memory allocated by args because we are not
888
+ # calling `arguments.sync()` in this case (which will free memory)
889
+ arguments.free()
890
+ else:
891
+ arguments.sync()
892
+
893
+ return arguments
894
+
895
+ #
896
+ # Helper functions
897
+ #
898
+ @staticmethod
899
+ def output_size(input_size, weight_size, padding, stride, dilation):
900
+ problem_size = Conv2DProblemSize(
901
+ *input_size,
902
+ *weight_size,
903
+ padding[0], padding[1],
904
+ stride[0], stride[1],
905
+ dilation[0], dilation[1],
906
+ ConvMode.CrossCorrelation,
907
+ 1, 1
908
+ )
909
+ return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
910
+
911
+
912
+ #
913
+ # Easy to use interfaces for fprop, wgrad, and dgrad
914
+ #
915
+
916
+ class Conv2dFprop(Conv2d):
917
+ def __init__(
918
+ self,
919
+ input=None, weight=None, C=None, output=None, alpha=1, beta=0,
920
+ element=None,
921
+ element_input=None, element_weight=None, element_C=None, element_output=None,
922
+ element_accumulator=None,
923
+ cc: int = None, kernel_cc: int = None):
924
+ A, B, D = input, weight, output
925
+ element_A, element_B, element_D = element_input, element_weight, element_output
926
+ super().__init__(
927
+ "fprop", A, B, C, D, alpha, beta, element,
928
+ element_A, element_B, element_C, element_D,
929
+ element_accumulator, cc, kernel_cc)
930
+
931
+ def run(
932
+ self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
933
+ stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
934
+ sync: bool = True, print_module: bool = False,
935
+ stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
936
+
937
+ if not stream:
938
+ stream = cuda.CUstream(0)
939
+
940
+ A, B, D = input, weight, output
941
+ return super().run(
942
+ A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
943
+
944
+
945
+ class Conv2dDgrad(Conv2d):
946
+ def __init__(
947
+ self,
948
+ grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
949
+ element=None,
950
+ element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
951
+ element_accumulator=None,
952
+ cc: int = None, kernel_cc: int = None):
953
+ A, B, D = grad_output, weight, grad_input
954
+ element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
955
+ super().__init__(
956
+ "dgrad", A, B, C, D, alpha, beta, element,
957
+ element_A, element_B, element_C, element_D,
958
+ element_accumulator, cc, kernel_cc)
959
+
960
+ def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
961
+ stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
962
+ sync: bool = True, print_module: bool = False,
963
+ stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
964
+ #
965
+ if not stream:
966
+ stream = cuda.CUstream(0)
967
+
968
+ A, B, D = grad_output, weight, grad_input
969
+ return super().run(
970
+ A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
971
+
972
+
973
+ class Conv2dWgrad(Conv2d):
974
+ def __init__(
975
+ self,
976
+ grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
977
+ element=None,
978
+ element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
979
+ element_accumulator=None,
980
+ cc: int = None, kernel_cc: int = None):
981
+ A, B, D = grad_output, input, grad_weight
982
+ element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
983
+ super().__init__(
984
+ "wgrad", A, B, C, D, alpha, beta, element,
985
+ element_A, element_B, element_C, element_D,
986
+ element_accumulator, cc, kernel_cc)
987
+
988
+ def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
989
+ stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
990
+ sync: bool = True, print_module: bool = False,
991
+ stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
992
+ if not stream:
993
+ stream = cuda.CUstream(0)
994
+
995
+ A, B, D = grad_output, input, grad_weight
996
+ return super().run(
997
+ A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Ease-of-use interface for constructing, compiling, and running GEMMs.
35
+
36
+ The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run
37
+ GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
38
+ Under the hood, the interface will select sensible default parameters for the many template
39
+ parameters for CUTLASS GEMMs.
40
+
41
+ Note: optimal performance is not to be expected from this interface. To achieve optimal
42
+ performance, one should specify and tune each configuration parameter.
43
+
44
+ The simplest example of using this interface is the following:
45
+
46
+ .. highlight:: python
47
+ .. code-block:: python
48
+
49
+ # A, B, C, and D are torch/numpy/cupy tensor objects
50
+ plan = cutlass_cppgen.op.Gemm(A, B, C, D)
51
+ plan.run()
52
+
53
+
54
+ One can also use the interface by specifying data types of operands at construction
55
+ and using different tensor objects with these data types at runtime:
56
+
57
+ .. highlight:: python
58
+ .. code-block:: python
59
+
60
+ # The following is shorthand for:
61
+ # cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32,
62
+ # element_C=torch.float32, element_D=torch.float32,
63
+ # element_accumulator=torch.float32,
64
+ # layout=cutlass_cppgen.LayoutType.RowMajor)
65
+ plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
66
+
67
+ A0 = torch.rand((128, 256), device='cuda')
68
+ B0 = torch.rand((256, 64), device='cuda')
69
+ C0 = torch.zeros((128, 64), device='cuda')
70
+ D0 = torch.zeros((128, 64), device.'cuda')
71
+ plan.run(A0, B0, C0, D0)
72
+
73
+ A = torch.rand((32, 128), device='cuda')
74
+ B = torch.rand((128, 256), device='cuda')
75
+ C = torch.zeros((32, 256), device='cuda')
76
+ D = torch.zeros((32, 256), device.'cuda')
77
+ plan.run(A1, B1, C1, D1)
78
+
79
+ The interface additionally enables one to decouple the compilation of the underlying CUTLASS
80
+ kernel from its execution:
81
+
82
+ .. highlight:: python
83
+ .. code-block:: python
84
+
85
+ plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
86
+ plan.compile()
87
+
88
+ # Do other work...
89
+
90
+ plan.run(A0, B0, C0, D0)
91
+
92
+ # Do other work...
93
+
94
+ plan.run(A1, B1, C1, D1)
95
+
96
+ Elementwise activation functions are easily fused to the GEMM via the interface:
97
+
98
+ .. highlight:: python
99
+ .. code-block:: python
100
+
101
+ plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
102
+ plan.activation = cutlass_cppgen.epilogue.relu
103
+
104
+ Operations can also be run asynchronously:
105
+
106
+ .. highlight:: python
107
+ .. code-block:: python
108
+
109
+ plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
110
+ args = plan.run()
111
+
112
+ # Do other work...
113
+
114
+ args.sync()
115
+ """
116
+ from __future__ import annotations
117
+ from typing import Optional
118
+ from math import prod
119
+
120
+ from cutlass_cppgen.utils.lazy_import import lazy_import
121
+ cuda = lazy_import("cuda.cuda")
122
+ from cutlass_library import (
123
+ DataType,
124
+ DataTypeSize,
125
+ GemmUniversalMode,
126
+ KernelScheduleSuffixes,
127
+ )
128
+
129
+ import cutlass_cppgen
130
+ from cutlass_cppgen import epilogue, swizzle
131
+ from cutlass_cppgen.backend import compiler
132
+ from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
133
+ from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
134
+ from cutlass_cppgen.backend.library import TensorDescription, TileDescription
135
+ from cutlass_cppgen.op.op import OperationBase
136
+ from cutlass_cppgen.shape import GemmCoord
137
+ from cutlass_cppgen.utils import check, datatypes
138
+
139
+
140
+ class Gemm(OperationBase):
141
+ """
142
+ Constructs a ``Gemm`` object.
143
+
144
+ The data types and layouts of operands A, B, and C, along with the data type of output D
145
+ and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime --
146
+ these are not to be changed after a ``Gemm`` has been constructed.
147
+
148
+ The constructor has optional parameters for flexibly setting these parameters. The following
149
+ constructors are equivalent:
150
+
151
+ .. highlight:: python
152
+ .. code-block:: python
153
+
154
+ # Use F32 for A, B, C, D, and accumulation. All operands are row major.
155
+
156
+ # Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
157
+ # for operands to the same values.
158
+ Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
159
+
160
+ # Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
161
+ Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32,
162
+ element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
163
+
164
+ # Set the data types and elements from existing tensors. Note that one can use different tensors when
165
+ # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
166
+ # have the same data type and layout as those passed in here).
167
+ # A, B, C, and D are row-major torch.Tensor objects of type torch.float32
168
+ Gemm(A=A, B=B, C=C, D=D)
169
+
170
+ # Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
171
+ # the same as that for D, at present)
172
+ Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor,
173
+ layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor)
174
+
175
+ # Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
176
+ # and layouts will inherit those passed in via the generic ``element`` and ``layout``
177
+ Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor,
178
+ element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
179
+
180
+ The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
181
+ 1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
182
+ 2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those
183
+ 3) Otherwise, use the generic values (e.g., ``element``, ``layout``)
184
+
185
+ :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
186
+ :type cc: int
187
+ :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
188
+ :type kernel_cc: int
189
+ :param A: tensor representing data type and layout of operand A
190
+ :param B: tensor representing data type and layout of operand B
191
+ :param C: tensor representing data type and layout of operand C
192
+ :param D: tensor representing data type and layout of operand D
193
+ :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
194
+ :param beta: scalar parameter beta from GEMM operation that scales operand C
195
+ :param element_accumulator: data type to be used in accumulation of the product of operands A and B
196
+ :type element_accumulator: cutlass_cppgen.DataType
197
+ :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
198
+ :type element: cutlass_cppgen.DataType
199
+ :param layout: generic layout type to be used for operands A, B, C, and D
200
+ :type layout: cutlass_cppgen.LayoutType
201
+ :param element_A: data type to be used for operand A
202
+ :type element_A: cutlass_cppgen.DataType
203
+ :param element_B: data type to be used for operand B
204
+ :type element_B: cutlass_cppgen.DataType
205
+ :param element_C: data type to be used for operand C
206
+ :type element_C: cutlass_cppgen.DataType
207
+ :param element_D: data type to be used for operand D
208
+ :type element_D: cutlass_cppgen.DataType
209
+ :param layout_A: layout of operand A
210
+ :type layout_A: cutlass_cppgen.LayoutType
211
+ :param layout_B: layout of operand B
212
+ :type layout_B: cutlass_cppgen.LayoutType
213
+ :param layout_C: layout of operand C
214
+ :type layout_C: cutlass_cppgen.LayoutType
215
+ :param layout_D: layout of operand D
216
+ :type layout_D: cutlass_cppgen.LayoutType
217
+ """
218
+
219
+ def __init__(
220
+ self, A=None, B=None, C=None, D=None,
221
+ alpha=1.0, beta=0.0, element_accumulator=None,
222
+ element=None, layout=None,
223
+ element_A=None, element_B=None, element_C=None, element_D=None,
224
+ layout_A=None, layout_B=None, layout_C=None,
225
+ cc: int = None, kernel_cc: int = None
226
+ ):
227
+ super().__init__(cc=cc, kernel_cc=kernel_cc)
228
+ self.name = "gemm"
229
+ self.compiled = False
230
+
231
+ elements = []
232
+ layouts = []
233
+
234
+ # Check that at least one of the following is set for each tensor (illustrated assuming tensor A):
235
+ # ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout``
236
+ for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D],
237
+ [layout_A, layout_B, layout_C, layout_C],
238
+ [A, B, C, D],
239
+ ["A", "B", "C", "D"]):
240
+ if elt is not None and tens is not None:
241
+ raise Exception(f'Must not specify both element_{name} and tensor {name}')
242
+ if lay is not None and tens is not None:
243
+ raise Exception(f'Must not specify both layout_{name} and tensor {name}')
244
+ if elt is None and tens is None and element is None:
245
+ raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
246
+ if lay is None and tens is None and layout is None:
247
+ raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.')
248
+
249
+ elt_to_set = None
250
+ lay_to_set = None
251
+ if tens is not None:
252
+ elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens)
253
+ else:
254
+ elt_to_set = elt if elt is not None else element
255
+ lay_to_set = lay if lay is not None else layout
256
+
257
+ elements.append(datatypes.library_type(elt_to_set))
258
+ layouts.append(lay_to_set)
259
+
260
+ self._element_a, self._element_b, self._element_c, self._element_d = elements
261
+ self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
262
+
263
+ if element_accumulator is None:
264
+ self._element_accumulator = self._element_c
265
+ else:
266
+ self._element_accumulator = datatypes.library_type(element_accumulator)
267
+
268
+ self.A = A
269
+ self.B = B
270
+ self.C = C
271
+ self.D = D
272
+
273
+ self.alpha = alpha
274
+ self.beta = beta
275
+
276
+ self.epilogue_functor = None
277
+ self.op_class = None
278
+ self._tile_description = None
279
+
280
+ self._reset_operations()
281
+
282
+ self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1
283
+
284
+ def _reset_operations(self, reset_epilogue: bool = True):
285
+ # Set the default op class
286
+ datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
287
+ layout_comb = (self._layout_a, self._layout_b)
288
+
289
+ self.possible_op_classes = self.options.supporting_opclasses(
290
+ self._element_a, self._element_b, self._element_accumulator,
291
+ self._layout_a, self._layout_b, self._math_operation)
292
+
293
+ if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
294
+ self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
295
+ elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
296
+ self.opclass = cutlass_cppgen.OpcodeClass.Simt
297
+ else:
298
+ if self._math_operation is not None:
299
+ math_op_str = f' and math operation {self._math_operation}'
300
+ else:
301
+ math_op_str = ''
302
+
303
+ raise Exception(f'No kernel configuration found for supported data type and layout '
304
+ f'combination {datatype_comb}x{layout_comb}{math_op_str}')
305
+
306
+ if reset_epilogue:
307
+ self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity)
308
+
309
+ @property
310
+ def swizzling_functor(self):
311
+ """
312
+ Returns the type of the swizzling functor currently being used by the GEMM
313
+
314
+ :return: swizzing functor type
315
+ """
316
+ return self._swizzling_functor
317
+
318
+ @swizzling_functor.setter
319
+ def swizzling_functor(self, swizzling_functor):
320
+ """
321
+ Sets the swizzling functor to the type specified by `swizzling_functor`
322
+ """
323
+ if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK:
324
+ if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
325
+ raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
326
+
327
+ if self.current_cc in [90, 100, 101, 103]:
328
+ raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+')
329
+ self._swizzling_functor = swizzling_functor
330
+
331
+ #
332
+ # Tile description Related
333
+ #
334
+
335
+ @property
336
+ def tile_description(self) -> TileDescription:
337
+ """
338
+ Returns the tile description
339
+ """
340
+ return self._tile_description
341
+
342
+ @tile_description.setter
343
+ def tile_description(
344
+ self, td=None):
345
+ """
346
+ Set the tile description
347
+
348
+ :param td: tile description
349
+ :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
350
+ {
351
+ "threadblock_shape": [int, int, int],
352
+ "warp_count": [int, int, int],
353
+ "stages": int,
354
+ "instruction_shape": [int, int, int] (optional),
355
+ "cluster_shape": [int, int, int] (optional)
356
+ }
357
+ """
358
+ if td is None:
359
+ return
360
+ if isinstance(td, dict):
361
+ if self._tile_description is None:
362
+ op = self.possible_operations.default_operation(self._math_operation)
363
+ self._tile_description = datatypes.td_from_profiler_op(op)
364
+ td = self._tile_description.clone_and_update(td)
365
+
366
+ valid, msg = self._valid_tile_description(td)
367
+ if valid:
368
+ self._tile_description = td
369
+ else:
370
+ raise Exception(msg)
371
+
372
+ def _valid_tile_description(self, td: TileDescription) -> tuple:
373
+ """
374
+ Checks whether the provided tile description is valid for the given compute capability. At present,
375
+ this checks the following:
376
+
377
+ - Does the tile description use a number of stages supported by the compute capability in question?
378
+ - Does the tile size requested fit within shared memory?
379
+ - Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
380
+ more non-unit cluster dimensions for pre-SM90 architectures)?
381
+ - Is the kernel schedule being used supported on the architecture in question?
382
+
383
+ :param td: tile description to validate
384
+ :type td: cutlass_cppgen.backend.TileDescription
385
+ :return: tuple in which the first element is a bool indicating that the tile description is valid
386
+ and the second element is a string providing an optional error message.
387
+ :rtype: tuple
388
+ """
389
+ valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d)
390
+ if not valid:
391
+ return (valid, msg)
392
+
393
+ valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
394
+ if not valid:
395
+ return (valid, msg)
396
+
397
+ valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler)
398
+
399
+ if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0:
400
+ valid = False
401
+ msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103"
402
+
403
+ return valid, msg
404
+
405
+ def tile_descriptions(self) -> list:
406
+ """
407
+ Returns a list of valid tile descriptions for the operations
408
+
409
+ :returns: list of valid tile descriptions for the operations
410
+ :rtype: list
411
+ """
412
+ tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
413
+ if self._math_operation is not None:
414
+ tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation]
415
+ return tds
416
+
417
+ def construct(
418
+ self, tile_description: TileDescription = None,
419
+ alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
420
+ """
421
+ Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current
422
+ kernel specification of the ``Gemm`` object.
423
+
424
+ :param tile_description: tile description specifying shapes and operand types to use in the kernel
425
+ :type tile_description: cutlass_cppgen.backend.TileDescription
426
+ :param alignment_A: alignment of operand A
427
+ :type alignment_A: int
428
+ :param alignment_B: alignment of operand B
429
+ :type alignment_B: int
430
+ :param alignment_C: alignment of operand C
431
+ :type alignment_C: int
432
+
433
+ :return: operation that was constructed
434
+ :rtype: cutlass_cppgen.backend.GemmOperationUniversal
435
+ """
436
+ alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
437
+ alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
438
+ alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A)
439
+ alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
440
+
441
+ tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
442
+ tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
443
+
444
+ if alignment_C is None:
445
+ alignment_C = max(self.possible_operations.alignments("C"))
446
+ if self._element_c != DataType.void:
447
+ alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C)
448
+
449
+ if tile_description is None:
450
+ if self._tile_description is None:
451
+ op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
452
+ tile_description = datatypes.td_from_profiler_op(op)
453
+
454
+ # The selected op may have lower alignment than that determined above, so we must
455
+ # reset alignment here.
456
+ alignment_C = op.C.alignment
457
+ else:
458
+ tile_description = self._tile_description
459
+ else:
460
+ valid, err_str = self._valid_tile_description(tile_description)
461
+ if not valid:
462
+ raise Exception(f"Invalid tile description. {err_str}")
463
+ self._tile_description = tile_description
464
+
465
+ tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
466
+ self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
467
+
468
+ operation = GemmOperationUniversal(
469
+ arch=self.current_cc,
470
+ tile_description=tile_description,
471
+ A=tensor_A, B=tensor_B, C=tensor_C,
472
+ epilogue_functor=self.epilogue_functor,
473
+ swizzling_functor=self._swizzling_functor,
474
+ )
475
+
476
+ return operation
477
+
478
+ def compile(self, tile_description: TileDescription = None,
479
+ alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
480
+ print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal:
481
+ """
482
+ Emits and compiles the kernel currently specified. If ``tile_description`` and any
483
+ of the ``alignment`` parameters are set, the kernel will be chosen using this
484
+ tile description and alignments. Otherwise, a default tile description and alignment
485
+ will be used.
486
+
487
+ :param tile_description: tile description specifying shapes and operand types to use in the kernel
488
+ :type tile_description: cutlass_cppgen.backend.TileDescription
489
+ :param alignment_A: alignment of operand A
490
+ :type alignment_A: int
491
+ :param alignment_B: alignment of operand B
492
+ :type alignment_B: int
493
+ :param alignment_C: alignment of operand C
494
+ :type alignment_C: int
495
+ :param print_module: whether to print the emitted C++ code
496
+ :type print_module: bool
497
+
498
+ :return: operation that was compiled
499
+ :rtype: cutlass_cppgen.backend.GemmOperationUniversal
500
+ """
501
+ self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
502
+
503
+ if print_module:
504
+ print(self.operation.rt_module.emit())
505
+
506
+ compiler.add_module([self.operation,])
507
+ return self.operation
508
+
509
+ def _verify_rank(self, tensor):
510
+ """
511
+ Verifies that ``tensor`` has rank greater than 1
512
+
513
+ :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
514
+ :type tensor: numpy/cupy/torch array/tensor object
515
+ """
516
+ if len(tensor.shape) < 2:
517
+ raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
518
+
519
+ def _get_batch_count(self, A, B, C, D) -> int:
520
+ """
521
+ Returns the batch count specified by the tensors A, B, C, and D and verifies that these
522
+ tensors match in batch size. Presence of a batch dimension is detected by one of the
523
+ tensors being rank 3. If a batch dimension is present, it must be present in one of
524
+ operands A, B, or C (but need not be in all), and must be present in D.
525
+
526
+ :param A: tensor A
527
+ :type A: numpy/cupy/torch array/tensor object
528
+ :param B: tensor B
529
+ :type B: numpy/cupy/torch array/tensor object
530
+ :param C: tensor C
531
+ :type C: numpy/cupy/torch array/tensor object
532
+ :param D: tensor D
533
+ :type D: numpy/cupy/torch array/tensor object
534
+
535
+ :return: tuple of batch count dimensions
536
+ :rtype: tuple
537
+ """
538
+ A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1
539
+ B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1
540
+
541
+ if 1 not in [A_batch, B_batch]:
542
+ if A_batch != B_batch:
543
+ raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}")
544
+ return max(A_batch, B_batch)
545
+
546
+ def _get_batch_stride(self, tensor) -> int:
547
+ """
548
+ Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
549
+
550
+ :param tensor: tensor object to process
551
+ :type tensor: numpy/cupy/torch array/tensor object
552
+
553
+ :return: stride between each matrix in the batch
554
+ :rtype: int
555
+ """
556
+ if tensor is not None and len(tensor.shape) > 2:
557
+ return tensor.shape[-2] * tensor.shape[-1]
558
+ else:
559
+ return 0
560
+
561
+ def _get_problem_args(self, A, B, C, D) -> tuple:
562
+ """
563
+ Returns the problem size and GEMM universal mode to use for the
564
+ given operands.
565
+
566
+ :param A: tensor A
567
+ :type A: numpy/cupy/torch array/tensor object
568
+ :param B: tensor B
569
+ :type B: numpy/cupy/torch array/tensor object
570
+ :param C: tensor C
571
+ :type C: numpy/cupy/torch array/tensor object
572
+ :param D: tensor D
573
+ :type D: numpy/cupy/torch array/tensor object
574
+
575
+ :return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int)
576
+ :rtype: tuple
577
+ """
578
+ M, K = A.shape[-2:]
579
+ N = B.shape[-1]
580
+ mode = GemmUniversalMode.Gemm
581
+
582
+ batch_count = self._get_batch_count(A, B, C, D)
583
+ returned_batch_count = batch_count
584
+
585
+ # If we are running a batched GEMM in which there is a nonzero batch stride
586
+ # only for A, then we can fold the batched dimension of A into the M dimension
587
+ # (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
588
+ # and C are row major. A similar operation can be performed if only B has a nonzero
589
+ # batch dimension
590
+ if batch_count > 1:
591
+ A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor
592
+ B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor
593
+ C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor
594
+
595
+ # Consider a Tensor to be batched if its rank is > 2 and
596
+ # the product of the modes beyond rank 2 equals our pre-determined batch size.
597
+ batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count)
598
+
599
+ if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row:
600
+ M *= batch_count
601
+ returned_batch_count = 1
602
+ elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row:
603
+ N *= batch_count
604
+ returned_batch_count = 1
605
+ else:
606
+ mode = GemmUniversalMode.Batched
607
+
608
+ return GemmCoord(M, N, K), mode, returned_batch_count
609
+
610
+ def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
611
+ """
612
+ Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
613
+ is raised if it does not.
614
+
615
+ :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
616
+ :type tensor: numpy/cupy/torch array/tensor object
617
+ :param ref_dtype: data type for the tensor that this object was initialized to
618
+ :param ref_layout: layout for the tensor that this object was initialized to
619
+ :param name: identifier of the tensor to verify. Used in raising exceptions
620
+ :type name: str
621
+ """
622
+ dtype, layout = datatypes.get_datatype_and_layout(tensor)
623
+ if dtype != ref_type or layout != ref_layout:
624
+ try:
625
+ # Attempt to transpose the tensor to fit the desired layout
626
+ tensor = tensor.transpose(-1, -2)
627
+ except:
628
+ raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
629
+ f'does not match the expected type and '
630
+ f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
631
+
632
+ def run(self, A=None, B=None, C=None, D=None,
633
+ alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
634
+ stream: Optional[cuda.CUstream] = None) -> GemmArguments:
635
+ """
636
+ Runs the kernel currently specified. If it has not already been, the kernel is emitted and
637
+ compiled. Tensors holding operands and outputs of the kernel are sourced either from the
638
+ ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
639
+ parameters provided in this call, or from those
640
+ passed in on the construction of this object -- one of the two must be specified.
641
+
642
+ By default, this call returns only once the kernel has completed. To launch the kernel
643
+ and immediately return, set ``sync=False``. In this case, it is the responsibility of the
644
+ caller to syncrhonize the results of the kernel before attempting to access outputs
645
+ by calling ``sync()`` on the arguments returned from this call.
646
+
647
+ :param A: tensor representing data type and layout of operand A
648
+ :param B: tensor representing data type and layout of operand B
649
+ :param C: tensor representing data type and layout of operand C
650
+ :param D: tensor representing data type and layout of operand D
651
+ :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
652
+ :param beta: scalar parameter beta from GEMM operation that scales operand C
653
+ :param sync: whether the call should wait for the kernel to complete before returning
654
+ :type sync: bool
655
+ :param print_module: whether to print the emitted C++ code
656
+ :type print_module: bool
657
+ :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
658
+ :type stream: :class:`cuda.cuda.CUstream`
659
+
660
+ :return: arguments passed in to the kernel
661
+ :rtype: cutlass_cppgen.backend.GemmArguments
662
+ """
663
+ if not stream:
664
+ stream = cuda.CUstream(0)
665
+ super().run_setup()
666
+ A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
667
+ B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
668
+ C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
669
+ D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
670
+ alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
671
+ beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
672
+
673
+ is_void_c = self._element_c == DataType.void
674
+
675
+ self._verify_rank(A)
676
+ self._verify_rank(B)
677
+ if not is_void_c:
678
+ self._verify_rank(C)
679
+ self._verify_rank(D)
680
+
681
+ alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A")
682
+ alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B")
683
+
684
+ # Set C alignment based on D.shape so as to correctly get an alignment with void-C
685
+ # kernels, for which `C` is None.
686
+ alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C")
687
+ self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
688
+ alignment_C=alignment_c, print_module=print_module)
689
+
690
+ problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
691
+
692
+ if mode == GemmUniversalMode.Gemm or batch_count == 1:
693
+ kwargs = {'split_k_slices': 1}
694
+ else:
695
+ kwargs = {
696
+ 'batch': batch_count,
697
+ 'batch_strides': {
698
+ 'A': self._get_batch_stride(A),
699
+ 'B': self._get_batch_stride(B),
700
+ 'C': self._get_batch_stride(C),
701
+ 'D': self._get_batch_stride(D)
702
+ }
703
+ }
704
+
705
+ kwargs['stream'] = stream
706
+
707
+ if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
708
+ output_op = self.operation.epilogue_type(visitor_args)
709
+ else:
710
+ output_op = self.operation.epilogue_type(alpha, beta)
711
+
712
+ arguments = GemmArguments(
713
+ operation=self.operation, problem_size=problem_size,
714
+ A=A, B=B, C=C, D=D,
715
+ output_op=output_op,
716
+ gemm_mode=mode,
717
+ **kwargs
718
+ )
719
+
720
+ self.operation.run(arguments)
721
+
722
+ if sync:
723
+ arguments.sync()
724
+
725
+ return arguments
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Ease-of-use interface for constructing, compiling, and running GEMMs.
35
+
36
+ The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run
37
+ grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
38
+ Under the hood, the interface will select sensible default parameters for the many template
39
+ parameters for CUTLASS grouped GEMMs.
40
+
41
+ Note: optimal performance is not to be expected from this interface. To achieve optimal
42
+ performance, one should specify and tune each configuration parameter.
43
+
44
+ The simplest example of using this interface is the following:
45
+
46
+ .. highlight:: python
47
+ .. code-block:: python
48
+
49
+ # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
50
+ plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
51
+ plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
52
+ """
53
+ from __future__ import annotations
54
+ from typing import Optional
55
+ from cutlass_library import DataTypeSize
56
+
57
+ from cutlass_cppgen.utils.lazy_import import lazy_import
58
+ cuda = lazy_import("cuda.cuda")
59
+ from cutlass_cppgen.backend.gemm_operation import (
60
+ GemmGroupedArguments,
61
+ GemmOperationGrouped,
62
+ )
63
+ from cutlass_cppgen.backend.library import (
64
+ SchedulerMode,
65
+ TensorDescription,
66
+ TileDescription,
67
+ )
68
+ from cutlass_cppgen.op.gemm import Gemm
69
+ from cutlass_cppgen.shape import GemmCoord
70
+ from cutlass_cppgen.utils import check, datatypes
71
+
72
+
73
+ class GroupedGemm(Gemm):
74
+ """
75
+ Constructs a ``GroupedGemm`` object.
76
+
77
+ The data types and layouts of operands A, B, and C, along with the data type of output D
78
+ and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --
79
+ these are not to be changed after a ``GroupedGemm`` has been constructed.
80
+
81
+ The constructor has optional parameters for flexibly setting these parameters. Please see the constructor
82
+ for ``Gemm`` for examples of these.
83
+
84
+ :param cc: compute capability of device to generate kernels for
85
+ :type cc: int
86
+ :param A: tensor representing data type and layout of operands A
87
+ :param B: tensor representing data type and layout of operands B
88
+ :param C: tensor representing data type and layout of operands C
89
+ :param D: tensor representing data type and layout of operands D
90
+ :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
91
+ :param beta: scalar parameter beta from GEMM operation that scales operand C
92
+ :param element_accumulator: data type to be used in accumulation of the product of operands A and B
93
+ :type element_accumulator: cutlass_cppgen.DataType
94
+ :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
95
+ :type element: cutlass_cppgen.DataType
96
+ :param layout: generic layout type to be used for operands A, B, C, and D
97
+ :type layout: cutlass_cppgen.LayoutType
98
+ :param element_A: data type to be used for operand A
99
+ :type element_A: cutlass_cppgen.DataType
100
+ :param element_B: data type to be used for operand B
101
+ :type element_B: cutlass_cppgen.DataType
102
+ :param element_C: data type to be used for operand C
103
+ :type element_C: cutlass_cppgen.DataType
104
+ :param element_D: data type to be used for operand D
105
+ :type element_D: cutlass_cppgen.DataType
106
+ :type layout_A: layout of operand A
107
+ :param layout_A: cutlass_cppgen.LayoutType
108
+ :type layout_B: layout of operand B
109
+ :param layout_B: cutlass_cppgen.LayoutType
110
+ :type layout_C: layout of operand C
111
+ :param layout_C: cutlass_cppgen.LayoutType
112
+ :type layout_D: layout of operand D
113
+ :param layout_D: cutlass_cppgen.LayoutType
114
+ """
115
+
116
+ def __init__(
117
+ self, A=None, B=None, C=None, D=None,
118
+ alpha=1.0, beta=0.0, element_accumulator=None,
119
+ element=None, layout=None,
120
+ element_A=None, element_B=None, element_C=None, element_D=None,
121
+ layout_A=None, layout_B=None, layout_C=None,
122
+ cc: int = None,
123
+ ):
124
+ super().__init__(
125
+ A=A, B=B, C=C, D=D,
126
+ alpha=alpha, beta=beta,
127
+ element_accumulator=element_accumulator,
128
+ element=element, layout=layout,
129
+ element_A=element_A, element_B=element_B,
130
+ element_C=element_C, element_D=element_D,
131
+ layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
132
+ cc=cc
133
+ )
134
+
135
+ # Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80
136
+ if self.current_cc in [90, 100, 101, 103]:
137
+ self._reset_options(80)
138
+ self._reset_operations(reset_epilogue=False)
139
+
140
+ self.name = "grouped_gemm"
141
+
142
+ @Gemm.swizzling_functor.setter
143
+ def swizzling_functor(self, swizzling_functor):
144
+ """
145
+ Sets the swizzling functor to the type specified by `swizzling_functor`
146
+ """
147
+ raise Exception('Grouped GEMM does not currently support different swizzling functors')
148
+
149
+ def construct(self, tile_description: TileDescription = None,
150
+ alignment_A: int = None,
151
+ alignment_B: int = None,
152
+ alignment_C: int = None) -> GemmOperationGrouped:
153
+ """
154
+ Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current
155
+ kernel specification of the ``Gemm`` object.
156
+
157
+ :param tile_description: tile description specifying shapes and operand types to use in the kernel
158
+ :type tile_description: cutlass_cppgen.backend.TileDescription
159
+ :param alignment_A: alignment of operand A
160
+ :type alignment_A: int
161
+ :param alignment_B: alignment of operand B
162
+ :type alignment_B: int
163
+ :param alignment_C: alignment of operand C
164
+ :type alignment_C: int
165
+
166
+ :return: operation that was constructed
167
+ :rtype: cutlass_cppgen.backend.GemmOperationGrouped
168
+ """
169
+ alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A")))
170
+ alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B")))
171
+ alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C")))
172
+
173
+ self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
174
+
175
+ tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
176
+ tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
177
+ tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
178
+
179
+ if tile_description is None:
180
+ op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
181
+ tile_description = datatypes.td_from_profiler_op(op)
182
+ else:
183
+ valid, err_str = self._valid_tile_description(tile_description)
184
+ if not valid:
185
+ raise Exception(f"Invalid tile description. {err_str}")
186
+ self.tile_description = tile_description
187
+
188
+ operation = GemmOperationGrouped(
189
+ arch=self.current_cc,
190
+ tile_description=tile_description,
191
+ A=tensor_A, B=tensor_B, C=tensor_C,
192
+ epilogue_functor=self.epilogue_functor,
193
+ swizzling_functor=self._swizzling_functor,
194
+ precompute_mode=SchedulerMode.Device)
195
+
196
+ return operation
197
+
198
+ def run(self, A, B, C, D,
199
+ alpha=None, beta=None, sync: bool = True,
200
+ print_module: bool = False,
201
+ stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments:
202
+ """
203
+ Runs the kernel currently specified.
204
+
205
+ By default, this call returns only once the kernel has completed. To launch the kernel
206
+ and immediately return, set ``sync=False``. In this case, it is the responsibility of the
207
+ caller to syncrhonize the results of the kernel before attempting to access outputs
208
+ by calling ``sync()`` on the arguments returned from this call.
209
+
210
+ :param A: list of tensors representing data type and layout of operand A
211
+ :type A: list
212
+ :param B: list of tensors representing data type and layout of operand B
213
+ :type B: list
214
+ :param C: list of tensors representing data type and layout of operand C
215
+ :type C: list
216
+ :param D: list of tensors representing data type and layout of operand D
217
+ :type D: list
218
+ :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
219
+ :param beta: scalar parameter beta from GEMM operation that scales operand C
220
+ :param sync: whether the call should wait for the kernel to complete before returning
221
+ :type sync: bool
222
+ :param print_module: whether to print the emitted C++ code
223
+ :type print_module: bool
224
+ :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
225
+ :type stream: :class:`cuda.cuda.CUstream`
226
+
227
+ :return: arguments passed in to the kernel
228
+ :rtype: cutlass_cppgen.backend.GemmGroupedArguments
229
+ """
230
+ if not stream:
231
+ stream = cuda.CUstream(0)
232
+
233
+ super().run_setup()
234
+
235
+ if len(A) != len(B) or len(A) != len(C) or len(A) != len(D):
236
+ raise Exception("Lengths of A, B, C, and D lists must be equal")
237
+
238
+ problem_sizes = []
239
+ As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4))
240
+ for i in range(len(A)):
241
+ As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A")
242
+ Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
243
+ Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
244
+ Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
245
+ problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
246
+
247
+ alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
248
+ beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
249
+
250
+ alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As))
251
+ alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs))
252
+ alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs))
253
+ self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
254
+ alignment_C=alignment_c, print_module=print_module)
255
+
256
+ arguments = GemmGroupedArguments(
257
+ operation=self.operation,
258
+ problem_sizes=problem_sizes,
259
+ A=As, B=Bs, C=Cs, D=Ds,
260
+ output_op=self.operation.epilogue_type(alpha, beta),
261
+ stream=stream
262
+ )
263
+
264
+ self.operation.run(arguments)
265
+
266
+ if sync:
267
+ arguments.sync()
268
+
269
+ return arguments
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
35
+ """
36
+
37
+ from bisect import bisect_left
38
+
39
+ from cutlass_library import (
40
+ DataType,
41
+ DataTypeSize,
42
+ MathOperation,
43
+ OperationKind,
44
+ SharedMemPerCC
45
+ )
46
+
47
+ import cutlass_cppgen
48
+ from cutlass_cppgen import get_option_registry
49
+ from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
50
+ from cutlass_cppgen.backend.evt.passes.util import cc_map
51
+ from cutlass_cppgen.backend.utils.device import device_cc
52
+ from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
53
+ from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
54
+ from cutlass_cppgen.swizzle import get_swizzling_functors
55
+ from cutlass_cppgen.utils import datatypes, check
56
+
57
+
58
+ class OperationBase:
59
+ """
60
+ Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
61
+ """
62
+
63
+ def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm):
64
+ """
65
+ :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
66
+ :type cc: int
67
+ :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
68
+ :type kernel_cc: int
69
+ :param operation_kind: class of operation that will be performed (e.g., GEMM, Conv)
70
+ :type operation_kind: cutlass_library.OperationKind
71
+ """
72
+ self.operation_kind = operation_kind
73
+ self.cc = cc if cc is not None else device_cc()
74
+ self.specified_kernel_cc = kernel_cc is not None
75
+ self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
76
+ self.tile_description = None
77
+ self._math_operation = None
78
+
79
+ self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
80
+
81
+ if self.options is None:
82
+ raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
83
+
84
+ # Default activation function: identity
85
+ self._activation = identity
86
+
87
+ def _find_closest_cc(self, cc: int) -> int:
88
+ """
89
+ Returns the closest CC in _generator_ccs less than or equal to `cc`
90
+
91
+ :param cc: compute capability to query
92
+ :type cc: int
93
+
94
+ :returns: closest CC in _generator_ccs less than or equal to `cc`
95
+ :rtype: int
96
+ """
97
+ if cc in _generator_ccs:
98
+ return cc
99
+
100
+ # Find closest CC lower than this CC
101
+ idx = bisect_left(_generator_ccs, cc)
102
+ if idx == 0:
103
+ raise Exception(f'No valid CC to fall back to for {cc}')
104
+ return _generator_ccs[idx-1]
105
+
106
+ def activations(self) -> list:
107
+ """
108
+ Returns possible activation functions that can be used
109
+
110
+ :return: list of activation functions that can be used
111
+ :rtype: list
112
+ """
113
+ return get_activations()
114
+
115
+ def swizzling_functors(self) -> list:
116
+ """
117
+ Returns possible swizzling functions that can be used
118
+
119
+ :return: list of swizzling functions that can be used
120
+ :rtype: list
121
+ """
122
+ return get_swizzling_functors()
123
+
124
+ def _reset_options(self, cc: int):
125
+ """
126
+ Resets the kernel options based on cc
127
+
128
+ :param cc: compute capability to reset to
129
+ :type cc: int
130
+ """
131
+ if cc != self.current_cc:
132
+ if cc not in _generator_ccs:
133
+ raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
134
+ self.current_cc = cc
135
+ self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind)
136
+
137
+ def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
138
+ """
139
+ Verifies the following properties:
140
+ 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
141
+ 2) If ``scalar`` is not ``None``, its datatype must match matches the current version
142
+ set by the plan (i.e., those in ``ref_dtype``)
143
+
144
+ If either of these properties does not hold, an exception is raised. If these properties hold and
145
+ ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
146
+
147
+ :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
148
+ :type scalar: numpy/cupy/torch scalar
149
+ :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
150
+ :type ref_scalar: numpy/cupy/torch scalar
151
+ :param ref_dtype: data type for the scalar that this object was initialized to
152
+ :param name: identifier of the scalar to verify. Used in raising exceptions
153
+ :type name: str
154
+
155
+ :return: valid scalar to use
156
+ :rtype: numpy/cupy/torch scalar
157
+ """
158
+ if scalar is None:
159
+ if ref_scalar is None:
160
+ raise Exception(f"Scalar {name} must be set.")
161
+ return ref_scalar
162
+ if hasattr(scalar, "dtype"):
163
+ dtype = datatypes.library_type(scalar.dtype)
164
+ if dtype != ref_dtype:
165
+ raise Exception(
166
+ f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
167
+ )
168
+ return scalar
169
+
170
+ def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
171
+ """
172
+ Verifies the following properties:
173
+ If ref_dtype is not void:
174
+ 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
175
+ 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
176
+ set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
177
+ If ref_dtype is void:
178
+ Neither ``tensor`` nor ``ref_tensor`` are set
179
+
180
+ If either of these properties does not hold, an exception is raised. If these properties hold and
181
+ ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
182
+
183
+ :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
184
+ :type tensor: numpy/cupy/torch array/tensor object
185
+ :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
186
+ :type ref_tensor: numpy/cupy/torch array/tensor object
187
+ :param ref_dtype: data type for the tensor that this object was initialized to
188
+ :param ref_layout: layout for the tensor that this object was initialized to
189
+ :param name: identifier of the tensor to verify. Used in raising exceptions
190
+ :type name: str
191
+
192
+ :return: valid tensor object to use
193
+ :rtype: numpy/cupy/torch array/tensor object
194
+ """
195
+ if ref_dtype == DataType.void:
196
+ if tensor is not None or ref_tensor is not None:
197
+ raise Exception("Operands with element DataType.void must not be provided a tensor")
198
+ return None
199
+
200
+ if tensor is None:
201
+ if ref_tensor is None:
202
+ raise Exception(f"Tensor {name} must be set.")
203
+ return ref_tensor
204
+
205
+ self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
206
+ return tensor
207
+
208
+ @property
209
+ def opclass(self) -> cutlass_cppgen.OpcodeClass:
210
+ """
211
+ Returns the opcode class currently in use
212
+
213
+ :return: opcode class currently in use
214
+ :rtype: cutlass_cppgen.OpcodeClass
215
+ """
216
+ return self.op_class
217
+
218
+ @opclass.setter
219
+ def opclass(self, oc: cutlass_cppgen.OpcodeClass):
220
+ if isinstance(oc, str):
221
+ oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
222
+ if oc in self.possible_op_classes:
223
+ self.op_class = oc
224
+ else:
225
+ raise Exception(
226
+ f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
227
+ f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
228
+ f'layout combination ({self._layout_a}, {self._layout_b}).')
229
+
230
+ # Changing the op class also changes the possible operations available. Reset these.
231
+ self.possible_operations = self.options.operations(
232
+ self.op_class, self._element_a, self._element_b,
233
+ self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
234
+
235
+ # Changing the op class changes the elements per access in the epilogue. Reset this.
236
+ if self.epilogue_functor is not None:
237
+ self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
238
+
239
+ @property
240
+ def math_operation(self) -> cutlass_cppgen.MathOperation:
241
+ """
242
+ Returns the math operation currently in use
243
+
244
+ :return: math operation currently in use
245
+ :rtype: cutlass_cppgen.MathOperation
246
+ """
247
+ return self._math_operation
248
+
249
+ @math_operation.setter
250
+ def math_operation(self, mo: cutlass_cppgen.MathOperation):
251
+ if isinstance(mo, str):
252
+ mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
253
+
254
+ if not self.specified_kernel_cc:
255
+ if self.current_cc in [90, 100, 101, 103]:
256
+ # CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
257
+ # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
258
+ cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
259
+ self._reset_options(80)
260
+ self._reset_operations(reset_epilogue=False)
261
+ elif self.current_cc in [90, 100, 101, 103]:
262
+ raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
263
+ "To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
264
+ "parameter when constructing the plan.")
265
+
266
+ self._math_operation = mo
267
+ self._reset_operations()
268
+
269
+ def _elements_per_access(self):
270
+ if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
271
+ return 1
272
+ elif self._element_c != DataType.void:
273
+ return 128 // DataTypeSize[self._element_c]
274
+ else:
275
+ return 128 // max(self.possible_operations.alignments("C"))
276
+
277
+ def _create_epilogue_functor_activation(self, activation):
278
+ """
279
+ Returns the epilogue functor with given activation function
280
+ """
281
+ if self.epilogue_functor is None:
282
+ elements_per_access = self._elements_per_access()
283
+ else:
284
+ elements_per_access = self.epilogue_functor.epilogue_vector_length
285
+
286
+ if not self.specified_kernel_cc:
287
+ if self.current_cc in [90, 100, 101, 103] and activation != identity:
288
+ # CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
289
+ # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
290
+ cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
291
+ if self._element_c != self._element_d:
292
+ raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
293
+ self._reset_options(80)
294
+ self._reset_operations(reset_epilogue=False)
295
+ elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None):
296
+ # SM80 fallback kernels are currently used. Since an identity activation is requested,
297
+ # we can switch back to using SM90 kernels.
298
+ self._reset_options(self.cc)
299
+ self._reset_operations(reset_epilogue=False)
300
+ else:
301
+ if self.current_cc in [90, 100, 101, 103] and activation != identity:
302
+ raise Exception("Epilogues with elementwise fusion are not currently supported "
303
+ "in the Python interface for 3.x kernels. To use 2.x kernels "
304
+ "with fused elementwise epilogues, do not set the `kernel_cc` "
305
+ "parameter when constructing the plan.")
306
+
307
+ return get_activation_epilogue(
308
+ activation,
309
+ self._element_d,
310
+ elements_per_access,
311
+ self._element_accumulator,
312
+ self._element_accumulator,
313
+ )
314
+
315
+ def _reset_epilogue_functor_activation(self, activation):
316
+ """
317
+ Set the epilogue functor based on the provided activation function
318
+ """
319
+ self.epilogue_functor = self._create_epilogue_functor_activation(activation)
320
+
321
+ def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
322
+ """
323
+ Reset the alignment of the current epilogue functor based on alignment C
324
+ """
325
+ if isinstance(epilogue_functor, EpilogueFunctorVisitor):
326
+ return epilogue_functor
327
+
328
+ if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
329
+ # Identity epilogue does not have 'activation_functor'
330
+ activation = identity
331
+ else:
332
+ activation = epilogue_functor.activation_functor
333
+
334
+ epilogue_functor = get_activation_epilogue(
335
+ activation,
336
+ self._element_d,
337
+ alignment,
338
+ self._element_accumulator,
339
+ self._element_accumulator,
340
+ )
341
+ return epilogue_functor
342
+
343
+ @property
344
+ def activation(self):
345
+ """
346
+ Returns the type of the current activation function used
347
+ """
348
+ if hasattr(self.epilogue_functor, "activation_functor"):
349
+ return self.epilogue_functor.activation_functor
350
+ else:
351
+ return identity
352
+
353
+ @activation.setter
354
+ def activation(self, act):
355
+ """
356
+ Sets the type of the activation function to use
357
+ Activation can come with a set of arguments
358
+
359
+ :param act: type of activation function to use
360
+ :type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
361
+
362
+ """
363
+ if isinstance(act, tuple):
364
+ if isinstance(act[0], str):
365
+ act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
366
+ else:
367
+ act_fn = act[0]
368
+ self._reset_epilogue_functor_activation(act_fn)
369
+ self._activation_args = act[1]
370
+ self._activation = act[0]
371
+ else:
372
+ if isinstance(act, str):
373
+ act = getattr(cutlass_cppgen.backend.epilogue, act)
374
+ self._reset_epilogue_functor_activation(act)
375
+ self._activation = act
376
+
377
+ @property
378
+ def epilogue_visitor(self):
379
+ """
380
+ Return the epilogue functor
381
+ """
382
+ return self.epilogue_functor
383
+
384
+ @epilogue_visitor.setter
385
+ def epilogue_visitor(self, visitor):
386
+ """
387
+ Create the epilogue visitor
388
+ """
389
+ self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor)
390
+
391
+ # The epilogue_functor may consume too much shared memory
392
+ # Reset the possible operations
393
+ if self.cc not in [90, 100, 101, 103]:
394
+ # The shared memory is only a concern for sm90+ epilogue
395
+ # In sm80, the epilogue and mainloop share the shared memory
396
+ return
397
+
398
+ datatype_comb = self.possible_operations.datatype_comb
399
+ layout_comb = self.possible_operations.layout_comb
400
+ new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
401
+ for operation in self.possible_operations.all_operations:
402
+ td = datatypes.td_from_profiler_op(operation)
403
+ # Filter invalid epilogue schedules
404
+ if cc_map[self.cc] == 90 and td.epilogue_schedule not in [
405
+ cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
406
+ cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
407
+ continue
408
+ epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
409
+
410
+ # Verify the maximum number of mainloop stages
411
+ mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm)
412
+ smem_capacity_bytes = SharedMemPerCC[self.cc] << 10
413
+ mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
414
+ if mainloop_stages < 2:
415
+ # Mainloop stages must >= 2
416
+ continue
417
+
418
+ new_possible_operations.add(operation)
419
+ if len(new_possible_operations.all_operations) == 0:
420
+ raise RuntimeError(
421
+ "The epilogue consumes too much shared memory. "
422
+ "No valid tile description is found in the generator.")
423
+ self.possible_operations = new_possible_operations
424
+
425
+
426
+ def run_setup(self):
427
+ """
428
+ Steps that must be taken before caling `plan.run()`
429
+ """
430
+ # Initialize the memory pool if, if not already done
431
+ cutlass_cppgen.get_memory_pool()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for expressing shapes
35
+ """
36
+
37
+ from cutlass_library import (
38
+ ConvMode,
39
+ ConvKind,
40
+ LayoutType
41
+ )
42
+ from cutlass_cppgen.backend.c_types import (
43
+ Conv2DProblemSize_,
44
+ GemmCoord_,
45
+ GemmCoordBatched_
46
+ )
47
+
48
+
49
+ class MatrixCoord:
50
+ def __init__(self, row, col):
51
+ self._row = row
52
+ self._col = col
53
+
54
+ @property
55
+ def row(self):
56
+ return self._row
57
+
58
+ @property
59
+ def column(self):
60
+ return self._col
61
+
62
+ def leading_dimension(self, layout: LayoutType) -> int:
63
+ """
64
+ Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord.
65
+
66
+ :param layout: layout of matrix
67
+ :type layout: cutlass_library.LayoutType
68
+
69
+ :returns: leading dimension
70
+ :rtype: int
71
+ """
72
+ if layout == LayoutType.RowMajor:
73
+ return self._col
74
+ elif layout == LayoutType.ColumnMajor:
75
+ return self._row
76
+ else:
77
+ raise Exception(f'Unsupported layout for leading dimension calculation: {layout}')
78
+
79
+
80
+ class GemmCoord:
81
+ def __init__(self, m: int, n: int, k: int):
82
+ self._m = m
83
+ self._n = n
84
+ self._k = k
85
+
86
+ @property
87
+ def m(self) -> int:
88
+ return self._m
89
+
90
+ @property
91
+ def n(self) -> int:
92
+ return self._n
93
+
94
+ @property
95
+ def k(self) -> int:
96
+ return self._k
97
+
98
+ @property
99
+ def mk(self) -> MatrixCoord:
100
+ return MatrixCoord(self._m, self._k)
101
+
102
+ @property
103
+ def mn(self) -> MatrixCoord:
104
+ return MatrixCoord(self._m, self._n)
105
+
106
+ @property
107
+ def kn(self) -> MatrixCoord:
108
+ return MatrixCoord(self._k, self._n)
109
+
110
+ @property
111
+ def ctype(self) -> GemmCoord_:
112
+ return GemmCoord_(self._m, self._n, self._k)
113
+
114
+ def batched_ctype(self, batch_count: int) -> GemmCoordBatched_:
115
+ return GemmCoordBatched_(self._m, self._n, self._k, batch_count)
116
+
117
+
118
+ class Conv2DProblemSize:
119
+ def __init__(
120
+ self, n: int, h: int, w: int, c: int,
121
+ k: int, r: int, s: int, c_: int,
122
+ pad_h: int, pad_w: int, stride_h: int, stride_w: int,
123
+ dilation_h: int, dilation_w: int, mode: ConvMode=ConvMode.CrossCorrelation,
124
+ split_k_slices: int=1, groups: int=1):
125
+
126
+ self.N = n
127
+ self.H = h
128
+ self.W = w
129
+ self.C = c
130
+ self.K = k
131
+ self.R = r
132
+ self.S = s
133
+ self.pad_h = pad_h
134
+ self.pad_w = pad_w
135
+ self.stride_h = stride_h
136
+ self.stride_w = stride_w
137
+ self.dilation_h = dilation_h
138
+ self.dilation_w = dilation_w
139
+ self.mode = int(mode)
140
+ self.split_k_slices = split_k_slices
141
+ self.groups = groups
142
+ self.P = ((h + pad_h * 2 - r * dilation_h) // stride_h) + 1
143
+ self.Q = ((w + pad_w * 2 - s * dilation_w) // stride_w) + 1
144
+
145
+ @property
146
+ def ctype(self) -> Conv2DProblemSize_:
147
+ return Conv2DProblemSize_(self)
148
+
149
+ def implicit_gemm_size(self, kind: ConvKind):
150
+ if kind == ConvKind.Fprop:
151
+ return GemmCoord(
152
+ self.N * self.P * self.Q,
153
+ self.K,
154
+ self.R * self.S * self.C // self.groups
155
+ )
156
+ elif kind == ConvKind.Dgrad:
157
+ return GemmCoord(
158
+ self.N * self.H * self.W,
159
+ self.C,
160
+ self.R * self.S * self.K
161
+ )
162
+ elif kind == ConvKind.Wgrad:
163
+ return GemmCoord(
164
+ self.K,
165
+ self.R * self.S * self.C,
166
+ self.N * self.P * self.Q
167
+ )
168
+
169
+ @staticmethod
170
+ def from_sizes(input_size, weight_size):
171
+ K, R, S, _ = weight_size
172
+ pad_h = R // 2
173
+ pad_w = S // 2
174
+ stride_h = 1
175
+ stride_w = 1
176
+ dilation_h = 1
177
+ dilation_w = 1
178
+ return Conv2DProblemSize(
179
+ *input_size,
180
+ *weight_size,
181
+ pad_h, pad_w,
182
+ stride_h, stride_w,
183
+ dilation_h, dilation_w
184
+ )
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Registry of swizzling functions
35
+ """
36
+
37
+ from cutlass_library import SwizzlingFunctor
38
+
39
+
40
+ IdentitySwizzle1 = SwizzlingFunctor.Identity1
41
+ IdentitySwizzle2 = SwizzlingFunctor.Identity2
42
+ IdentitySwizzle4 = SwizzlingFunctor.Identity4
43
+ IdentitySwizzle8 = SwizzlingFunctor.Identity8
44
+ HorizontalSwizzle = SwizzlingFunctor.Horizontal
45
+ ThreadblockSwizzleStreamK = SwizzlingFunctor.StreamK
46
+ StridedDgradIdentitySwizzle1 = SwizzlingFunctor.StridedDgradIdentity1
47
+ StridedDgradIdentitySwizzle4 = SwizzlingFunctor.StridedDgradIdentity4
48
+ StridedDgradHorizontalSwizzle = SwizzlingFunctor.StridedDgradHorizontal
49
+
50
+
51
+ _swizzling_functors = [
52
+ IdentitySwizzle1,
53
+ IdentitySwizzle2,
54
+ IdentitySwizzle4,
55
+ IdentitySwizzle8,
56
+ HorizontalSwizzle,
57
+ ThreadblockSwizzleStreamK,
58
+ StridedDgradIdentitySwizzle1,
59
+ StridedDgradIdentitySwizzle4,
60
+ StridedDgradHorizontalSwizzle,
61
+ ]
62
+
63
+
64
+ def get_swizzling_functors():
65
+ return _swizzling_functors
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from cutlass_cppgen.utils.check import (
34
+ alignment_or_default,
35
+ calculate_smem_usage,
36
+ calculate_smem_usage_per_stage,
37
+ valid_cluster_shape,
38
+ valid_schedule,
39
+ valid_stage_count,
40
+ update_alignment,
41
+ )
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utility functions for checking constraints on kernels and calculating kernel attributes
35
+ """
36
+
37
+ import ctypes
38
+
39
+ from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.backend.library import TileDescription
43
+
44
+
45
+ def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int:
46
+ """
47
+ Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
48
+
49
+ :param td: tile description to compute shared memory of
50
+ :type td: TileDescription
51
+ :param operation_kind: identifier for the type of operation being performed
52
+ :type operation_kind: cutlass_library.OperationKind
53
+
54
+ :return: number of bytes of shared memory consumed by a single stage
55
+ :rtype: int
56
+ """
57
+ m, n, k = td.blackwell_threadblock_shape
58
+ if td.is_2sm:
59
+ m //= 2
60
+
61
+ if operation_kind == OperationKind.Gemm:
62
+ stage_barrier_bytes = 32
63
+ return (
64
+ (DataTypeSize[td.math_instruction.element_a] * m * k // 8)
65
+ + (DataTypeSize[td.math_instruction.element_b] * k * n // 8)
66
+ + stage_barrier_bytes
67
+ )
68
+ else:
69
+ raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")
70
+
71
+
72
+ def calculate_smem_usage(operation) -> int:
73
+ """
74
+ Returns the amount of shared memory in bytes consumed by a kernel.
75
+
76
+ :return: number of bytes of shared memory consumed by the operation
77
+ :return: int
78
+ """
79
+ _per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind)
80
+ return _per_stage * operation.tile_description.stages
81
+
82
+
83
+ def valid_stage_count(
84
+ cc: int,
85
+ kernel_cc: int,
86
+ td: TileDescription,
87
+ element_C: cutlass_cppgen.DataType = None,
88
+ element_D: cutlass_cppgen.DataType = None,
89
+ verbose: bool = True) -> tuple:
90
+ """
91
+ Checks whether a device with `cc` supports the number of stages within `tile_description`, both
92
+ based on raw limits on the number of stages and based on shared memory capacity
93
+
94
+ :param cc: compute capability of device in question
95
+ :type cc: int
96
+ :param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS)
97
+ :type kernel_cc: int
98
+ :param td: tile description to check
99
+ :type td: TileDescription
100
+ :param element_C: data type of operand C
101
+ :type element_C: cutlass_cppgen.DataType
102
+ :param element_D: data type of operand D
103
+ :type element_D: cutlass_cppgen.DataType
104
+ :param verbose: whether to log warnings
105
+ :type verbose: bool
106
+
107
+ :return: tuple with the first element indicating whether the provided tile description is
108
+ valid for the provided device and the second element being an error message
109
+ :rtype: tuple
110
+ """
111
+ if kernel_cc in [90, 100, 101, 103]:
112
+ if (td.stages is None or td.stages == 0):
113
+ # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
114
+ # determines the stage count to use. Thus, all settings are valid in these scenarios.
115
+ return (True, "")
116
+ elif verbose:
117
+ cutlass_cppgen.logger.warning(
118
+ "Setting an explicit stage count for SM90 kernels currently may "
119
+ "result in compilation errors if the combination of tile shape, "
120
+ "stage count, and shared memory requirement of the epilogue exceeds "
121
+ "the available shared memory per SM.")
122
+
123
+ if td.stages <= 0:
124
+ return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")
125
+
126
+ if cc < 80 and td.stages != 2:
127
+ return (False, f"Tile description has stage count of {td.stages}, "
128
+ f"but only 2 stages are supported on SM{cc}.")
129
+
130
+ # The calculation below does not consider shared memory used by the epilogue and, thus,
131
+ # only catches cases in which the mainloop exceeds the device's shared memory capacity.
132
+ # This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the
133
+ # mainloop and epilogue is shared.
134
+ smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm)
135
+ smem_usage_mainloop = (smem_per_stage * td.stages)
136
+ smem_arch = SharedMemPerCC[cc] << 10
137
+ if smem_usage_mainloop > smem_arch:
138
+ return ( False,
139
+ "Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
140
+ f"Details:\n"
141
+ f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and "
142
+ f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n"
143
+ f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.")
144
+
145
+ return (True, "")
146
+
147
+
148
+ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
149
+ """
150
+ Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`.
151
+
152
+ :param cc: compute capability of device in question
153
+ :type cc: int
154
+ :param cluster_shape: dimensions of thread block cluster shape to check
155
+ :type cluster_shape: list
156
+
157
+ :return: tuple with the first element indicating whether the provided cluster shape is
158
+ valid for the provided device and the second element being an error message
159
+ :rtype: tuple
160
+ """
161
+
162
+ if cc < 90 or cc in [120, 121]:
163
+ if cluster_shape != [1, 1, 1]:
164
+ return (False,
165
+ f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of "
166
+ f"{cluster_shape} for SM{cc}.")
167
+ else:
168
+ return (True, "")
169
+
170
+ if len(cluster_shape) != 3:
171
+ return (False,
172
+ f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}")
173
+
174
+ if cluster_shape[2] != 1:
175
+ return (False,
176
+ "CUTLASS kernels currently require the third dimension of cluster shape to be 1. "
177
+ f"Received cluster shape of {cluster_shape}.")
178
+
179
+ return (True, "")
180
+
181
+
182
+ def valid_schedule(
183
+ cc: int,
184
+ kernel_schedule: cutlass_cppgen.KernelScheduleType,
185
+ epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
186
+ tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple:
187
+ """
188
+ Checks that the kernel and epilogue schedules passed in are a valid combination for
189
+ a device of compute capability ``cc``.
190
+
191
+ :param cc: compute capability of device in question
192
+ :type cc: int
193
+ :param kernel_schedule: kernel schedule type
194
+ :type kernel_schedule: cutlass_cppgen.KernelScheduleType
195
+ :param epilogue_schedule: epilogue schedule type
196
+ :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
197
+ :param tile_scheduler: tile scheduler type
198
+ :type tile_scheduler: cutlass_cppgen.TileSchedulerType
199
+
200
+ :return: tuple with the first element indicating whether the provided schedules are
201
+ valid for the provided device and the second element being an error message
202
+ :rtype: tuple
203
+ """
204
+ kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto)
205
+ epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto)
206
+ tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default)
207
+ if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default):
208
+ return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)")
209
+
210
+ if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)):
211
+ return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
212
+
213
+ if not tile_scheduler_default:
214
+ cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
215
+ cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
216
+ if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
217
+ return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
218
+ return (True, "")
219
+
220
+
221
+ def alignment_or_default(alignment_provided: int, default_alignment: int) -> int:
222
+ """
223
+ Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
224
+ that `alignment_provided` does not exceed `default_alignment`.
225
+
226
+ :param alignment_provided: alignment preference specified. Can be None.
227
+ :type alignment_provided: int
228
+ :param default_alignment: alignment to use if `alignment_provided` is None
229
+ :type default_alignment: int
230
+
231
+ :return: alignment to use
232
+ :rtype: int
233
+ """
234
+ if alignment_provided is not None:
235
+ if alignment_provided > default_alignment:
236
+ raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
237
+ return alignment_provided
238
+
239
+ return default_alignment
240
+
241
+
242
+ def update_alignment(alignment_provided:int, default_alignment: int) -> int:
243
+ """
244
+ Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
245
+ that `alignment_provided` does not exceed `default_alignment`.
246
+
247
+ :param alignment_provided: alignment preference specified. Can be None.
248
+ :type alignment_provided: int
249
+ :param default_alignment: alignment to use if `alignment_provided` is None
250
+ :type default_alignment: int
251
+
252
+ :return: alignment to use
253
+ :rtype: int
254
+ """
255
+ if alignment_provided is not None:
256
+ if alignment_provided > default_alignment:
257
+ if alignment_provided % default_alignment == 0:
258
+ return default_alignment
259
+ raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
260
+ return alignment_provided
261
+
262
+ return default_alignment
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utility functions for converting between frontend datatypes and CUTLASS datatypes
35
+ """
36
+
37
+ import cutlass_cppgen
38
+ from cutlass_library import (
39
+ DataTypeSize,
40
+ MathOperation,
41
+ MathInstruction
42
+ )
43
+ from cutlass_cppgen.backend.library import (
44
+ TileDescription,
45
+ )
46
+
47
+ bfloat16_available = None
48
+ cupy_available = None
49
+ numpy_available = None
50
+ torch_available = None
51
+ _library_to_cupy_dict = None
52
+ _library_to_numpy_dict = None
53
+ _library_to_torch_dict = None
54
+ _torch_to_library_dict = None
55
+
56
+
57
+ def is_numpy_available():
58
+ global numpy_available, _library_to_numpy_dict
59
+ if numpy_available is None:
60
+ try:
61
+ import numpy as np
62
+
63
+ numpy_available = True
64
+ _library_to_numpy_dict = {
65
+ cutlass_cppgen.DataType.f16: np.float16,
66
+ cutlass_cppgen.DataType.f32: np.float32,
67
+ cutlass_cppgen.DataType.f64: np.float64,
68
+ cutlass_cppgen.DataType.s8: np.int8,
69
+ cutlass_cppgen.DataType.s32: np.int32,
70
+ }
71
+ except ImportError:
72
+ numpy_available = False
73
+ _library_to_numpy_dict = {}
74
+ return numpy_available
75
+
76
+
77
+ def is_numpy_tensor(inp) -> bool:
78
+ if is_numpy_available():
79
+ import numpy as np
80
+ return isinstance(inp, np.ndarray)
81
+ return False
82
+
83
+
84
+ def numpy_library_type(inp) -> cutlass_cppgen.DataType:
85
+ if is_numpy_available():
86
+ import numpy as np
87
+ if inp == np.float16:
88
+ return cutlass_cppgen.DataType.f16
89
+ elif inp == np.float32:
90
+ return cutlass_cppgen.DataType.f32
91
+ elif inp == np.float64:
92
+ return cutlass_cppgen.DataType.f64
93
+ elif inp == np.int8:
94
+ return cutlass_cppgen.DataType.s8
95
+ elif inp == np.int32:
96
+ return cutlass_cppgen.DataType.s32
97
+ return None
98
+
99
+
100
+ def numpy_type(inp):
101
+ return _library_to_numpy_dict.get(inp, None)
102
+
103
+
104
+ def is_cupy_available():
105
+ global cupy_available
106
+ if cupy_available is None:
107
+ try:
108
+ import cupy as cp
109
+
110
+ cupy_available = True
111
+ _library_to_cupy_dict = {
112
+ cutlass_cppgen.DataType.f16: cp.float16,
113
+ cutlass_cppgen.DataType.f32: cp.float32,
114
+ cutlass_cppgen.DataType.f64: cp.float64,
115
+ cutlass_cppgen.DataType.s8: cp.int8,
116
+ cutlass_cppgen.DataType.s32: cp.int32,
117
+ }
118
+ except ImportError:
119
+ cupy_available = False
120
+ _library_to_cupy_dict = {}
121
+ return cupy_available
122
+
123
+
124
+ def is_cupy_tensor(inp) -> bool:
125
+ if is_cupy_available():
126
+ import cupy as cp
127
+ return isinstance(inp, cp.ndarray)
128
+ return False
129
+
130
+
131
+ def cupy_library_type(inp) -> cutlass_cppgen.DataType:
132
+ if is_cupy_available():
133
+ import cupy as cp
134
+ if inp == cp.float16:
135
+ return cutlass_cppgen.DataType.f16
136
+ elif inp == cp.float32:
137
+ return cutlass_cppgen.DataType.f32
138
+ elif inp == cp.float64:
139
+ return cutlass_cppgen.DataType.f64
140
+ return None
141
+
142
+
143
+ def cupy_type(inp):
144
+ return _library_to_cupy_dict.get(inp, None)
145
+
146
+
147
+ def is_torch_available():
148
+ global torch_available, _library_to_torch_dict, _torch_to_library_dict
149
+ if torch_available is None:
150
+ try:
151
+ import torch
152
+
153
+ torch_available = True
154
+ _torch_to_library_dict = {
155
+ torch.half: cutlass_cppgen.DataType.f16,
156
+ torch.float16: cutlass_cppgen.DataType.f16,
157
+ torch.bfloat16: cutlass_cppgen.DataType.bf16,
158
+ torch.float: cutlass_cppgen.DataType.f32,
159
+ torch.float32: cutlass_cppgen.DataType.f32,
160
+ torch.double: cutlass_cppgen.DataType.f64,
161
+ torch.float64: cutlass_cppgen.DataType.f64,
162
+ torch.int8: cutlass_cppgen.DataType.s8,
163
+ torch.int32: cutlass_cppgen.DataType.s32,
164
+ torch.uint8: cutlass_cppgen.DataType.u8,
165
+ }
166
+
167
+ _library_to_torch_dict = {
168
+ cutlass_cppgen.DataType.f16: torch.half,
169
+ cutlass_cppgen.DataType.f16: torch.float16,
170
+ cutlass_cppgen.DataType.bf16: torch.bfloat16,
171
+ cutlass_cppgen.DataType.f32: torch.float,
172
+ cutlass_cppgen.DataType.f32: torch.float32,
173
+ cutlass_cppgen.DataType.f64: torch.double,
174
+ cutlass_cppgen.DataType.f64: torch.float64,
175
+ cutlass_cppgen.DataType.s8: torch.int8,
176
+ cutlass_cppgen.DataType.s32: torch.int32,
177
+ cutlass_cppgen.DataType.u8: torch.uint8,
178
+ }
179
+
180
+ def possibly_add_type(torch_type_name, cutlass_type):
181
+ # Only try adding the type if the version of torch being used supports it
182
+ if hasattr(torch, torch_type_name):
183
+ torch_type = getattr(torch, torch_type_name)
184
+ _torch_to_library_dict[torch_type] = cutlass_type
185
+ _library_to_torch_dict[cutlass_type] = torch_type
186
+
187
+ possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3)
188
+ possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2)
189
+
190
+ except ImportError:
191
+ torch_available = False
192
+ _torch_to_library_dict = {}
193
+ _library_to_torch_dict = {}
194
+ return torch_available
195
+
196
+
197
+ def is_torch_tensor(inp) -> bool:
198
+ if is_torch_available():
199
+ import torch
200
+ return isinstance(inp, torch.Tensor)
201
+ return False
202
+
203
+
204
+ def torch_library_type(inp) -> cutlass_cppgen.DataType:
205
+ return _torch_to_library_dict.get(inp, None)
206
+
207
+
208
+ def torch_type(inp):
209
+ return _library_to_torch_dict.get(inp, None)
210
+
211
+
212
+ def is_bfloat16_available():
213
+ global bfloat16_available
214
+
215
+ if bfloat16_available is None:
216
+ try:
217
+ import bfloat16
218
+
219
+ bfloat16_available = True
220
+ except ImportError:
221
+ bfloat16_available = False
222
+ return bfloat16_available
223
+
224
+
225
+ def bfloat16_library_type(inp) -> cutlass_cppgen.DataType:
226
+ if is_bfloat16_available():
227
+ import bfloat16
228
+ if inp == bfloat16.bfloat16:
229
+ return cutlass_cppgen.DataType.bf16
230
+
231
+
232
+ def bfloat16_type(inp):
233
+ if is_bfloat16_available():
234
+ import bfloat16
235
+ if inp == cutlass_cppgen.DataType.bf16:
236
+ return bfloat16.bfloat16
237
+
238
+
239
+ def library_type(inp):
240
+ if inp in DataTypeSize:
241
+ return inp
242
+
243
+ for cvt_fn in [
244
+ bfloat16_library_type,
245
+ cupy_library_type,
246
+ numpy_library_type,
247
+ torch_library_type,
248
+ ]:
249
+ out = cvt_fn(inp)
250
+ if out is not None:
251
+ return out
252
+
253
+ raise Exception(f"No available conversion from type {inp} to a library type.")
254
+
255
+
256
+ def _tensor_from_numpy(np_tensor):
257
+ dtype = library_type(np_tensor.dtype)
258
+ if np_tensor.flags.c_contiguous:
259
+ layout = cutlass_cppgen.LayoutType.RowMajor
260
+ elif np_tensor.flags.f_contiguous:
261
+ layout = cutlass_cppgen.LayoutType.ColumnMajor
262
+ return (dtype, layout)
263
+
264
+
265
+ def _tensor_from_torch(pt_tensor):
266
+ dtype = library_type(pt_tensor.dtype)
267
+ return (dtype, cutlass_cppgen.LayoutType.RowMajor)
268
+
269
+
270
+ def get_datatype_and_layout(tensor):
271
+ if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
272
+ return _tensor_from_numpy(tensor)
273
+ elif is_torch_tensor(tensor):
274
+ return _tensor_from_torch(tensor)
275
+ elif isinstance(tensor, float) or isinstance(tensor, int):
276
+ return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor)
277
+ else:
278
+ raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
279
+
280
+
281
+ def get_tensor_shape(tensor, op="GEMM"):
282
+ if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
283
+ return tensor.shape
284
+ elif is_torch_tensor(tensor):
285
+ size = tensor.size()
286
+ if op == "CONV":
287
+ # PyTorch Tensors have shape NCHW
288
+ return (size[0], size[2], size[3], size[1])
289
+ else:
290
+ return tuple(tensor.size())
291
+ elif isinstance(tensor, float) or isinstance(tensor, int):
292
+ return (1,)
293
+ else:
294
+ raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
295
+
296
+
297
+ _math_operation_value_map = {x.value: x for x in MathOperation}
298
+
299
+
300
+ def backend_math_operation(math_op: MathOperation):
301
+ if math_op.value not in _math_operation_value_map.keys():
302
+ raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.")
303
+ return _math_operation_value_map[math_op.value]
304
+
305
+
306
+ def construct_backend_td(td: cutlass_cppgen.TileDescription,
307
+ kernel_schedule: cutlass_cppgen.KernelScheduleType,
308
+ epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
309
+ tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription:
310
+ mi = td.math_instruction
311
+ backend_mi = MathInstruction(
312
+ mi.instruction_shape,
313
+ mi.element_a,
314
+ mi.element_b,
315
+ mi.element_accumulator,
316
+ mi.opcode_class,
317
+ backend_math_operation(mi.math_operation)
318
+ )
319
+ cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1]
320
+ return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
321
+ backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler)
322
+
323
+
324
+ def td_from_profiler_op(op) -> TileDescription:
325
+ """
326
+ Converts the profiler's TileDescription in ``op`` into the backend TileDescription
327
+
328
+ :param op: profiler Operation
329
+
330
+ :returns: backend TileDescription
331
+ :rtype: cutlass_cppgen.backend.TileDescription
332
+ """
333
+ kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
334
+ eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None
335
+ tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None
336
+ return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule)
337
+
338
+
339
+ def td_from_profiler_td(td: TileDescription) -> TileDescription:
340
+ """
341
+ Converts the profiler's TileDescription into the backend TileDescription
342
+
343
+ :param td: profiler TileDescription
344
+ :type td: cutlass_cppgen.TileDescription
345
+
346
+ :returns: backend TileDescription
347
+ :rtype: cutlass_cppgen.backend.TileDescription
348
+ """
349
+ return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)
350
+
351
+
352
+ def to_camel_case(snake_str):
353
+ return "".join(x.capitalize() for x in snake_str.lower().split("_"))
354
+
355
+
356
+ def getattr_enum(obj, attr_name):
357
+ # The attr_name is under the snake_case
358
+ camel_attr = to_camel_case(attr_name)
359
+ if hasattr(obj, camel_attr):
360
+ return getattr(obj, camel_attr)
361
+ else:
362
+ raise Exception(f"Invalid option: {attr_name}")
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+ import importlib
33
+ from typing import Any
34
+
35
+ def lazy_import(mod_name: str) -> Any:
36
+ class Lazy:
37
+ def __getattr__(self, name:str) -> Any:
38
+ module = importlib.import_module(mod_name)
39
+ return getattr(module, name)
40
+
41
+ return Lazy()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Profiler based on the cuda events
35
+ """
36
+
37
+ import re
38
+ import subprocess
39
+
40
+ from cutlass_cppgen.utils.lazy_import import lazy_import
41
+ cuda = lazy_import("cuda.cuda")
42
+ cudart = lazy_import("cuda.cudart")
43
+ import numpy as np
44
+
45
+ from cutlass_cppgen import CUTLASS_PATH
46
+ from cutlass_cppgen.backend.library import DataTypeSize
47
+ from cutlass_cppgen.op.op import OperationBase
48
+ from cutlass_cppgen.shape import GemmCoord
49
+ from cutlass_cppgen.utils.datatypes import is_numpy_tensor
50
+
51
+
52
+ class GpuTimer:
53
+ def __init__(self) -> None:
54
+ self.events = [
55
+ cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
56
+ cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
57
+ ]
58
+
59
+ def start(self, stream=None):
60
+ if not stream:
61
+ stream = cuda.CUstream(0)
62
+
63
+ (err,) = cuda.cuEventRecord(self.events[0], stream)
64
+ if err != cuda.CUresult.CUDA_SUCCESS:
65
+ raise RuntimeError(f"CUDA Error {str(err)}")
66
+
67
+ def stop(self, stream=None):
68
+ if not stream:
69
+ stream = cuda.CUstream(0)
70
+
71
+ (err,) = cuda.cuEventRecord(self.events[1], stream)
72
+ if err != cuda.CUresult.CUDA_SUCCESS:
73
+ raise RuntimeError(f"CUDA Error {str(err)}")
74
+ pass
75
+
76
+ def stop_and_wait(self, stream=None):
77
+ if not stream:
78
+ stream = cuda.CUstream(0)
79
+
80
+ self.stop(stream)
81
+ if stream:
82
+ (err,) = cuda.cuStreamSynchronize(stream)
83
+ if err != cuda.CUresult.CUDA_SUCCESS:
84
+ raise RuntimeError(f"CUDA Error {str(err)}")
85
+ else:
86
+ (err,) = cudart.cudaDeviceSynchronize()
87
+ if err != cuda.CUresult.CUDA_SUCCESS:
88
+ raise RuntimeError(f"CUDA Error {str(err)}")
89
+
90
+ def duration(self, iterations=1):
91
+ err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
92
+ if err != cuda.CUresult.CUDA_SUCCESS:
93
+ raise RuntimeError(f"CUDA Error {str(err)}")
94
+ return duration / float(iterations)
95
+
96
+
97
+ class CUDAEventProfiler:
98
+ def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None:
99
+ self.arguments = op.run(*args, **kwargs)
100
+ self.operation = op.operation
101
+ self.warmup_iterations = warmup_iterations
102
+ self.iterations = iterations
103
+ self.timer = GpuTimer()
104
+
105
+ #
106
+ # Cutlass Python Interface Profiler
107
+ #
108
+
109
+ def __call__(self):
110
+ for _ in range(self.warmup_iterations):
111
+ self.operation.run(self.arguments)
112
+
113
+ self.timer.start()
114
+ for _ in range(self.iterations):
115
+ self.operation.run(self.arguments)
116
+
117
+ self.timer.stop_and_wait()
118
+ runtime = self.timer.duration(self.iterations)
119
+ return runtime
120
+
121
+ #
122
+ # CUTLASS Profiler
123
+ #
124
+
125
+ def run_cutlass_profiler(self):
126
+ alpha = 1.0
127
+ beta = 1.0
128
+
129
+ profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler"
130
+ kernel_name = self.operation.procedural_name()
131
+ verification_providers = "device"
132
+ provider = "cutlass"
133
+ problem_size = self.arguments.problem_size
134
+
135
+ if "cutlass3x" in kernel_name:
136
+ # cutlass3x generator only have column-major output
137
+ layout_name = self.operation.layout_name_3x()
138
+ if layout_name[-1] == "t":
139
+ new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"])
140
+ problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
141
+ kernel_name = kernel_name.replace(layout_name, new_layout_name)
142
+
143
+ batch_count = self.arguments.batch_count
144
+
145
+ cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \
146
+ f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \
147
+ f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\
148
+ f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}"
149
+
150
+ result = subprocess.getoutput(cmd)
151
+
152
+ m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
153
+ runtime = float(m.group("runtime"))
154
+
155
+ m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
156
+ bytes = int(m.group("bytes"))
157
+
158
+ m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
159
+ flops = int(m.group("flops"))
160
+
161
+ # check if the problem size matches
162
+ assert bytes == self.bytes(problem_size, batch_count, beta)
163
+ assert flops == self.flops(problem_size, batch_count, beta)
164
+
165
+ return runtime
166
+
167
+ def bytes(self, problem_size, batch_count=1, beta=0.0):
168
+ m = problem_size.m()
169
+ n = problem_size.n()
170
+ k = problem_size.k()
171
+
172
+ bytes = (
173
+ (DataTypeSize[self.operation.A.element] * m // 8) * k
174
+ + (DataTypeSize[self.operation.B.element] * n // 8) * k
175
+ + (DataTypeSize[self.operation.C.element] * m // 8) * n
176
+ )
177
+
178
+ if beta != 0:
179
+ bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
180
+
181
+ bytes *= batch_count
182
+
183
+ return bytes
184
+
185
+ def flops(self, problem_size, batch_count=1, beta=0.0):
186
+ m = problem_size.m()
187
+ n = problem_size.n()
188
+ k = problem_size.k()
189
+
190
+ flops_ = (m * n * k) * 2 * batch_count
191
+
192
+ if beta != 0:
193
+ flops_ += m * n * batch_count * 2
194
+
195
+ return flops_
196
+
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import os
34
+ import sys
35
+
36
+ from . import conv2d_operation
37
+ from . import conv3d_operation
38
+ from . import emit_kernel_listing
39
+ from . import gemm_operation
40
+
41
+ if '-m' not in sys.argv:
42
+ # Do not import generator when running python -m cutlass_library.generator to
43
+ # avoid double-import warnings
44
+ from . import generator
45
+
46
+ from . import library
47
+ from . import manifest
48
+ from . import rank_2k_operation
49
+ from . import rank_k_operation
50
+ from . import symm_operation
51
+ from . import trmm_operation
52
+ # Make enum types from library.py accessible via cutlass_library.*
53
+ from .library import *
54
+
55
+ # Set up `source` to point to the path containing the CUTLASS source.
56
+ # Check first if the path contains a `source` subdirectory -- this will
57
+ # be the case when the package has been installed via pip. Otherwise,
58
+ # default to the root of CUTLASS.
59
+ install_source_path = os.path.join(__path__[0], 'source')
60
+ if os.path.isdir(install_source_path):
61
+ source_path = install_source_path
62
+ else:
63
+ source_path = os.path.join(__path__[0], '../..')
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for emitting Conv2d kernels
35
+ """
36
+
37
+ import enum
38
+ import logging
39
+ import os.path
40
+ import shutil
41
+ from string import Template
42
+
43
+ try:
44
+ import builtins
45
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
+ raise ImportError("Disabling attempt to import cutlass_library")
47
+ from cutlass_library.library import *
48
+ from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
49
+ except ImportError:
50
+ from library import *
51
+ from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
52
+
53
+ _LOGGER = logging.getLogger(__name__)
54
+
55
+ ###################################################################################################
56
+
57
+ #
58
+ class Conv2dOperation:
59
+ #
60
+ def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
61
+ stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \
62
+ group_mode = GroupMode.NoneGroup):
63
+
64
+ self.operation_kind = OperationKind.Conv2d
65
+ self.arch = arch
66
+ self.tile_description = tile_description
67
+ self.conv_kind = conv_kind
68
+ self.A = A
69
+ self.B = B
70
+ self.C = C
71
+ self.element_epilogue = element_epilogue
72
+ self.epilogue_functor = epilogue_functor
73
+ self.iterator_algorithm = iterator_algorithm
74
+ self.stride_support = stride_support
75
+ self.swizzling_functor = swizzling_functor
76
+ self.group_mode = group_mode
77
+ #
78
+ def is_complex(self):
79
+ complex_operators = [
80
+ MathOperation.multiply_add_complex,
81
+ MathOperation.multiply_add_complex_gaussian
82
+ ]
83
+ return self.tile_description.math_instruction.math_operation in complex_operators
84
+
85
+ #
86
+ def is_mixed_input(self):
87
+ return self.A.element != self.B.element
88
+
89
+ #
90
+ def accumulator_type(self):
91
+ accum = self.tile_description.math_instruction.element_accumulator
92
+
93
+ if self.is_complex():
94
+ return get_complex_from_real(accum)
95
+
96
+ return accum
97
+
98
+ #
99
+ def core_name(self):
100
+ ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
101
+
102
+ intermediate_type = ''
103
+
104
+ if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
105
+ inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
106
+ if self.tile_description.math_instruction.element_a != self.A.element and \
107
+ self.tile_description.math_instruction.element_a != self.accumulator_type():
108
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
109
+ else:
110
+ inst_shape = ''
111
+
112
+ return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
113
+ inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
114
+
115
+ #
116
+ def extended_name(self):
117
+ ''' Append data types if they differ from compute type. '''
118
+ if self.C.element != self.tile_description.math_instruction.element_accumulator and \
119
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
120
+ extended_name = "${element_c}_${core_name}_${element_a}"
121
+ elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
122
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
123
+ extended_name = "${core_name}_${element_a}"
124
+ else:
125
+ extended_name = "${core_name}"
126
+
127
+ extended_name = SubstituteTemplate(extended_name, {
128
+ 'element_a': DataTypeNames[self.A.element],
129
+ 'element_c': DataTypeNames[self.C.element],
130
+ 'core_name': self.core_name()
131
+ })
132
+
133
+ return extended_name
134
+
135
+ #
136
+ def layout_name(self):
137
+ return "%s" % (ShortLayoutTypeNames[self.A.layout])
138
+
139
+ #
140
+ def configuration_name(self):
141
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
142
+
143
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
144
+
145
+ threadblock = self.tile_description.procedural_name()
146
+
147
+ # grouped conv
148
+ if self.group_mode != GroupMode.NoneGroup:
149
+ group_conv_name = f"{GroupModeNames[self.group_mode]}_"
150
+ else:
151
+ group_conv_name = ""
152
+
153
+ if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad:
154
+ configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
155
+ else:
156
+ configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"
157
+
158
+ return SubstituteTemplate(
159
+ configuration_name,
160
+ {
161
+ 'opcode_class': opcode_class_name,
162
+ 'extended_name': self.extended_name(),
163
+ 'threadblock': threadblock,
164
+ 'layout': self.layout_name(),
165
+ 'alignment': "%d" % self.A.alignment,
166
+ 'group_conv_name': group_conv_name
167
+ }
168
+ )
169
+
170
+ #
171
+ def procedural_name(self):
172
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
173
+ return self.configuration_name()
174
+
175
+ ###################################################################################################
176
+ #
177
+ # Emits single instances of a CUTLASS device-wide operator
178
+ #
179
+ ###################################################################################################
180
+
181
+ class EmitConv2dInstance:
182
+ def __init__(self):
183
+ # Emitter for CUTLASS 3 convolution operations
184
+ self.conv3x_emitter = EmitConv3xInstance()
185
+ self.template = """
186
+ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
187
+ using ${operation_name}_base =
188
+ typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
189
+ ${element_a},
190
+ ${layout_a},
191
+ ${element_b},
192
+ ${layout_b},
193
+ ${element_c},
194
+ ${layout_c},
195
+ ${element_accumulator},
196
+ ${opcode_class},
197
+ ${arch},
198
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
199
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
200
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
201
+ ${epilogue_functor}<
202
+ ${element_c},
203
+ ${epilogue_vector_length},
204
+ ${element_accumulator},
205
+ ${element_epilogue}
206
+ >,
207
+ ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
208
+ ${stages},
209
+ ${math_operator},
210
+ ${iterator_algorithm},
211
+ ${stride_support},
212
+ ${align_a},
213
+ ${align_b}
214
+ >::Kernel;
215
+ """
216
+ self.template_group_conv = """
217
+ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
218
+ using ${operation_name}_base =
219
+ typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}<
220
+ ${element_a},
221
+ ${layout_a},
222
+ ${element_b},
223
+ ${layout_b},
224
+ ${element_c},
225
+ ${layout_c},
226
+ ${element_accumulator},
227
+ ${opcode_class},
228
+ ${arch},
229
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
230
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
231
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
232
+ ${epilogue_functor}<
233
+ ${element_c},
234
+ ${epilogue_vector_length},
235
+ ${element_accumulator},
236
+ ${element_epilogue}
237
+ >,
238
+ ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
239
+ ${stages},
240
+ ${math_operator},
241
+ ${group_mode},
242
+ ${iterator_algorithm},
243
+ ${stride_support},
244
+ ${align_a},
245
+ ${align_b}
246
+ >::Kernel;
247
+ """
248
+ self.template_depthwise_direct_conv = """
249
+ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
250
+ using ${operation_name}_base =
251
+ typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}<
252
+ ${element_a},
253
+ ${layout_a},
254
+ ${element_b},
255
+ ${layout_b},
256
+ ${element_c},
257
+ ${layout_c},
258
+ ${element_accumulator},
259
+ ${opcode_class},
260
+ ${arch},
261
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
262
+ cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>,
263
+ cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>,
264
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
265
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
266
+ ${epilogue_functor}<
267
+ ${element_c},
268
+ ${epilogue_vector_length},
269
+ ${element_accumulator},
270
+ ${element_epilogue},
271
+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
272
+ >,
273
+
274
+ cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
275
+ 1,
276
+ ${threadblock_output_shape_n},
277
+ ${threadblock_output_shape_p},
278
+ ${threadblock_output_shape_q}>,
279
+ ${stages},
280
+ ${math_operator},
281
+ ${iterator_algorithm},
282
+ ${stride_support},
283
+ cutlass::MatrixShape<${stride_r}, ${stride_s}>,
284
+ cutlass::MatrixShape<${dilation_r}, ${dilation_s}>
285
+ >::Kernel;
286
+ """
287
+
288
+ def arch_number_to_type(self, arch: int):
289
+ return f"cutlass::arch::Sm{arch}"
290
+
291
+ def emit(self, operation):
292
+ _LOGGER.debug("*** EmitConv2dInstance::emit")
293
+ _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
294
+
295
+ if hasattr(operation, 'is_3x') and operation.is_3x:
296
+ _LOGGER.debug("*** CUTLASS 3 operation")
297
+ return self.conv3x_emitter.emit(operation)
298
+
299
+ _LOGGER.debug("*** CUTLASS 2 operation")
300
+
301
+ warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
302
+
303
+ epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
304
+
305
+ values = {
306
+ 'operation_name': operation.procedural_name(),
307
+ 'conv_kind': ConvKindTag[operation.conv_kind],
308
+ 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
309
+ 'element_a': DataTypeTag[operation.A.element],
310
+ 'layout_a': LayoutTag[operation.A.layout],
311
+ 'element_b': DataTypeTag[operation.B.element],
312
+ 'layout_b': LayoutTag[operation.B.layout],
313
+ 'element_c': DataTypeTag[operation.C.element],
314
+ 'layout_c': LayoutTag[operation.C.layout],
315
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
316
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
317
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
318
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
319
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
320
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
321
+ 'warp_shape_m': str(warp_shape[0]),
322
+ 'warp_shape_n': str(warp_shape[1]),
323
+ 'warp_shape_k': str(warp_shape[2]),
324
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
325
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
326
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
327
+ 'epilogue_vector_length': str(epilogue_vector_length),
328
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
329
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
330
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
331
+ 'stages': str(operation.tile_description.stages),
332
+ 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
333
+ 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
334
+ 'stride_support': StrideSupportTag[operation.stride_support],
335
+ 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \
336
+ MathOperationTag[operation.tile_description.math_instruction.math_operation],
337
+ 'align_a': str(operation.A.alignment),
338
+ 'align_b': str(operation.B.alignment),
339
+ }
340
+
341
+ if operation.group_mode == GroupMode.NoneGroup:
342
+ _LOGGER.debug("*** group_mode=NoneGroup")
343
+ return SubstituteTemplate(self.template, values)
344
+
345
+ elif operation.group_mode == GroupMode.Depthwise:
346
+ _LOGGER.debug("*** group_mode=Depthwise")
347
+ values['group_mode'] = GroupModeTag[operation.group_mode]
348
+ # Setup other template params
349
+ values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0])
350
+ values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1])
351
+ values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2])
352
+
353
+ values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3])
354
+
355
+ values['filter_shape_r'] = str(operation.tile_description.filter_shape[0])
356
+ values['filter_shape_s'] = str(operation.tile_description.filter_shape[1])
357
+
358
+ values['stride_r'] = str(operation.tile_description.stride[0])
359
+ values['stride_s'] = str(operation.tile_description.stride[1])
360
+
361
+ values['dilation_r'] = str(operation.tile_description.dilation[0])
362
+ values['dilation_s'] = str(operation.tile_description.dilation[1])
363
+
364
+ return SubstituteTemplate(self.template_depthwise_direct_conv, values)
365
+
366
+ else:
367
+ _LOGGER.debug("*** group_mode=" + GroupModeTag[operation.group_mode])
368
+ values['group_mode'] = GroupModeTag[operation.group_mode]
369
+ return SubstituteTemplate(self.template_group_conv, values)
370
+
371
+ ###################################################################################################
372
+ #
373
+ # Generator functions for all layouts
374
+ #
375
+ ###################################################################################################
376
+
377
+ #
378
+ def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
379
+ _LOGGER.debug("*** GenerateConv2dTensorOp")
380
+
381
+ for tile in tile_descriptions:
382
+ for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
383
+
384
+ if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
385
+
386
+ #
387
+ output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
388
+ if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
389
+ else [tile.math_instruction.element_accumulator,]
390
+
391
+ for output_type in output_types:
392
+ A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
393
+ B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
394
+ C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type])))
395
+
396
+ manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
397
+
398
+ class EmitConv2dIncludes:
399
+ '''Emit includes that are specific to the operation.'''
400
+
401
+ def __init__(self):
402
+ self.includes = ['conv2d_operation.h']
403
+ self.emitter_3x = EmitConv3xIncludes()
404
+
405
+ def operation_is_3x(self, operation) -> bool:
406
+ """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
407
+ return hasattr(operation, 'is_3x') and operation.is_3x
408
+
409
+ def emit(self, operation) -> str:
410
+ if self.operation_is_3x(operation):
411
+ return self.emitter_3x.emit(operation)
412
+
413
+ return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
414
+ "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
415
+
416
+ ###################################################################################################
417
+ #
418
+ # Emitters functions for all targets
419
+ #
420
+ ###################################################################################################
421
+
422
+ class EmitConv2dConfigurationLibrary:
423
+ def __init__(self, operation_path, configuration_name):
424
+ self.configuration_name = configuration_name
425
+ self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
426
+
427
+ self.instance_emitter = EmitConv2dInstance()
428
+ self.includes_emitter = EmitConv2dIncludes()
429
+
430
+ self.header_template = """
431
+ /*
432
+ Generated by conv2d_operation.py - Do not edit.
433
+ */
434
+
435
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
436
+
437
+ #include "cutlass/cutlass.h"
438
+ #include "cutlass/library/library.h"
439
+ #include "cutlass/library/manifest.h"
440
+
441
+ #include "library_internal.h"
442
+ """
443
+
444
+ self.instance_template = """
445
+ ${stub_begin}
446
+ ${operation_instance}
447
+ // Derived class
448
+ struct ${operation_name} :
449
+ public ${operation_name}_base { };
450
+ ${stub_end}
451
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
452
+
453
+ """
454
+
455
+ self.configuration_header = """
456
+
457
+ namespace cutlass {
458
+ namespace library {
459
+
460
+ // Initialize all instances
461
+ void initialize_${configuration_name}(Manifest &manifest) {
462
+ """
463
+
464
+ self.configuration_instance = """${stub_begin}
465
+ using Operation_${operation_name} = cutlass::conv::device::${kernel_name}<
466
+ ${operation_name}>;
467
+
468
+ manifest.append(new cutlass::library::${operation_wrapper}<
469
+ Operation_${operation_name}
470
+ >(
471
+ "${operation_name}"
472
+ ));
473
+ ${stub_end}
474
+ """
475
+
476
+ self.configuration_epilogue = "}\n"
477
+
478
+ self.epilogue_template = """
479
+
480
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
481
+
482
+ } // namespace library
483
+ } // namespace cutlass
484
+
485
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
486
+
487
+ """
488
+
489
+ def operation_is_3x(self, operation):
490
+ """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
491
+ return hasattr(operation, 'is_3x') and operation.is_3x
492
+
493
+ def __enter__(self):
494
+ """
495
+ Open the configuration_file, and write the "header" C++ code to it.
496
+
497
+ The "header" consists of a comment (that this is generated code,
498
+ so it should not be edited), and includes that are common
499
+ to all kinds of kernels.
500
+ """
501
+ _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__enter__')
502
+ _LOGGER.debug('*** configuration_path (file to write): ' +
503
+ str(self.configuration_path))
504
+ _LOGGER.debug('*** configuration_name: ' + self.configuration_name)
505
+ self.configuration_file = open(self.configuration_path, "w")
506
+
507
+ self.configuration_file.write(SubstituteTemplate(self.header_template, {
508
+ 'configuration_name': self.configuration_name
509
+ }))
510
+ self.operations = []
511
+ return self
512
+
513
+ def emit(self, operation):
514
+ """
515
+ Write three pieces of C++ code to the configuration_file
516
+ (that was opened by the __enter__ method above):
517
+
518
+ 1. the header includes that are specific to the operation
519
+ (CUTLASS 2 vs. CUTLASS 3);
520
+
521
+ 2. the "operation instance" (a "using" declaration ending in "_base"); and
522
+
523
+ 3. the "operation name" (declaration and definition of a derived class
524
+ of the above operation instance).
525
+
526
+ The "using" declaration turns a C++ class name, possibly namespace-qualified,
527
+ possibly also with angle brackets, into a C-style, easily demangled identifier.
528
+ """
529
+ _LOGGER.debug('*** EmitConv2dConfigurationLibrary::emit')
530
+ _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name())
531
+ self.operations.append(operation)
532
+
533
+ self.configuration_file.write(self.includes_emitter.emit(operation))
534
+
535
+ stub_begin = ''
536
+ stub_end = ''
537
+ # It can be useful to stub (comment) out instantiations for testing.
538
+ # In this case, one need only set is_stub to True.
539
+ is_stub = False
540
+ if is_stub:
541
+ stub_begin = "// STUB for now\n#if 0"
542
+ stub_end = '#endif // 0'
543
+
544
+ self.configuration_file.write(Template(self.instance_template).substitute({
545
+ 'configuration_name': self.configuration_name,
546
+ 'operation_name': operation.procedural_name(),
547
+ 'operation_instance': self.instance_emitter.emit(operation),
548
+ 'stub_begin': stub_begin,
549
+ 'stub_end': stub_end
550
+ }))
551
+
552
+ def __exit__(self, exception_type, exception_value, traceback):
553
+ """
554
+ Write the rest of the C++ code to the configuration_file, and close the file.
555
+
556
+ The "rest of the C++ code" has the following components.
557
+
558
+ 1. Configuration header: Open the namespace(s), and open the definition
559
+ of the "initialize_${configuration_name}" registration function
560
+ that registers the operation with the Manifest.
561
+ ("Registration" helps turn C++ compile-time polymorphism
562
+ (via template parameters) into a run-time choice of parameters.)
563
+
564
+ 2. Configuration instance: In the body of the registration function,
565
+ make a "using" declaration Operation_${operation_name} for the
566
+ operation type (which uses operation_name as its template argument).
567
+ Then, tell the manifest about the operation via a "manifest.append" call.
568
+ The argument of the call is a new instance of
569
+ "SomethingOperation<Operation_${operation_name}>"
570
+ (replace Something with a specific name).
571
+
572
+ 3. Configuration epilogue: Close the definition of the registration function.
573
+
574
+ 4. Epilogue template: Close the namespace(s).
575
+ """
576
+
577
+ _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__exit__')
578
+ _LOGGER.debug('*** configuration_path (file to write): ' +
579
+ str(self.configuration_path))
580
+ _LOGGER.debug('*** configuration_name: ' + self.configuration_name)
581
+
582
+ self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
583
+ 'configuration_name': self.configuration_name
584
+ }))
585
+
586
+ for operation in self.operations:
587
+ stub_begin = ''
588
+ stub_end = ''
589
+ # It can be useful to stub (comment) out instantiations for testing.
590
+ # In this case, one need only set is_stub to True.
591
+ is_stub = False
592
+ if is_stub:
593
+ stub_begin = "// STUB for now\n#if 0"
594
+ stub_end = "#endif // 0"
595
+
596
+ if operation.group_mode == GroupMode.Depthwise:
597
+ kernel_name = 'DirectConvolution'
598
+ operation_wrapper = 'DirectConv2dOperation'
599
+ else:
600
+ kernel_name = 'ImplicitGemmConvolution'
601
+ operation_wrapper = 'Conv2dOperation'
602
+ if self.operation_is_3x(operation):
603
+ kernel_name = 'ConvUniversalAdapter'
604
+ operation_wrapper = 'ConvOperation3x'
605
+
606
+ self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
607
+ 'configuration_name': self.configuration_name,
608
+ 'operation_name': operation.procedural_name(),
609
+ 'kernel_name': kernel_name,
610
+ 'operation_wrapper': operation_wrapper,
611
+ 'stub_begin': stub_begin,
612
+ 'stub_end': stub_end
613
+ }))
614
+
615
+ self.configuration_file.write(self.configuration_epilogue)
616
+ self.configuration_file.write(self.epilogue_template)
617
+ self.configuration_file.close()
618
+
619
+
620
+ ###################################################################################################
621
+ ###################################################################################################
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for emitting Conv3d kernels
35
+ """
36
+
37
+ import enum
38
+ import logging
39
+ import os.path
40
+ import shutil
41
+ from string import Template
42
+
43
+ try:
44
+ import builtins
45
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
+ raise ImportError("Disabling attempt to import cutlass_library")
47
+ from cutlass_library.library import *
48
+ from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
49
+ except ImportError:
50
+ from library import *
51
+ from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
52
+
53
+ _LOGGER = logging.getLogger(__name__)
54
+
55
+ ###################################################################################################
56
+
57
+ #
58
+ class Conv3dOperation:
59
+ #
60
+ def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
61
+ stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
62
+
63
+ self.operation_kind = OperationKind.Conv3d
64
+ self.arch = arch
65
+ self.tile_description = tile_description
66
+ self.conv_kind = conv_kind
67
+ self.A = A
68
+ self.B = B
69
+ self.C = C
70
+ self.element_epilogue = element_epilogue
71
+ self.epilogue_functor = epilogue_functor
72
+ self.iterator_algorithm = iterator_algorithm
73
+ self.stride_support = stride_support
74
+ self.swizzling_functor = swizzling_functor
75
+
76
+ #
77
+ def is_mixed_input(self):
78
+ return self.A.element != self.B.element
79
+
80
+ #
81
+ def core_name(self):
82
+ ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
83
+
84
+ intermediate_type = ''
85
+
86
+ if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
87
+ inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
88
+ if self.tile_description.math_instruction.element_a != self.A.element and \
89
+ self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
90
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
91
+ else:
92
+ inst_shape = ''
93
+
94
+ return "%s%s%s%s3d_%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], \
95
+ inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
96
+
97
+ #
98
+ def extended_name(self):
99
+ ''' Append data types if they differ from compute type. '''
100
+ if self.C.element != self.tile_description.math_instruction.element_accumulator and \
101
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
102
+ extended_name = "${element_c}_${core_name}_${element_a}"
103
+ elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
104
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
105
+ extended_name = "${core_name}_${element_a}"
106
+ else:
107
+ extended_name = "${core_name}"
108
+
109
+ extended_name = SubstituteTemplate(extended_name, {
110
+ 'element_a': DataTypeNames[self.A.element],
111
+ 'element_c': DataTypeNames[self.C.element],
112
+ 'core_name': self.core_name()
113
+ })
114
+
115
+ return extended_name
116
+
117
+ #
118
+ def configuration_name(self):
119
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
120
+
121
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
122
+
123
+ threadblock = "%dx%d_%dx%d" % (
124
+ self.tile_description.threadblock_shape[0],
125
+ self.tile_description.threadblock_shape[1],
126
+ self.tile_description.threadblock_shape[2],
127
+ self.tile_description.stages
128
+ )
129
+
130
+ if self.stride_support == StrideSupport.Unity:
131
+ configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_unity_stride"
132
+ else:
133
+ configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}"
134
+
135
+ return SubstituteTemplate(
136
+ configuration_name,
137
+ {
138
+ 'opcode_class': opcode_class_name,
139
+ 'extended_name': self.extended_name(),
140
+ 'threadblock': threadblock,
141
+ }
142
+ )
143
+
144
+ #
145
+ def procedural_name(self):
146
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
147
+ return self.configuration_name()
148
+
149
+ ###################################################################################################
150
+ #
151
+ # Emits single instances of a CUTLASS device-wide operator
152
+ #
153
+ ###################################################################################################
154
+
155
+ class EmitConv3dInstance:
156
+ def __init__(self):
157
+ # Emitter for CUTLASS 3 convolution operations
158
+ self.conv3x_emitter = EmitConv3xInstance()
159
+ self.template = """
160
+ // Conv3d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
161
+ using ${operation_name}_base =
162
+ typename cutlass::conv::kernel::DefaultConv3d${conv_kind_name}<
163
+ ${element_a},
164
+ cutlass::layout::TensorNDHWC,
165
+ ${element_b},
166
+ cutlass::layout::TensorNDHWC,
167
+ ${element_c},
168
+ cutlass::layout::TensorNDHWC,
169
+ ${element_accumulator},
170
+ ${opcode_class},
171
+ ${arch},
172
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
173
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
174
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
175
+ ${epilogue_functor}<
176
+ ${element_c},
177
+ ${epilogue_vector_length},
178
+ ${element_accumulator},
179
+ ${element_epilogue}
180
+ >,
181
+ ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
182
+ ${stages},
183
+ cutlass::arch::OpMultiplyAdd,
184
+ ${iterator_algorithm},
185
+ ${stride_support}
186
+ >::Kernel;
187
+ """
188
+
189
+ def emit(self, operation):
190
+ _LOGGER.debug("*** EmitConv3dInstance::emit")
191
+ _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
192
+
193
+ if hasattr(operation, 'is_3x') and operation.is_3x:
194
+ _LOGGER.debug("*** CUTLASS 3 operation")
195
+ return self.conv3x_emitter.emit(operation)
196
+
197
+ _LOGGER.debug("*** CUTLASS 2 operation")
198
+
199
+ warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
200
+
201
+ epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
202
+
203
+ values = {
204
+ 'operation_name': operation.procedural_name(),
205
+ 'conv_kind': ConvKindTag[operation.conv_kind],
206
+ 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
207
+ 'element_a': DataTypeTag[operation.A.element],
208
+ 'layout_a': LayoutTag[operation.A.layout],
209
+ 'element_b': DataTypeTag[operation.B.element],
210
+ 'layout_b': LayoutTag[operation.B.layout],
211
+ 'element_c': DataTypeTag[operation.C.element],
212
+ 'layout_c': LayoutTag[operation.C.layout],
213
+ 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
214
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
215
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
216
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
217
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
218
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
219
+ 'warp_shape_m': str(warp_shape[0]),
220
+ 'warp_shape_n': str(warp_shape[1]),
221
+ 'warp_shape_k': str(warp_shape[2]),
222
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
223
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
224
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
225
+ 'epilogue_vector_length': str(epilogue_vector_length),
226
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
227
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
228
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
229
+ 'stages': str(operation.tile_description.stages),
230
+ 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
231
+ 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
232
+ 'stride_support': StrideSupportTag[operation.stride_support]
233
+ }
234
+
235
+ return SubstituteTemplate(self.template, values)
236
+
237
+ ###################################################################################################
238
+ #
239
+ # Generator functions for all layouts
240
+ #
241
+ ###################################################################################################
242
+
243
+ #
244
+ def GenerateConv3dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
245
+
246
+ for tile in tile_descriptions:
247
+ for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
248
+
249
+ if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
250
+
251
+ #
252
+ output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
253
+ if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
254
+ else [tile.math_instruction.element_accumulator,]
255
+
256
+ for output_type in output_types:
257
+ A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
258
+ B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
259
+ C = TensorDescription(output_type, LayoutType.TensorNDHWC, max(1, int(align / DataTypeSize[output_type])))
260
+
261
+ manifest.append(Conv3dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
262
+
263
+ class EmitConv3dIncludes:
264
+ '''Emit includes that are specific to the operation.'''
265
+
266
+ def __init__(self):
267
+ self.includes = ['conv3d_operation.h']
268
+ self.emitter_3x = EmitConv3xIncludes()
269
+
270
+ def operation_is_3x(self, operation) -> bool:
271
+ """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
272
+ return hasattr(operation, 'is_3x') and operation.is_3x
273
+
274
+ def emit(self, operation) -> str:
275
+ if self.operation_is_3x(operation):
276
+ return self.emitter_3x.emit(operation)
277
+
278
+ return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
279
+ "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
280
+
281
+ ###################################################################################################
282
+ #
283
+ # Emitters functions for all targets
284
+ #
285
+ ###################################################################################################
286
+
287
+ class EmitConv3dConfigurationLibrary:
288
+ def __init__(self, operation_path, configuration_name):
289
+ self.configuration_name = configuration_name
290
+ self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
291
+
292
+ self.instance_emitter = EmitConv3dInstance()
293
+ self.includes_emitter = EmitConv3dIncludes()
294
+
295
+ self.header_template = """
296
+ /*
297
+ Generated by conv3d_operation.py - Do not edit.
298
+ */
299
+
300
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
301
+
302
+ #include "cutlass/cutlass.h"
303
+ #include "cutlass/library/library.h"
304
+ #include "cutlass/library/manifest.h"
305
+
306
+ #include "library_internal.h"
307
+ """
308
+
309
+ self.instance_template = """
310
+ ${stub_begin}
311
+ ${operation_instance}
312
+ // Derived class
313
+ struct ${operation_name} :
314
+ public ${operation_name}_base { };
315
+ ${stub_end}
316
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
317
+
318
+ """
319
+
320
+ self.configuration_header = """
321
+
322
+ namespace cutlass {
323
+ namespace library {
324
+
325
+ // Initialize all instances
326
+ void initialize_${configuration_name}(Manifest &manifest) {
327
+ """
328
+
329
+ self.configuration_instance = """${stub_begin}
330
+ using Operation_${operation_name} = cutlass::conv::device::${kernel_name}<
331
+ ${operation_name}>;
332
+
333
+ manifest.append(new cutlass::library::${operation_wrapper}<
334
+ Operation_${operation_name}
335
+ >(
336
+ "${operation_name}"
337
+ ));
338
+ ${stub_end}
339
+ """
340
+
341
+ self.configuration_epilogue = "}\n"
342
+
343
+ self.epilogue_template = """
344
+
345
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
346
+
347
+ } // namespace library
348
+ } // namespace cutlass
349
+
350
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
351
+
352
+ """
353
+
354
+ def operation_is_3x(self, operation):
355
+ """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
356
+ return hasattr(operation, 'is_3x') and operation.is_3x
357
+
358
+ def __enter__(self):
359
+ """
360
+ Open the configuration_file, and write the "header" C++ code to it.
361
+
362
+ The "header" consists of a comment (that this is generated code,
363
+ so it should not be edited), and includes that are common
364
+ to both the CUTLASS 2 and the CUTLASS 3 cases.
365
+ """
366
+ _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__enter__')
367
+ _LOGGER.debug('*** configuration_path (file to write): ' +
368
+ str(self.configuration_path))
369
+ _LOGGER.debug('*** configuration_name: ' + self.configuration_name)
370
+ self.configuration_file = open(self.configuration_path, "w")
371
+
372
+ self.configuration_file.write(SubstituteTemplate(self.header_template, {
373
+ 'configuration_name': self.configuration_name
374
+ }))
375
+ self.operations = []
376
+ return self
377
+
378
+ def emit(self, operation):
379
+ """
380
+ Write three pieces of C++ code to the configuration_file
381
+ (that was opened by the __enter__ method above):
382
+
383
+ 1. the header includes that are specific to the operation
384
+ (CUTLASS 2 vs. CUTLASS 3);
385
+
386
+ 2. the "operation instance" (a "using" declaration ending in "_base"); and
387
+
388
+ 3. the "operation name" (declaration and definition of a derived class
389
+ of the above operation instance).
390
+
391
+ The "using" declaration turns a C++ class name, possibly namespace-qualified,
392
+ possibly also with angle brackets, into a C-style, easily demangled identifier.
393
+ """
394
+ _LOGGER.debug('*** EmitConv3dConfigurationLibrary::emit')
395
+ _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name())
396
+ self.operations.append(operation)
397
+
398
+ self.configuration_file.write(self.includes_emitter.emit(operation))
399
+
400
+ stub_begin = ''
401
+ stub_end = ''
402
+ # It can be useful to stub (comment) out instantiations for testing.
403
+ # In this case, one need only set is_stub to True.
404
+ is_stub = False
405
+ if is_stub:
406
+ stub_begin = "// STUB for now\n#if 0"
407
+ stub_end = '#endif // 0'
408
+
409
+ self.configuration_file.write(Template(self.instance_template).substitute({
410
+ 'configuration_name': self.configuration_name,
411
+ 'operation_name': operation.procedural_name(),
412
+ 'operation_instance': self.instance_emitter.emit(operation),
413
+ 'stub_begin': stub_begin,
414
+ 'stub_end': stub_end
415
+ }))
416
+
417
+ def __exit__(self, exception_type, exception_value, traceback):
418
+ """
419
+ Write the rest of the C++ code to the configuration_file, and close the file.
420
+
421
+ The "rest of the C++ code" has the following components.
422
+
423
+ 1. Configuration header: Open the namespace(s), and open the definition
424
+ of the "initialize_${configuration_name}" registration function
425
+ that registers the operation with the Manifest.
426
+ ("Registration" helps turn C++ compile-time polymorphism
427
+ (via template parameters) into a run-time choice of parameters.)
428
+
429
+ 2. Configuration instance: In the body of the registration function,
430
+ make a "using" declaration Operation_${operation_name} for the
431
+ operation type (which uses operation_name as its template argument).
432
+ Then, tell the manifest about the operation via a "manifest.append" call.
433
+ The argument of the call is a new instance of
434
+ "SomethingOperation<Operation_${operation_name}>"
435
+ (replace Something with a specific name).
436
+
437
+ 3. Configuration epilogue: Close the definition of the registration function.
438
+
439
+ 4. Epilogue template: Close the namespace(s).
440
+ """
441
+
442
+ _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__exit__')
443
+ _LOGGER.debug('*** configuration_path (file to write): ' +
444
+ str(self.configuration_path))
445
+ _LOGGER.debug('*** configuration_name: ' + self.configuration_name)
446
+
447
+ self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
448
+ 'configuration_name': self.configuration_name
449
+ }))
450
+
451
+ for operation in self.operations:
452
+ stub_begin = ''
453
+ stub_end = ''
454
+ # It can be useful to stub (comment) out instantiations for testing.
455
+ # In this case, one need only set is_stub to True.
456
+ is_stub = False
457
+ if is_stub:
458
+ stub_begin = "// STUB for now\n#if 0"
459
+ stub_end = "#endif // 0"
460
+
461
+ kernel_name = 'ImplicitGemmConvolution'
462
+ operation_wrapper = 'Conv3dOperation'
463
+ if self.operation_is_3x(operation):
464
+ kernel_name = 'ConvUniversalAdapter'
465
+ operation_wrapper = 'ConvOperation3x'
466
+
467
+ self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
468
+ 'configuration_name': self.configuration_name,
469
+ 'operation_name': operation.procedural_name(),
470
+ 'kernel_name': kernel_name,
471
+ 'operation_wrapper': operation_wrapper,
472
+ 'stub_begin': stub_begin,
473
+ 'stub_end': stub_end
474
+ }))
475
+
476
+ self.configuration_file.write(self.configuration_epilogue)
477
+ self.configuration_file.write(self.epilogue_template)
478
+ self.configuration_file.close()
479
+
480
+
481
+ ###################################################################################################
482
+ ###################################################################################################
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for emitting CUTLASS >= 3 convolution kernels
35
+ """
36
+
37
+ import enum
38
+ import os.path
39
+ import shutil
40
+ import logging
41
+ from string import Template
42
+
43
+ try:
44
+ import builtins
45
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
+ raise ImportError("Disabling attempt to import cutlass_library")
47
+ from cutlass_library.library import *
48
+ except ImportError:
49
+ from library import *
50
+
51
+ _LOGGER = logging.getLogger(__name__)
52
+
53
+ ###################################################################################################
54
+ #
55
+ # Emits single instances of a CUTLASS device-wide operator
56
+ #
57
+ ###################################################################################################
58
+
59
+ class EmitConv3xInstance:
60
+ def __init__(self):
61
+ _LOGGER.debug("*** EmitConv3xInstance::__init__")
62
+
63
+ # Define epilogue type first, so that the mainloop type
64
+ # can use it with StageCountAutoCarveout.
65
+ self.template = """
66
+
67
+ // CUTLASS >= 3 convolution ${conv_kind_name} kernel instance "${operation_name}"
68
+ using ${operation_name}_epilogue =
69
+ typename cutlass::epilogue::collective::CollectiveBuilder<
70
+ ${arch},
71
+ ${opcode_class_epi},
72
+ ${mma_tile_shape}, // mma tile shape
73
+ ${cluster_shape}, // cluster shape
74
+ ${epi_tile_mn},
75
+ ${element_accumulator},
76
+ ${element_compute},
77
+ ${element_c}, ${layout_c}, 128 / cute::sizeof_bits_v<${element_c}>,
78
+ ${element_d}, ${layout_d}, 128 / cute::sizeof_bits_v<${element_d}>,
79
+ ${epilogue_schedule}
80
+ // , class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD,ElementCompute>
81
+ >::CollectiveOp;
82
+
83
+ using ${operation_name}_mainloop =
84
+ typename cutlass::conv::collective::CollectiveBuilder<
85
+ ${arch},
86
+ ${opcode_class_main},
87
+ ${conv_kind}, // kFprop, kDgrad, or kWgrad
88
+ ${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>,
89
+ ${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>,
90
+ ${element_accumulator},
91
+ ${mma_tile_shape}, // mma tile shape
92
+ ${cluster_shape}, // cluster shape
93
+ ${stages},
94
+ ${kernel_schedule}
95
+ >::CollectiveOp;
96
+
97
+ using ${operation_name}_problem_shape = cutlass::conv::ConvProblemShape<${conv_kind}, ${operation_name}_mainloop::NumSpatialDimensions>;
98
+
99
+ // Unit tests call this "ConvKernel".
100
+ // Conv operator ${operation_name}
101
+ using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
102
+ ${operation_name}_problem_shape,
103
+ ${operation_name}_mainloop,
104
+ ${operation_name}_epilogue,
105
+ ${tile_scheduler}
106
+ >;
107
+ """
108
+
109
+ def arch_number_to_type(self, arch: int) -> str:
110
+ return f"cutlass::arch::Sm{arch}"
111
+
112
+ def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
113
+ mma_m = cta_m
114
+ mma_n = cta_n
115
+ mma_k = cta_k
116
+
117
+ if operation.arch >= 100:
118
+ # MmaTileShape (mma_m, mma_n, mma_k) is passed to kernel mainloop where
119
+ # mma_m = cta_m for 1sm version and mma_m = cta_m * 2 for 2sm version.
120
+ # If schedule is auto and cluster size is static and cta_m % 64 == 0 and cluster_m % 2 == 0, 2sm kernel version is allocated,
121
+ # otherwise 1sm kernel is allocated.
122
+ cta_m_per_mma_instruction = 1
123
+ if "2sm" in operation.procedural_name() :
124
+ cta_m_per_mma_instruction = 2
125
+ elif "1sm" in operation.procedural_name() :
126
+ cta_m_per_mma_instruction = 1
127
+ elif operation.tile_description.cluster_shape[0] > 0 and operation.tile_description.cluster_shape[0] % 2 == 0 and cta_m % 64 == 0 :
128
+ cta_m_per_mma_instruction = 2
129
+ mma_m = cta_m * cta_m_per_mma_instruction
130
+
131
+ # For all three kinds of convolutions, the tile shape's K mode
132
+ # differs from GEMM in that needs to be wrapped in a Shape.
133
+ # For Wgrad convolutions specifically,
134
+ # the N tile shape also needs to be wrapped in a Shape.
135
+ m_template = 'cute::_${mma_m}'
136
+ if operation.conv_kind == ConvKind.Wgrad:
137
+ n_template = 'cute::Shape<cute::_${mma_n}>'
138
+ else:
139
+ n_template = 'cute::_${mma_n}'
140
+ k_template = 'cute::Shape<cute::_${mma_k}>'
141
+
142
+ mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
143
+ values = {
144
+ 'mma_m': mma_m,
145
+ 'mma_n': mma_n,
146
+ 'mma_k': mma_k
147
+ }
148
+ return Template(mma_tile_shape_template).substitute(values)
149
+
150
+ def cluster_shape(self, operation) -> str:
151
+ m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)'
152
+ n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)'
153
+ k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)'
154
+ cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
155
+ values = {
156
+ 'cluster_shape_m': operation.tile_description.cluster_shape[0],
157
+ 'cluster_shape_n': operation.tile_description.cluster_shape[1],
158
+ 'cluster_shape_k': operation.tile_description.cluster_shape[2],
159
+ }
160
+ return Template(cluster_shape_template).substitute(values)
161
+
162
+ def stage_count(self, operation) -> str:
163
+ # stages == 0 tells builder to pick the number of stages automatically
164
+ namespace_prefix = 'cutlass::conv::collective::'
165
+ if operation.tile_description.stages > 0:
166
+ return f"{namespace_prefix}StageCount<{str(operation.tile_description.stages)}>"
167
+ else:
168
+ return f"{namespace_prefix}StageCountAutoCarveout<sizeof(typename {operation.procedural_name()}_epilogue::SharedStorage)>"
169
+
170
+ def emit(self, operation) -> str:
171
+ _LOGGER.debug("*** EmitConv3xInstance::emit")
172
+ _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
173
+
174
+ # Identify the operation as CUTLASS 3 by its is_3x field
175
+ if (not hasattr(operation, 'is_3x')) or (not operation.is_3x):
176
+ raise RuntimeError("operation must be a CUTLASS 3 operation")
177
+
178
+ epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto"
179
+ opcode_class_main = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class]
180
+ opcode_class_epi = opcode_class_main
181
+
182
+ tile_shape = operation.tile_description.tile_shape
183
+ cluster_m = operation.tile_description.cluster_shape[0]
184
+ cluster_n = operation.tile_description.cluster_shape[1]
185
+
186
+ cta_m, cta_n, cta_k = tile_shape
187
+ # account for static/dynamic cluster shapes
188
+ if operation.arch >= 100:
189
+ cta_m = cta_m // cluster_m if cluster_m > 0 else cta_m
190
+ cta_n = cta_n // cluster_n if cluster_n > 0 else cta_n
191
+
192
+ warp_count = operation.tile_description.warp_count
193
+ epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule]
194
+
195
+ # KernelScheduleTag and TileSchedulerTag both hard-code the
196
+ # namespace qualification of KernelScheduleAuto as
197
+ # "cutlass::gemm::collective::" (unless the tag is 'void').
198
+ #
199
+ # For TileSchedulerTag, this namespace is fine, since CUTLASS 3
200
+ # convolutions use the same tile schedulers (from the same
201
+ # cutlass::gemm::collective namespace) as GEMMs.
202
+ kernel_schedule = KernelScheduleTag[operation.kernel_schedule].replace('gemm::', 'conv::')
203
+ tile_scheduler = TileSchedulerTag[operation.tile_scheduler]
204
+ opcode_class = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class]
205
+
206
+ values = {
207
+ 'operation_name': operation.procedural_name(),
208
+ 'conv_kind': ConvKindTag[operation.conv_kind],
209
+ 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
210
+ 'element_a': DataTypeTag[operation.A.element],
211
+ 'layout_a': LayoutTag[operation.A.layout],
212
+ 'align_a': int(operation.A.alignment),
213
+ 'element_b': DataTypeTag[operation.B.element],
214
+ 'layout_b': LayoutTag[operation.B.layout],
215
+ 'align_b': int(operation.B.alignment),
216
+ 'element_c': DataTypeTag[operation.C.element],
217
+ 'layout_c': LayoutTag[operation.C.layout],
218
+ 'align_c': int(operation.C.alignment),
219
+ 'element_d': DataTypeTag[operation.D.element],
220
+ 'layout_d': LayoutTag[operation.D.layout],
221
+ 'align_d': int(operation.D.alignment),
222
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
223
+ 'opcode_class': opcode_class,
224
+ 'arch': self.arch_number_to_type(operation.arch),
225
+ 'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k),
226
+ 'cluster_shape': self.cluster_shape(operation),
227
+ 'opcode_class_epi': opcode_class_epi,
228
+ 'opcode_class_main': opcode_class_main,
229
+ 'epi_tile_mn': epi_tile_mn,
230
+ 'stages': self.stage_count(operation),
231
+ 'kernel_schedule': kernel_schedule,
232
+ 'epilogue_schedule': epilogue_schedule,
233
+ 'tile_scheduler': tile_scheduler,
234
+ 'element_compute': DataTypeTag[operation.element_compute]
235
+ }
236
+ return Template(self.template).substitute(values)
237
+
238
+ class EmitConv3xIncludes:
239
+ def __init__(self):
240
+ _LOGGER.debug("*** EmitConv3xIncludes::__init__")
241
+ self.includes = ['conv_operation_3x.hpp',
242
+ 'cutlass/conv/device/conv_universal_adapter.hpp',
243
+ 'cutlass/conv/kernel/conv_universal.hpp',
244
+ 'cutlass/conv/collective/collective_builder.hpp',
245
+ 'cutlass/epilogue/collective/collective_builder.hpp']
246
+
247
+ def emit(self, operation) -> str:
248
+ _LOGGER.debug("*** EmitConv3xIncludes::emit")
249
+ return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
250
+ "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/emit_kernel_listing.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ #
34
+ #
35
+ # \brief Generates the CUTLASS kernel listing with kernel filtering
36
+ #
37
+
38
+ #
39
+
40
+ ###############################################################################
41
+ # Example usage:
42
+ # generator.py --operations all --generator-target kernel_listing \
43
+ # --architectures "70;75;80" --kernels "*" --disable-cutlass-package-imports
44
+ ###############################################################################
45
+
46
+ import collections
47
+ import csv
48
+ import json
49
+ import math
50
+ import os
51
+
52
+ try:
53
+ import builtins
54
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
55
+ raise ImportError("Disabling attempt to import cutlass_library")
56
+ from cutlass_library.library import *
57
+ except ImportError:
58
+ from library import *
59
+
60
+ audit_csv_fields = [
61
+ "KernelType", "KernelName", "Type_A", "Type_B", "Type_C", "Type_Acc", "Type_EpilogueScale", "Type_D", "Type_SFA", "Type_SFD",
62
+ "Layout_A", "Layout_B", "Layout_C", "Layout_D",
63
+ "Alignment_A", "Alignment_B", "Alignment_C", "Alignment_D",
64
+ "1SM/2SM",
65
+ "StreamK Enabled", "Support Runtime_Cluster_Shape", "Support Runtime_Input_Types",
66
+ "Test Counts"
67
+ ]
68
+
69
+ audit_csv_runtime_fields = [
70
+ "KerneIndex", "KernelName",
71
+ "Inst_M", "Inst_N", "Inst_K", "Tile_M", "Tile_N", "Tile_K",
72
+ "Cluster_M", "Cluster_N", "Cluster_K", "Preferred_Cluster_M", "Preferred_Cluster_N", "Preferred_Cluster_K", "Fallback_Cluster_M", "Fallback_Cluster_N", "Fallback_Cluster_K",
73
+ "M", "N", "K", "L", "Alpha_val", "Beta_val",
74
+ "Runtime_Input_Types Enabled", "Runtime_Cluster_Shape Enabled"
75
+ ]
76
+
77
+ def hash_cutlass_string(input_string):
78
+ mma_cluster_shape_pattern = r"_\d+x\d+x\d+" # Matches MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
79
+
80
+ # Remove MMA and Cluster shapes (e.g., '_128x128x256', '_0x0x1')
81
+ output = re.sub(mma_cluster_shape_pattern, "", input_string)
82
+
83
+ return output
84
+
85
+ def transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b):
86
+ # Define a dictionary mapping the detected types to runtime values
87
+ datatype_map = {
88
+ 'f4_f4': runtime_datatype_a + '_' + runtime_datatype_b,
89
+ 'f4_f6': runtime_datatype_a + '_' + runtime_datatype_b,
90
+ 'f4_f8': runtime_datatype_a + '_' + runtime_datatype_b,
91
+ 'f6_f4': runtime_datatype_a + '_' + runtime_datatype_b,
92
+ 'f6_f6': runtime_datatype_a + '_' + runtime_datatype_b,
93
+ 'f6_f8': runtime_datatype_a + '_' + runtime_datatype_b,
94
+ 'f8_f4': runtime_datatype_a + '_' + runtime_datatype_b,
95
+ 'f8_f6': runtime_datatype_a + '_' + runtime_datatype_b,
96
+ 'f8_f8': runtime_datatype_a + '_' + runtime_datatype_b,
97
+ 'ue8m0xf4_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
98
+ 'ue4m3xf4_ue4m3xf4': 'ue4m3x' + runtime_datatype_a + '_ue4m3x' + runtime_datatype_b,
99
+ 'ue8m0xf4_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
100
+ 'ue8m0xf4_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
101
+ 'ue8m0xf6_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
102
+ 'ue8m0xf6_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
103
+ 'ue8m0xf8_ue8m0xf4': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
104
+ 'ue8m0xf8_ue8m0xf6': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
105
+ 'ue8m0xf8_ue8m0xf8': 'ue8m0x' + runtime_datatype_a + '_ue8m0x' + runtime_datatype_b,
106
+ }
107
+
108
+ # Regular expression to detect all the keys in datatype_map
109
+ pattern = re.compile(r'(' + '|'.join(map(re.escape, datatype_map.keys())) + r')')
110
+
111
+ # Replace detected patterns using the dictionary
112
+ updated_kernel_name = pattern.sub(lambda match: datatype_map[match.group(0)], hashed_kernel_name)
113
+
114
+ return updated_kernel_name
115
+
116
+ # This helper function reports foundational kernel features: datatypes, layouts, alignment and stream-k.
117
+ def get_kernel_features(operation, kernel_name,
118
+ dynamic_datatype, runtime_input_datatype):
119
+ numcta_inst = "2sm" if "2sm" in kernel_name else "1sm"
120
+ math_inst = operation.tile_description.math_instruction
121
+
122
+ if dynamic_datatype:
123
+ dtype_name_A = runtime_input_datatype[0]
124
+ dtype_name_B = runtime_input_datatype[1]
125
+ else:
126
+ dtype_name_A = DataTypeNames[operation.A.element]
127
+ dtype_name_B = DataTypeNames[operation.B.element]
128
+
129
+ layout_name_A = ShortLayoutTypeNames[operation.A.layout]
130
+ layout_name_B = ShortLayoutTypeNames[operation.B.layout]
131
+ layout_name_C = ShortLayoutTypeNames[operation.C.layout]
132
+ layout_name_D = ShortLayoutTypeNames[operation.D.layout]
133
+
134
+ scale_factor_D_type = operation.ScaleFactorD.element if hasattr(operation, "ScaleFactorD") else DataType.void
135
+ scale_factor_A_type = getattr(operation, "ScaleFactorA", DataType.void)
136
+ audit_vals = [
137
+ "BlockScaledGEMM" if math_inst.opcode_class == OpcodeClass.BlockScaledTensorOp else "GEMM",
138
+ kernel_name,
139
+ dtype_name_A,
140
+ dtype_name_B,
141
+ DataTypeNames[operation.C.element],
142
+ DataTypeNames[operation.tile_description.math_instruction.element_accumulator],
143
+ DataTypeNames[operation.element_epilogue],
144
+ DataTypeNames[operation.D.element],
145
+ DataTypeNames[scale_factor_D_type],
146
+ DataTypeNames[scale_factor_A_type],
147
+ layout_name_A,
148
+ layout_name_B,
149
+ layout_name_C,
150
+ layout_name_D,
151
+ str(operation.A.alignment),
152
+ str(operation.B.alignment),
153
+ str(operation.C.alignment),
154
+ str(operation.D.alignment),
155
+ numcta_inst,
156
+ "Y" if 'stream_k' in kernel_name else "N",
157
+ ]
158
+ return audit_vals
159
+
160
+ # This helper function reports other performance-related kernel parameters and those can be specified at runtime: cluster_shape, instruction shap, m/n/k and alpha/beta.
161
+ def get_kernel_params(operation, kernel_name, cluster_shape, fallback_cluster_shape, problem_shape, alpha, beta, dynamic_datatype, dynamic_cluster):
162
+ math_inst = operation.tile_description.math_instruction
163
+ audit_vals = [
164
+ str(math_inst.instruction_shape[0]),
165
+ str(math_inst.instruction_shape[1]),
166
+ str(math_inst.instruction_shape[2]),
167
+ str(operation.tile_description.threadblock_shape[0]),
168
+ str(operation.tile_description.threadblock_shape[1]),
169
+ str(operation.tile_description.threadblock_shape[2]),
170
+ str(operation.tile_description.cluster_shape[0]),
171
+ str(operation.tile_description.cluster_shape[1]),
172
+ str(operation.tile_description.cluster_shape[2]),
173
+ str(cluster_shape[0]),
174
+ str(cluster_shape[1]),
175
+ str(cluster_shape[2]),
176
+ str(fallback_cluster_shape[0]),
177
+ str(fallback_cluster_shape[1]),
178
+ str(fallback_cluster_shape[2]),
179
+ str(problem_shape[0]),
180
+ str(problem_shape[1]),
181
+ str(problem_shape[2]),
182
+ str(problem_shape[3]),
183
+ str(alpha),
184
+ str(beta),
185
+ "Y" if dynamic_datatype else "N",
186
+ "Y" if dynamic_cluster else "N",
187
+ ]
188
+ return audit_vals
189
+
190
+
191
+ def _getSubOperationType(kernel):
192
+
193
+ if kernel.operation_kind == OperationKind.Gemm:
194
+ return GemmKindNames[kernel.gemm_kind]
195
+ elif kernel.operation_kind == OperationKind.Conv2d:
196
+ return "conv_" + ConvKindNames[kernel.conv_kind]
197
+ elif kernel.operation_kind == OperationKind.Syrk:
198
+ return "syrk_" + SyrkKindNames[kernel.syrk_kind]
199
+ elif kernel.operation_kind == OperationKind.Trmm:
200
+ return "trmm_" + TrmmKindNames[kernel.trmm_kind]
201
+ elif kernel.operation_kind == OperationKind.Symm:
202
+ return "symm_" + SymmKindNames[kernel.symm_kind]
203
+ else:
204
+ raise Exception("Unsupported kernel type")
205
+
206
+ def _get_inst_shape(math_instruction):
207
+ return "".join(str(x) for x in math_instruction.instruction_shape)
208
+
209
+ def _is_simt_inst(math_instruction):
210
+ return _get_inst_shape(math_instruction) in ["111","114"]
211
+
212
+ def _getInstType(input_precision, accumulate_precision, math_instruction):
213
+
214
+ # inst_shape
215
+ inst_shape = _get_inst_shape(math_instruction)
216
+
217
+ # input precision
218
+ if input_precision == "fp32" and inst_shape != "111":
219
+ inp = "tf32"
220
+ else:
221
+ inp = input_precision
222
+
223
+ # Handle SIMT op types first
224
+ if _is_simt_inst(math_instruction):
225
+
226
+ simt_input_precision_to_inst = {
227
+ "fp32": "FFMA",
228
+ "fp64": "DFMA",
229
+ "fp16": "HFMA",
230
+ "int8": "IDP4A",
231
+ }
232
+ inst = simt_input_precision_to_inst[input_precision]
233
+
234
+ else: # Tensor op instructions
235
+
236
+ if accumulate_precision == "cf64":
237
+ fp64_acc_map = {
238
+ MathOperation.multiply_add_complex_gaussian : "gz",
239
+ MathOperation.multiply_add_complex : "z",
240
+ }
241
+ acc = fp64_acc_map[math_instruction.math_operation]
242
+ else:
243
+ tensor_op_acc_map = {
244
+ "fp32" : "s",
245
+ "cf32" : "s",
246
+ "fp16" : "h",
247
+ "int32": "i",
248
+ "fp64" : "d",
249
+ }
250
+ acc = tensor_op_acc_map[accumulate_precision]
251
+
252
+ inst = "{}{}{}".format(acc, inst_shape, inp)
253
+
254
+ return inst
255
+ # TODO: Computes FLOps/Bytes for GEMM - revisit for conv
256
+ def _computeFlopsPerByte(operation, m, n, k, batch_count=1, beta=0.0, num_groups=1):
257
+ assert not (batch_count > 1 and num_groups > 1)
258
+
259
+ # TODO: adjust for sparsity
260
+ gmem_bytes = (
261
+ (DataTypeSize[operation.A.element] * m // 8) * k +
262
+ (DataTypeSize[operation.B.element] * n // 8) * k +
263
+ (DataTypeSize[operation.C.element] * m // 8) * n
264
+ )
265
+
266
+ # TODO: complex-valued support
267
+ flops = 2 * (m * n * k)
268
+
269
+ if bool(beta):
270
+ gmem_bytes += (DataTypeSize[operation.C.element] * m // 8) * n
271
+ flops += 2 * m * n
272
+
273
+ multiplier = max(batch_count, num_groups)
274
+ gmem_bytes *= multiplier
275
+ flops *= multiplier
276
+
277
+ return flops / gmem_bytes
278
+
279
+ def emit_gemm_kernel_testlist(manifest, curr_build_dir, arch, mode
280
+ ):
281
+ # For functional testing, we prefer to run reference computing on device if any
282
+ reference_device_archs = ["100a", "103a"]
283
+ run_reference_on_device = True if arch in reference_device_archs and mode in ["functional_L0", "functional_L1"] else False
284
+ profiler_flags_for_verification = "device" if run_reference_on_device else "host"
285
+
286
+ # beta values for L0 and L1
287
+ # TODO: randomize beta values for wider coverage
288
+ beta_values = [0.5]
289
+
290
+ is_supported_arch = (arch in ["100a", "100f", "101a", "101f", "103a", "110a", "110f", "120a", "120f", "121a", "121f"])
291
+
292
+ is_runtime_datatype_enabled = mode == "functional_L0" and is_supported_arch
293
+
294
+ if (mode == "functional_L0") and is_supported_arch:
295
+ problem_waves = [0.5, 1.25, 2.5]
296
+
297
+ #
298
+ # Dense Gemm
299
+ #
300
+
301
+ sm100_mma_data_type_general = [
302
+ 'gemm_f16_f16_f16_f16_f16',
303
+ 'gemm_f16_f16_f16_void_f16',
304
+ #'gemm_f16_f16_f32_f16_f16',
305
+ 'tf32gemm_f32_f32_f32_f32_f32',
306
+ 'bf16gemm_f32_f32_f32_f32_f32',
307
+ ]
308
+
309
+ exclude_archs = arch not in ("103a")
310
+ if exclude_archs:
311
+ sm100_mma_data_type_general.append('gemm_s8_s8_s32_s8_s8')
312
+
313
+ sm100_mma_data_type_runtime_dtype = [
314
+ 'gemm.*f4_f4_f32_f32_f32',
315
+ 'gemm.*f6_f6_f32_f32_f32',
316
+ 'gemm.*f8_f8_f32_f32_f32',
317
+ ]
318
+
319
+ sm100_mma_cluster_size = [
320
+ '8x1x1',
321
+ '4x4x1', '2x1x1',
322
+ '0x0x1' # dynamic cluster
323
+ ]
324
+
325
+ # Restrict to two layouts to reduce L0 build and test time.
326
+ sm100_mma_layouts = [
327
+ 'tnt',
328
+ 'ntn'
329
+ ]
330
+
331
+ # regex list must be in kernel procedural name order
332
+ sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
333
+ sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_general, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
334
+
335
+ sm100_mma_filter_regex_1sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
336
+ sm100_mma_filter_regex_2sm_runtime = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_data_type_runtime_dtype, sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
337
+
338
+ #
339
+ # Block Scale Gemm
340
+ #
341
+
342
+ block_scaled_data_type = [
343
+ # runtime datatypes
344
+ 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
345
+ 'gemm.*ue4m3xf4_ue4m3xf4_f32_f16_e5m2',
346
+ 'gemm.*ue8m0xf4_ue8m0xf6_f32_f16_e5m2',
347
+ #'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
348
+ 'gemm.*ue8m0xf6_ue8m0xf6_f32_f16_ue8m0xe3m2',
349
+ ]
350
+
351
+ block_scaled_tile_k = ['x128_', 'x256_']
352
+
353
+ sm103_block_scaled_data_type = [
354
+ 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_e5m2',
355
+ 'gemm.*ue8m0xf4_ue8m0xf4_f32_f16_ue8m0xe2m1',
356
+ ]
357
+
358
+ sm103_block_scaled_tile_k = ['x768_']
359
+
360
+ block_scaled_cluster_size = [
361
+ '4x4x1', '2x1x1',
362
+ '0x0x1' # dynamic cluster
363
+ ]
364
+
365
+ block_scaled_layouts = ['tnt']
366
+ # regex list must be in kernel procedural name order
367
+ block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
368
+ block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
369
+
370
+ sm103_block_scaled_prefetch_policy = ['tmapf']
371
+ sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*"
372
+ sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, sm103_block_scaled_tile_k, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*(" + "|".join(sm103_block_scaled_prefetch_policy) + ").*"
373
+
374
+ if arch in ["100a", "100f"]:
375
+ kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
376
+ f"({sm100_mma_filter_regex_2sm})|" \
377
+ f"({sm100_mma_filter_regex_1sm_runtime})|" \
378
+ f"({sm100_mma_filter_regex_2sm_runtime})|" \
379
+ f"({block_scaled_filter_regex_1sm})|" \
380
+ f"({block_scaled_filter_regex_2sm})"
381
+ elif arch in ["101a", "101f", "110a", "110f"]:
382
+ kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
383
+ f"({sm100_mma_filter_regex_2sm})|" \
384
+ f"({sm100_mma_filter_regex_1sm_runtime})|" \
385
+ f"({sm100_mma_filter_regex_2sm_runtime})|" \
386
+ f"({block_scaled_filter_regex_1sm})|" \
387
+ f"({block_scaled_filter_regex_2sm})"
388
+ elif arch in ["103a"]:
389
+ kernel_filter = f"({sm100_mma_filter_regex_1sm})|" \
390
+ f"({sm100_mma_filter_regex_2sm})|" \
391
+ f"({sm100_mma_filter_regex_1sm_runtime})|" \
392
+ f"({sm100_mma_filter_regex_2sm_runtime})|" \
393
+ f"({block_scaled_filter_regex_1sm})|" \
394
+ f"({block_scaled_filter_regex_2sm})|" \
395
+ f"({sm103_block_scaled_filter_regex_1sm})|" \
396
+ f"({sm103_block_scaled_filter_regex_2sm})"
397
+ elif arch in ["120a", "120f", "121a", "121f"]:
398
+
399
+ # blockscaled sm120_mma kernels
400
+ blockscaled_sm120_mma_kernel_cta_tiles = [
401
+ [ '128x128' ]
402
+ ]
403
+
404
+ # Restrict to two layouts to reduce L0 build and test time.
405
+ blockscaled_sm120_mma_layouts = [ 'tn' ]
406
+ filter_regex_blockscaled_sm120_mma = "cutlass3x_sm120_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [blockscaled_sm120_mma_kernel_cta_tiles[0], blockscaled_sm120_mma_layouts]]) + ").*"
407
+
408
+ problem_waves = [0.5, 1.25, 2.5]
409
+
410
+ kernel_filter = f"({filter_regex_blockscaled_sm120_mma})"
411
+ else:
412
+ error_message = "unsupported arch, only support sm100a, sm100f, sm101a, sm101f, sm110a, sm110f, sm103a, sm120a, sm120f, sm121a, sm121f"
413
+ raise Exception(error_message)
414
+
415
+ elif mode == "functional_L1":
416
+ sm100_mma_cluster_size = [
417
+ '0x0x1' # dynamic cluster
418
+ ]
419
+ # Restrict to two layouts to reduce L1 build and test time.
420
+ sm100_mma_layouts = ['tnt', 'ntn']
421
+ sm100_mma_filter_regex_1sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*1sm.*"
422
+ sm100_mma_filter_regex_2sm = "cutlass3x_sm100_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm100_mma_cluster_size, sm100_mma_layouts]]) + ").*2sm.*"
423
+ block_scaled_data_type = [
424
+ 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
425
+ 'ue8m0xe2m1_ue8m0xe2m3_f32_f16_e5m2',
426
+ 'ue8m0xmx8s26_ue8m0xmx8s26_f32_f16_e5m2',
427
+ 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
428
+ 'ue8m0xe2m3_ue8m0xe2m3_f32_f16_ue8m0xe3m2',
429
+ ]
430
+
431
+ sm103_block_scaled_data_type = [
432
+ 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_e5m2',
433
+ 'ue8m0xe2m1_ue8m0xe2m1_f32_f16_ue8m0xe2m1',
434
+ ]
435
+
436
+ block_scaled_cluster_size = ['0x0x1']
437
+ block_scaled_layouts = ['tnt']
438
+
439
+ # regex list must be in kernel procedural name order
440
+ block_scaled_filter_regex_1sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
441
+ block_scaled_filter_regex_2sm = "cutlass3x_sm100_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
442
+
443
+ sm103_block_scaled_filter_regex_1sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*1sm.*"
444
+ sm103_block_scaled_filter_regex_2sm = "cutlass3x_sm103_bstensorop.*(" + ").*(".join([ "|".join(x) for x in [sm103_block_scaled_data_type, block_scaled_cluster_size, block_scaled_layouts]]) + ").*2sm.*"
445
+
446
+ filter_regex_sm100_mma = f"({sm100_mma_filter_regex_1sm})|" \
447
+ f"({sm100_mma_filter_regex_2sm})|" \
448
+ f"({block_scaled_filter_regex_1sm})|" \
449
+ f"({block_scaled_filter_regex_2sm})" \
450
+ f"({sm103_block_scaled_filter_regex_1sm})|" \
451
+ f"({sm103_block_scaled_filter_regex_2sm})"
452
+ # CTA tiles for sm120 MMA - only run one tile size to reduce build/test times
453
+ sm120_mma_kernel_cta_tiles = [
454
+ # h1688, s1688, i16832, i8816
455
+ [ '256x128' ],
456
+ # d884, c1688,
457
+ [ '128x128' ],
458
+ # c1688, z884
459
+ [ '128x64' ],
460
+ # gz884
461
+ [ '64x64' ]
462
+ ]
463
+
464
+ # sm120 MMA instruction shapes, planar complex type excluded as they are not required
465
+ sm120_mma_instruction_shapes = [
466
+ [ 'h1688gemm_(?!planar_complex)',
467
+ 's1688gemm_f16',
468
+ 's1688gemm_bf16',
469
+ 's1688gemm_tf32',
470
+ 'i16832gemm',
471
+ 'i8816gemm' ],
472
+ [ 'd884gemm', 'c1688tf32gemm' ] ,
473
+ [ 'c1688gemm',
474
+ 'z884gemm' ],
475
+ [ 'gz884gemm']
476
+ ]
477
+
478
+ # It's not pretty, but not sure why different instructions support different tile sizes.
479
+ filter_regex_sm120_mma_0 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[0], sm120_mma_kernel_cta_tiles[0]]]) + ").*"
480
+ filter_regex_sm120_mma_1 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[1], sm120_mma_kernel_cta_tiles[1]]]) + ").*"
481
+ filter_regex_sm120_mma_2 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[2], sm120_mma_kernel_cta_tiles[2]]]) + ").*"
482
+ filter_regex_sm120_mma_3 = "cutlass_tensorop.*(" + ").*(".join([ "|".join(x) for x in [sm120_mma_instruction_shapes[3], sm120_mma_kernel_cta_tiles[3]]]) + ").*"
483
+
484
+ filter_regex_sm120_mma = f"({filter_regex_sm120_mma_0})|({filter_regex_sm120_mma_1})|({filter_regex_sm120_mma_2})|({filter_regex_sm120_mma_3})"
485
+
486
+ problem_waves = [0.5, 1.25, 2.5]
487
+
488
+ if arch in ["120a", "120f", "121a", "121f"]:
489
+ kernel_filter = f"({filter_regex_sm120_mma})"
490
+ else:
491
+ kernel_filter = f"({filter_regex_sm100_mma})"
492
+ else:
493
+ raise ValueError()
494
+
495
+ outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")
496
+
497
+ audit_file_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_SM{arch}_cutlass3x_gemm.csv")
498
+
499
+ audit_file_params_name = os.path.join(curr_build_dir, f"FK_{mode}_audit_params_SM{arch}_cutlass3x_gemm.csv")
500
+
501
+ kernel_filter_re = re.compile(kernel_filter)
502
+ testcase_counter = 0
503
+ kernels_emitted = 0
504
+ kernels_total = 0
505
+
506
+ perf_json_list = []
507
+ kernel_name_set = set()
508
+
509
+ testlist_csv_fields = ["testcase", "metadata"]
510
+ testlist_csv_rows = []
511
+ auditlist_csv_map = {}
512
+ auditlist_csv_params_map = {}
513
+
514
+ kernel_features = {}
515
+
516
+ for cc in manifest.operations[OperationKind.Gemm].keys():
517
+ for kernel_name, operation_l in manifest.operations[OperationKind.Gemm][cc].items():
518
+ assert(len(operation_l) == 1)
519
+ kernels_total += 1
520
+ if len(kernel_filter_re.findall(kernel_name)) == 0:
521
+ continue
522
+ # Only test f16 I/O void C kernels in void C kernel set
523
+ # Exception: Use void C kernels for more accurate perf testing
524
+ if '_void_' in kernel_name and 'perf_' not in mode:
525
+ if 'f16_f16_f16_void_f16' not in kernel_name :
526
+ continue
527
+
528
+ kernels_emitted += 1
529
+ kernel_name_set.add(kernel_name)
530
+ hashed_kernel_name = hash_cutlass_string(kernel_name)
531
+ operation = operation_l[0]
532
+
533
+ dynamic_cluster = (operation.tile_description.cluster_shape[0] == 0
534
+ or operation.tile_description.cluster_shape[1] == 0)
535
+
536
+ dynamic_datatype = "f8" in kernel_name or "f6" in kernel_name or "f4" in kernel_name
537
+
538
+ runtime_input_datatypes = [None]
539
+
540
+ if dynamic_datatype:
541
+ if "f4_f4" in kernel_name:
542
+ runtime_input_datatypes = [['e2m1','e2m1']]
543
+ elif "f4_f6" in kernel_name:
544
+ runtime_input_datatypes = [['e2m1','e3m2']]
545
+ elif "f4_f8" in kernel_name:
546
+ runtime_input_datatypes = [['e2m1','e4m3']]
547
+
548
+ elif "f6_f4" in kernel_name:
549
+ runtime_input_datatypes = [['e3m2','e2m1']]
550
+ elif "f6_f6" in kernel_name:
551
+ runtime_input_datatypes = [['e3m2','e3m2']]
552
+ elif "f6_f8" in kernel_name:
553
+ runtime_input_datatypes = [['e3m2','e4m3']]
554
+
555
+ elif "f8_f4" in kernel_name:
556
+ runtime_input_datatypes = [['e4m3','e2m1']]
557
+ elif "f8_f6" in kernel_name:
558
+ runtime_input_datatypes = [['e4m3','e3m2']]
559
+ elif "f8_f8" in kernel_name:
560
+ runtime_input_datatypes = [
561
+ # mask out those not covered in statically encoded test cases
562
+ # ['e5m2','e4m3'],
563
+ # ['e4m3','e5m2'],
564
+ ['e4m3','e4m3']
565
+ ]
566
+
567
+ # block scaled kernels
568
+ elif "ue8m0xf4_ue8m0xf4" in kernel_name:
569
+ runtime_input_datatypes = [['e2m1','e2m1']]
570
+ elif "ue4m3xf4_ue4m3xf4" in kernel_name:
571
+ runtime_input_datatypes = [['e2m1','e2m1']]
572
+ elif "ue8m0xf4_ue8m0xf6" in kernel_name:
573
+ runtime_input_datatypes = [['e2m1','e2m3']]
574
+ elif "ue8m0xf4_ue8m0xf8" in kernel_name:
575
+ runtime_input_datatypes = [['e2m1','e4m3']]
576
+
577
+ elif "ue8m0xf6_ue8m0xf4" in kernel_name:
578
+ runtime_input_datatypes = [['e2m3','e2m1']]
579
+ elif "ue8m0xf6_ue8m0xf6" in kernel_name:
580
+ runtime_input_datatypes = [['e2m3','e2m3']]
581
+ elif "ue8m0xf8_ue8m0xf4" in kernel_name:
582
+ runtime_input_datatypes = [['e4m3','e2m1']]
583
+
584
+ elif "ue8m0xf8_ue8m0xf4" in kernel_name:
585
+ runtime_input_datatypes = [['e4m3','e2m1']]
586
+ elif "ue8m0xf8_ue8m0xf6" in kernel_name:
587
+ runtime_input_datatypes = [['e4m3','e2m3']]
588
+ elif "ue8m0xf8_ue8m0xf8" in kernel_name:
589
+ runtime_input_datatypes = [['e4m3','e4m3']]
590
+
591
+ if "bstensorop" in kernel_name or is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
592
+ profiler_flags_for_verification = "host"
593
+
594
+ # reduce L1 test runtime if reference kernel is not running on device.
595
+ if mode == "functional_L1" and profiler_flags_for_verification == "host" :
596
+ problem_waves = [0.5, 2.5]
597
+
598
+
599
+ if dynamic_cluster:
600
+ if mode == "functional_L0":
601
+ runtime_cluster_shapes = [[1,1,1], [2,2,1]]
602
+ else:
603
+ runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1], [2,4,1], [4,2,1], [4,4,1]]
604
+ # reduce L1 test runtime if reference kernel is not running on device.
605
+ if profiler_flags_for_verification == "host":
606
+ runtime_cluster_shapes = [[1,1,1], [1,2,1], [2,1,1], [2,2,1], [1,4,1], [4,1,1]]
607
+ cta_tile_shape_m, cta_tile_shape_n, cta_tile_shape_k = operation.tile_description.threadblock_shape
608
+ else:
609
+ runtime_cluster_shapes = [operation.tile_description.cluster_shape]
610
+ cta_tile_shape_m = int(operation.tile_description.threadblock_shape[0] / operation.tile_description.cluster_shape[0])
611
+ cta_tile_shape_n = int(operation.tile_description.threadblock_shape[1] / operation.tile_description.cluster_shape[1])
612
+ cta_tile_shape_k = int(operation.tile_description.threadblock_shape[2] / operation.tile_description.cluster_shape[2])
613
+
614
+ alignment_a = operation.A.alignment
615
+ alignment_b = operation.B.alignment
616
+ alignment_c = operation.C.alignment
617
+ alignment_ab_max = max(alignment_a, alignment_b)
618
+
619
+ layout3x = operation.layout_name_3x()
620
+ data_types = operation.datatype_name_3x()
621
+
622
+ ctas_per_mma_instruction = 1
623
+ if '_2sm' in kernel_name:
624
+ ctas_per_mma_instruction = 2
625
+ valid_cluster_shapes = []
626
+
627
+ # Remove any cluster shapes that have cluster_m that is not divisible by 2
628
+ for cs in runtime_cluster_shapes:
629
+ if cs[0] % 2 == 0:
630
+ valid_cluster_shapes.append(cs)
631
+ runtime_cluster_shapes = valid_cluster_shapes
632
+
633
+ kernel_problem_waves = problem_waves
634
+ if mode == "functional_L0" or mode == "functional_L1":
635
+ # for functional testing, we want to perturb just a little from even shapes
636
+ # large K = 8 is chosen such that some kernels will warp around their smem buffers, and some will not
637
+ # -16 ensures that we are TMA aligned even for FP8/Int8
638
+ min_k = alignment_ab_max if cta_tile_shape_k == alignment_ab_max else cta_tile_shape_k - alignment_ab_max
639
+ max_k = (cta_tile_shape_k*8) - alignment_ab_max
640
+ problem_shapes_k = [min_k, max_k]
641
+ sm_count = 16
642
+ swizzle_sizes = [0]
643
+ # Larger k and less than half wave trigger streamk +separate reduction case to be generated
644
+ if 'stream_k' in kernel_name:
645
+ problem_shapes_k = [max_k, cta_tile_shape_k*32]
646
+ kernel_problem_waves = [0.125, 1.25, 2.5]
647
+ else:
648
+ raise ValueError
649
+
650
+ if "void" in kernel_name:
651
+ beta_values = [0]
652
+
653
+ alignment_shift_m = max(alignment_c, alignment_a)
654
+ alignment_shift_n = max(alignment_c, alignment_b)
655
+
656
+ is_first_line = True
657
+ for index_waves, waves in enumerate(kernel_problem_waves):
658
+ for index_k, k in enumerate(problem_shapes_k):
659
+ for beta in beta_values:
660
+ for cluster_shape in runtime_cluster_shapes:
661
+ for runtime_input_datatype in runtime_input_datatypes:
662
+ for swizzle_size in swizzle_sizes:
663
+ grid_size = waves * sm_count
664
+ cluster_shape_m, cluster_shape_n, cluster_shape_k = tuple(cluster_shape)
665
+ if cluster_shape_m >= cluster_shape_n:
666
+ grid_m = cluster_shape_m
667
+ grid_n = grid_size / grid_m
668
+ grid_n = max( int((grid_n + cluster_shape_n - 1) / cluster_shape_n) * cluster_shape_n, 1)
669
+ else:
670
+ grid_n = cluster_shape_n
671
+ grid_m = grid_size / grid_n
672
+ grid_m = max( int((grid_m + cluster_shape_m - 1) / cluster_shape_m) * cluster_shape_m, 1)
673
+
674
+ verification_required = False
675
+ if mode == "functional_L0" or mode == "functional_L1":
676
+ if '_void_' not in kernel_name:
677
+ verification_required = True
678
+
679
+ m = max(int(grid_m * cta_tile_shape_m), alignment_ab_max)
680
+ n = max(int(grid_n * cta_tile_shape_n), alignment_ab_max)
681
+ k = int(k)
682
+
683
+ # For functional testing, we want to perturb just a little from even shapes.
684
+ # Only do this if the perturbation does not cause one of the dimensions of the
685
+ # problem size to go to zero. This can occur for blockscaling kernels for which
686
+ # the alignment requirements for A and B can be quite large (e.g., 256).
687
+ if m > alignment_shift_m:
688
+ m -= alignment_shift_m
689
+ if n > alignment_shift_n:
690
+ n -= alignment_shift_n
691
+
692
+ if '_n32t32_' in kernel_name:
693
+ continue
694
+ batch_count = 1
695
+ if mode == "functional_L0" or mode == "functional_L1" :
696
+ if index_waves == 0 and index_k == 0 :
697
+ batch_count = 3 if mode == "functional_L0" else 5
698
+ gemm_op = "gemm"
699
+
700
+ grouped = is_grouped(manifest.operations_by_name[kernel_name].gemm_kind)
701
+ num_groups = 1
702
+ if grouped:
703
+ gemm_op = "grouped_gemm"
704
+ num_groups = 3 # small to limit test time in host block-scaled reference kernels
705
+ batch_count = 1
706
+ elif "bstensorop" in kernel_name:
707
+ gemm_op = "block_scaled_gemm"
708
+ elif is_blockwise(manifest.operations_by_name[kernel_name].gemm_kind):
709
+ gemm_op = "blockwise_gemm"
710
+
711
+ problem_size_category = ['smallK','largeK'][index_k] + '_' + ['beta==0','beta!=0'][bool(beta)]
712
+
713
+ assert m > 0 and n > 0 and k > 0
714
+
715
+ # Emit per-testcase metadata for perf testing usage, eventually in perf database
716
+ metadata_dict = {
717
+ "input_params": {
718
+ 'problem_size_category' : problem_size_category,
719
+ 'operation' : _getSubOperationType(operation),
720
+ 'datatype' : data_types,
721
+ 'layout' : layout3x,
722
+ 'm' : m,
723
+ 'n' : n,
724
+ 'k' : k,
725
+ 'beta' : beta,
726
+ 'flops_per_byte' : _computeFlopsPerByte(operation, m, n, k, batch_count, beta, num_groups)
727
+ },
728
+ "runtime_params": {
729
+ 'ctas_per_mma_instruction' : ctas_per_mma_instruction,
730
+ 'tilesize_m' : cta_tile_shape_m,
731
+ 'tilesize_n' : cta_tile_shape_n,
732
+ 'tilesize_k' : cta_tile_shape_k,
733
+ 'cluster_shape_m' : cluster_shape_m,
734
+ 'cluster_shape_n' : cluster_shape_n,
735
+ }
736
+ }
737
+
738
+ cluster_m_fallback = ctas_per_mma_instruction if dynamic_cluster else cluster_shape_m
739
+ cluster_n_fallback = 1 if dynamic_cluster else cluster_shape_n
740
+ cluster_k_fallback = 1 if dynamic_cluster else cluster_shape_k
741
+
742
+
743
+ if dynamic_datatype:
744
+ runtime_datatype_a, runtime_datatype_b = tuple(runtime_input_datatype)
745
+ metadata_dict["runtime_params"]["runtime_datatype_a"] = runtime_datatype_a
746
+ metadata_dict["runtime_params"]["runtime_datatype_b"] = runtime_datatype_b
747
+
748
+ testcase_metadata = [
749
+ f"cutlass_profiler --operation={gemm_op}" +
750
+ (f" --verification-providers=device --providers=cutlass" if profiler_flags_for_verification == "device" else " --mode=trace") +
751
+ f" --error-on-no-match --error-if-nothing-is-profiled" +
752
+ f" --kernels={kernel_name}" +
753
+ f" --m={str(m)}" +
754
+ f" --n={str(n)}" +
755
+ f" --k={str(k)}" +
756
+ (f" --num_groups={str(num_groups)}" if grouped else "") +
757
+ f" --cluster_m={str(cluster_shape_m)}" +
758
+ f" --cluster_n={str(cluster_shape_n)}" +
759
+ f" --cluster_k={str(cluster_shape_k)}" +
760
+ f" --cluster_m_fallback={str(cluster_m_fallback)}" +
761
+ f" --cluster_n_fallback={str(cluster_n_fallback)}" +
762
+ f" --cluster_k_fallback={str(cluster_k_fallback)}" +
763
+ f" --beta={str(beta)}" +
764
+ ("" if grouped else f" --batch_count={str(batch_count)}") +
765
+ f" --swizzle_size={str(swizzle_size)}" +
766
+ f" --verification-required={str(verification_required).lower()}"
767
+ ] \
768
+
769
+ output_dynamic_datatype = dynamic_datatype
770
+ if output_dynamic_datatype:
771
+ testcase_metadata[0] += (f" --runtime_input_datatype_a={runtime_datatype_a}" +
772
+ f" --runtime_input_datatype_b={runtime_datatype_b}")
773
+
774
+ testcase_metadata.append(json.dumps(metadata_dict))
775
+ testlist_csv_rows.append(testcase_metadata)
776
+ testcase_counter += 1
777
+
778
+ alpha = 1.0
779
+
780
+ if dynamic_datatype:
781
+ hashed_kernel_name = transform_hashed_string(hashed_kernel_name, runtime_datatype_a, runtime_datatype_b)
782
+
783
+ # If kernel_name is new, initialize its feature set with defaults
784
+ if hashed_kernel_name not in kernel_features:
785
+ kernel_features[hashed_kernel_name] = {
786
+ "is_support_dynamic_cluster": False,
787
+ "is_support_dynamic_datatype": False,
788
+ }
789
+
790
+ # Update features for the hashed kernel name
791
+ kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] |= dynamic_cluster
792
+ kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] |= dynamic_datatype
793
+
794
+ if hashed_kernel_name not in auditlist_csv_params_map:
795
+ auditlist_csv_params_map[hashed_kernel_name] = []
796
+
797
+ audit_row_params = get_kernel_params(
798
+ operation,
799
+ hashed_kernel_name,
800
+ (cluster_shape_m, cluster_shape_n, cluster_shape_k),
801
+ (cluster_m_fallback, cluster_n_fallback, cluster_k_fallback),
802
+ (m, n, k, batch_count),
803
+ alpha, beta,
804
+ dynamic_datatype, dynamic_cluster
805
+ )
806
+
807
+ auditlist_csv_params_map[hashed_kernel_name].append(audit_row_params)
808
+
809
+ if hashed_kernel_name not in auditlist_csv_map:
810
+ audit_row = get_kernel_features(operation, hashed_kernel_name, dynamic_datatype, runtime_input_datatype)
811
+ auditlist_csv_map[hashed_kernel_name] = audit_row
812
+
813
+ with open(outfile_name, 'w') as testlist_csv:
814
+ csv_writer = csv.writer(testlist_csv, delimiter=',')
815
+ csv_writer.writerow(testlist_csv_fields)
816
+ csv_writer.writerows(testlist_csv_rows)
817
+
818
+ with open(audit_file_name, 'w') as auditlist_csv:
819
+ csv_writer = csv.writer(auditlist_csv, delimiter=',')
820
+ csv_writer.writerow(audit_csv_fields)
821
+ for hashed_kernel_name, row in auditlist_csv_map.items():
822
+ # Append the dynamic features as "Y" or "N"
823
+ dynamic_cluster_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_cluster"] else "N"
824
+ dynamic_datatype_flag = "Y" if kernel_features[hashed_kernel_name]["is_support_dynamic_datatype"] else "N"
825
+ test_count = len(auditlist_csv_params_map[hashed_kernel_name])
826
+ csv_writer.writerow(row + [dynamic_cluster_flag, dynamic_datatype_flag, test_count])
827
+
828
+ with open(audit_file_params_name, 'w') as auditlist_csv:
829
+ csv_writer = csv.writer(auditlist_csv, delimiter=',')
830
+ csv_writer.writerow(audit_csv_runtime_fields)
831
+ for kernel_index, (hashed_kernel_name, rows) in enumerate(auditlist_csv_params_map.items(), start=1):
832
+ for i, row in enumerate(rows):
833
+ if i == 0:
834
+ csv_writer.writerow([kernel_index, hashed_kernel_name] + row)
835
+ else:
836
+ csv_writer.writerow(["", ""] + row)
837
+
838
+ print(f"Generated a total of {testcase_counter} test cases for {kernels_emitted} kernels out of {kernels_total} total.")
839
+
840
+ # Generate a newline separated list of kernel filters
841
+ assert(len(kernel_name_set) == kernels_emitted)
842
+ output_filter_enabled = True
843
+ if output_filter_enabled:
844
+ kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list")
845
+ with open(kernel_filter_outfile_name, "w") as file:
846
+ kernel_name_set = set(map(lambda x: x.replace("_epi_tma", ""), kernel_name_set))
847
+ for kernel_name in kernel_name_set:
848
+ file.write(kernel_name + "\n")
849
+
850
+ # Sort L0 and L1 kernel list and csv file to avoid mixing cutlass3.x kernels and sm120_mma kernels in cutlass2.x generated together.
851
+ if mode == "functional_L0" or mode == "functional_L1":
852
+ # Sort the .csv file
853
+ outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm.csv")
854
+ with open(outfile_name) as file:
855
+ data = file.readlines()
856
+ data.sort()
857
+ with open(outfile_name, 'w') as file:
858
+ for i in range(len(data)):
859
+ file.write(data[i])
860
+ # Sort the kernel list
861
+ kernel_filter_outfile_name = os.path.join(curr_build_dir, f"FK_{mode}_testlist_SM{arch}_cutlass3x_gemm_kernel_filter.list")
862
+ with open(kernel_filter_outfile_name) as file:
863
+ data = file.readlines()
864
+ data.sort()
865
+ with open(kernel_filter_outfile_name, 'w') as file:
866
+ for i in range(len(data)):
867
+ file.write(data[i])
868
+
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/gemm_operation.py ADDED
@@ -0,0 +1,1613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for emitting GEMM kernels
35
+ """
36
+
37
+ import collections
38
+ import enum
39
+ import functools
40
+ import logging
41
+ import operator
42
+ import os.path
43
+ import shutil
44
+
45
+ try:
46
+ import builtins
47
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
48
+ raise ImportError("Disabling attempt to import cutlass_library")
49
+ from cutlass_library.library import *
50
+ except ImportError:
51
+ from library import *
52
+
53
+ _LOGGER = logging.getLogger(__name__)
54
+
55
+ ###################################################################################################
56
+ #
57
+ # Data structure modeling a GEMM operation
58
+ #
59
+ ###################################################################################################
60
+
61
+ #
62
+ class GemmOperation:
63
+ #
64
+ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
65
+ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None,
66
+ kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto,
67
+ tile_scheduler = TileSchedulerType.Default, mixed_input_mode = None, mixed_input_shuffle = False,
68
+ ScaleFactorA = None, ScaleFactorB = None, ScaleFactorD = None,
69
+ ScaleFactorMVecSize = None, ScaleFactorNVecSize = None, ScaleFactorKVecSize = None):
70
+
71
+ kinds_3x = {
72
+ GemmKind.Universal3x,
73
+ GemmKind.SparseUniversal3x,
74
+ GemmKind.BlockScaledUniversal3x,
75
+ GemmKind.GroupedUniversal3x,
76
+ GemmKind.GroupedBlockScaledUniversal3x,
77
+ GemmKind.BlockwiseUniversal3x,
78
+ GemmKind.GroupedBlockwiseUniversal3x,
79
+ }
80
+ self.is_3x = gemm_kind in kinds_3x
81
+ self.prefix = "3x" if self.is_3x else ""
82
+ self.operation_kind = OperationKind.Gemm
83
+ self.arch = arch
84
+ self.tile_description = tile_description
85
+ self.gemm_kind = gemm_kind
86
+ self.A = A
87
+ self.B = B
88
+ self.C = C
89
+ self.D = D
90
+
91
+ if is_block_scaled(gemm_kind):
92
+ self.ScaleFactorA = ScaleFactorA
93
+ self.ScaleFactorB = ScaleFactorB
94
+ self.ScaleFactorD = ScaleFactorD["tensor"]
95
+ self.ScaleFactorVectorSize = ScaleFactorD["vector_size"]
96
+
97
+ if is_blockwise(gemm_kind):
98
+ self.ScaleFactorMVecSize = ScaleFactorMVecSize
99
+ self.ScaleFactorNVecSize = ScaleFactorNVecSize
100
+ self.ScaleFactorKVecSize = ScaleFactorKVecSize
101
+
102
+ if self.D == None:
103
+ self.D = self.C
104
+
105
+ if not self.is_3x:
106
+ assert(kernel_schedule == KernelScheduleType.ScheduleAuto)
107
+ assert(epilogue_schedule == EpilogueScheduleType.ScheduleAuto)
108
+ self.kernel_schedule = kernel_schedule
109
+ self.epilogue_schedule = epilogue_schedule
110
+ self.element_epilogue = element_epilogue
111
+ self.epilogue_functor = epilogue_functor
112
+
113
+ if self.is_3x and epilogue_functor == EpilogueFunctor.LinearCombination:
114
+ self.epilogue_functor = EpilogueFunctor3x.LinearCombination
115
+
116
+ self.swizzling_functor = swizzling_functor
117
+ self.tile_scheduler = tile_scheduler
118
+
119
+ # Only enable mixed input mode and mixed input shuffle for Hopper
120
+ self.mixed_input_mode = None
121
+ if self.is_mixed_input() and self.arch >= 90 and self.arch < 100:
122
+ self.mixed_input_mode = mixed_input_mode
123
+ self.mixed_input_shuffle = (self.mixed_input_mode is not None) and mixed_input_shuffle
124
+
125
+ #
126
+ def is_complex(self):
127
+ complex_operators = [
128
+ MathOperation.multiply_add_complex,
129
+ MathOperation.multiply_add_complex_gaussian,
130
+ MathOperation.multiply_add_complex_fast_f32
131
+ ]
132
+ return self.tile_description.math_instruction.math_operation in complex_operators
133
+
134
+ #
135
+ def is_mixed_input(self):
136
+ return self.A.element != self.B.element
137
+
138
+ #
139
+ def is_planar_complex(self):
140
+ return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
141
+
142
+ #
143
+ def accumulator_type(self):
144
+ accum = self.tile_description.math_instruction.element_accumulator
145
+
146
+ if self.is_complex():
147
+ return get_complex_from_real(accum)
148
+
149
+ return accum
150
+
151
+ #
152
+ def short_math_name(self):
153
+ if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
154
+ return "g%s" % ShortDataTypeNames[self.accumulator_type()]
155
+ return ShortDataTypeNames[self.accumulator_type()]
156
+
157
+
158
+ #
159
+ def core_name(self):
160
+ ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
161
+
162
+ inst_shape = ''
163
+ inst_operation = ''
164
+ intermediate_type = ''
165
+
166
+ math_operations_map = {
167
+ MathOperation.xor_popc: 'xor',
168
+ MathOperation.and_popc: 'and',
169
+ MathOperation.multiply_add_fast_accum: 'fastaccum',
170
+ }
171
+
172
+ tensor_ops = [
173
+ OpcodeClass.TensorOp,
174
+ OpcodeClass.WmmaTensorOp,
175
+ OpcodeClass.SparseTensorOp,
176
+ OpcodeClass.BlockScaledTensorOp,
177
+ ]
178
+
179
+ is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops
180
+
181
+ if is_tensor_op:
182
+
183
+ math_op = self.tile_description.math_instruction.math_operation
184
+ math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
185
+
186
+ inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) if not self.is_3x else ""
187
+
188
+ inst_shape += math_op_string
189
+
190
+ if self.tile_description.math_instruction.element_a != self.A.element and \
191
+ self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
192
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
193
+
194
+ short_math_name = self.short_math_name() if not self.is_3x else ""
195
+
196
+ return "%s%s%s%s" % (short_math_name, inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
197
+
198
+ # Generates a string representing the MMA instruction.
199
+ def extended_name(self):
200
+ ''' Append data types if they differ from compute type. '''
201
+ element_sfa = ""
202
+ element_sfb = ""
203
+ if self.is_complex():
204
+ extended_name = "${core_name}"
205
+ else:
206
+ if self.is_mixed_input():
207
+ extended_name = "${core_name}_${element_a}_${element_b}"
208
+ if self.C.element != self.tile_description.math_instruction.element_accumulator:
209
+ extended_name = "${element_c}_" + extended_name
210
+ elif is_blockwise(self.gemm_kind):
211
+ extended_name = "${core_name}_${element_sfa}x${element_a}_${element_sfb}x${element_b}"
212
+ element_sfa = DataTypeNames[self.accumulator_type()]
213
+ element_sfb = DataTypeNames[self.accumulator_type()]
214
+ else:
215
+ extended_name = "${core_name}"
216
+ if self.C.element != self.tile_description.math_instruction.element_accumulator:
217
+ extended_name = "${element_c}_" + extended_name
218
+ if self.A.element != self.tile_description.math_instruction.element_accumulator:
219
+ extended_name += "_${element_a}"
220
+
221
+ extended_name = SubstituteTemplate(extended_name, {
222
+ 'element_a': DataTypeNames[self.A.element],
223
+ 'element_sfa' : element_sfa,
224
+ 'element_b': DataTypeNames[self.B.element],
225
+ 'element_sfb' : element_sfb,
226
+ 'element_c': DataTypeNames[self.C.element],
227
+ 'core_name': self.core_name()
228
+ })
229
+
230
+ return extended_name
231
+
232
+ #
233
+ def mixed_input_mode_name(self):
234
+ mode_name_mapping = {
235
+ MixedInputMode.ConvertOnly: "_cvt",
236
+ MixedInputMode.ScaleOnly: "_scl",
237
+ MixedInputMode.ScaleWithZeroPoint: "_sclzr"
238
+ }
239
+ mode_name = mode_name_mapping.get(self.mixed_input_mode, "")
240
+ if self.mixed_input_shuffle:
241
+ mode_name = mode_name + "_shfl"
242
+ return mode_name
243
+
244
+ def extended_name_3x(self):
245
+ '''Generates a string representing the MMA atom. Assumes accumulator type is C type.'''
246
+ extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
247
+ element_a = DataTypeNames[self.A.element],
248
+ element_b = DataTypeNames[self.B.element],
249
+ element_acc = DataTypeNames[self.accumulator_type()],
250
+ element_c = DataTypeNames[self.C.element],
251
+ element_d = DataTypeNames[self.D.element],
252
+ core_name = self.core_name())
253
+
254
+ if is_block_scaled(self.gemm_kind):
255
+ d_type_names = DataTypeNames[self.D.element]
256
+
257
+ if self.ScaleFactorD.element != DataType.void:
258
+ d_type_names = DataTypeNames[self.ScaleFactorD.element] + "x" + d_type_names
259
+
260
+ extended_name = "{core_name}_{element_sfa}x{element_a}_{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
261
+ element_sfa = DataTypeNames[self.ScaleFactorA],
262
+ element_a = DataTypeNames[self.A.element],
263
+ element_sfb = DataTypeNames[self.ScaleFactorB],
264
+ element_b = DataTypeNames[self.B.element],
265
+ element_acc = DataTypeNames[self.accumulator_type()],
266
+ element_c = DataTypeNames[self.C.element],
267
+ element_d = d_type_names,
268
+ core_name = self.core_name())
269
+
270
+ if is_blockwise(self.gemm_kind):
271
+ d_type_names = DataTypeNames[self.D.element]
272
+
273
+ extended_name = "{core_name}_{sfvec_m_size}x{sfvec_k_size}{element_sfa}x{element_a}_{sfvec_n_size}x{sfvec_k_size}{element_sfb}x{element_b}_{element_acc}_{element_c}_{element_d}".format(
274
+ element_sfa = DataTypeNames[self.accumulator_type()],
275
+ element_a = DataTypeNames[self.A.element],
276
+ element_sfb = DataTypeNames[self.accumulator_type()],
277
+ element_b = DataTypeNames[self.B.element],
278
+ element_acc = DataTypeNames[self.accumulator_type()],
279
+ element_c = DataTypeNames[self.C.element],
280
+ element_d = d_type_names,
281
+ sfvec_m_size = self.ScaleFactorMVecSize,
282
+ sfvec_n_size = self.ScaleFactorNVecSize,
283
+ sfvec_k_size = self.ScaleFactorKVecSize,
284
+ core_name = self.core_name())
285
+
286
+ if self.mixed_input_mode != None:
287
+ extended_name = extended_name + self.mixed_input_mode_name()
288
+ return extended_name
289
+
290
+ def datatype_name_3x(self):
291
+ '''Generates a string representing the MMA atom. Assumes accumulator type is C type.'''
292
+ datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
293
+ element_a = DataTypeNames[self.A.element],
294
+ element_b = DataTypeNames[self.B.element],
295
+ element_acc = DataTypeNames[self.accumulator_type()],
296
+ element_c = DataTypeNames[self.C.element],
297
+ element_d = DataTypeNames[self.D.element])
298
+ return datatype_name
299
+
300
+ # Generates a short string representing the AB layout tags (e.g. nt or tn)
301
+ def layout_name(self):
302
+ if self.is_complex() or self.is_planar_complex():
303
+ return "%s%s" % (
304
+ ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
305
+ ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
306
+ )
307
+ return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
308
+
309
+ # Generates a short string representing the ABC layout tags (e.g. ntn or tnn)
310
+ def layout_name_3x(self):
311
+ if self.is_complex() or self.is_planar_complex():
312
+ return "{}{}{}".format(
313
+ ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
314
+ ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
315
+ ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)])
316
+ else:
317
+ return "{}{}{}".format(
318
+ ShortLayoutTypeNames[self.A.layout],
319
+ ShortLayoutTypeNames[self.B.layout],
320
+ ShortLayoutTypeNames[self.C.layout])
321
+
322
+ # Generates a short string representing underlying kernel schedule type
323
+ def kernel_schedule_name_3x(self):
324
+ return KernelScheduleSuffixes[self.kernel_schedule]
325
+
326
+ # Generates a short string representing underlying epilogue schedule type
327
+ def epilogue_schedule_name_3x(self):
328
+
329
+ if is_block_scaled(self.gemm_kind):
330
+ if self.ScaleFactorD.element != DataType.void:
331
+ return EpilogueScheduleSuffixes[self.epilogue_schedule] + "_epiVs" + str(self.ScaleFactorVectorSize)+ShortLayoutTypeNames[self.ScaleFactorD.layout]
332
+
333
+ return EpilogueScheduleSuffixes[self.epilogue_schedule]
334
+
335
+ # Generate a short string representing the operation class
336
+ def opcode_class_name(self):
337
+ return OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
338
+
339
+ def get_collective_tile_shape(self):
340
+ """
341
+ Get the tile shape passed to the collective builder.
342
+ On Blackwell, this is different than the operation.tile_description.tile_shape.
343
+ """
344
+ is_sm100_kernel = (self.arch == 100 or self.arch == 103)
345
+ if not is_sm100_kernel:
346
+ return self.tile_description.tile_shape
347
+
348
+ opcode_class_main = self.tile_description.math_instruction.opcode_class
349
+ instruction_shape = self.tile_description.math_instruction.instruction_shape
350
+ tile_shape_m, tile_shape_n, tile_shape_k = self.tile_description.tile_shape
351
+ if opcode_class_main in [OpcodeClass.TensorOp, OpcodeClass.BlockScaledTensorOp, OpcodeClass.SparseTensorOp]:
352
+ tile_shape_m = instruction_shape[0]
353
+ tile_shape_n = instruction_shape[1]
354
+ return (tile_shape_m, tile_shape_n, tile_shape_k)
355
+
356
+ # Generates the full kernel function name
357
+ def procedural_name(self):
358
+ return self._procedural_name
359
+
360
+ @functools.cached_property
361
+ def _procedural_name(self):
362
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
363
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
364
+ if self.arch >= 90:
365
+ kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}{ct}{cs}_{l}_{s}_align{al}{t}{k}{e}"
366
+ tile_shape = self.get_collective_tile_shape()
367
+ return kernel_name_template.format(
368
+ p = self.prefix,
369
+ ar = self.arch,
370
+ op = opcode_class_name,
371
+ ex = self.extended_name_3x(),
372
+ ct = '_' + 'x'.join([str(i) for i in tile_shape]) if tile_shape[0] > 0 else "",
373
+ cs = '_' + 'x'.join([str(i) for i in self.tile_description.cluster_shape]),
374
+ l = self.tile_description.stages,
375
+ s = self.layout_name_3x(),
376
+ al = str(max(self.A.alignment, self.B.alignment)),
377
+ t = TileSchedulerSuffixes[self.tile_scheduler],
378
+ k = self.kernel_schedule_name_3x(),
379
+ e = self.epilogue_schedule_name_3x())
380
+ else:
381
+ threadblock = self.tile_description.procedural_name()
382
+ return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
383
+ p = self.prefix,
384
+ op = opcode_class_name,
385
+ ex = self.extended_name(),
386
+ tb = threadblock,
387
+ l = self.layout_name(),
388
+ a = str(max(self.A.alignment, self.B.alignment)))
389
+
390
+ #
391
+ def configuration_name(self):
392
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
393
+ return self.procedural_name()
394
+
395
+ def __hash__(self):
396
+ return hash(self.configuration_name())
397
+
398
+ def __eq__(self, other):
399
+ return self.configuration_name() == other.configuration_name()
400
+
401
+ ###################################################################################################
402
+ #
403
+ # Data structure modeling a grouped GEMM operation
404
+ #
405
+ ###################################################################################################
406
+
407
+ #
408
+ class GroupedGemmOperation(GemmOperation):
409
+ #
410
+ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
411
+ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
412
+ scheduler_mode = GroupScheduleMode.Device):
413
+ super().__init__(gemm_kind, arch, tile_description, A, B, C, element_epilogue, \
414
+ epilogue_functor, swizzling_functor)
415
+
416
+ self.scheduler_mode = scheduler_mode
417
+
418
+ #
419
+ def procedural_name(self):
420
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
421
+ base = super().procedural_name()
422
+ return SubstituteTemplate(
423
+ base + "_schedule${schedule}",
424
+ {
425
+ 'schedule': ShortGroupScheduleModeNames[self.scheduler_mode]
426
+ })
427
+
428
+
429
+ ###################################################################################################
430
+ #
431
+ # Emits single instances of a CUTLASS device-wide operator
432
+ #
433
+ ###################################################################################################
434
+
435
+ #
436
+ class EmitGemmInstance:
437
+ ''' Responsible for emitting a CUTLASS template definition'''
438
+
439
+ def __init__(self, operation_suffix = ''):
440
+ self.operation_suffix = operation_suffix
441
+ self.includes = []
442
+ self.gemm_template = """
443
+ // Gemm operator ${operation_name}
444
+ using Operation_${operation_name} = cutlass::gemm::device::Gemm<
445
+ ${element_a}, ${layout_a},
446
+ ${element_b}, ${layout_b},
447
+ ${element_c}, ${layout_c},
448
+ ${element_accumulator},
449
+ ${opcode_class},
450
+ ${arch},
451
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
452
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
453
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
454
+ ${epilogue_functor}<
455
+ ${element_c},
456
+ ${epilogue_vector_length},
457
+ ${element_accumulator},
458
+ ${element_epilogue}
459
+ >,
460
+ ${swizzling_functor},
461
+ ${stages},
462
+ ${align_a},
463
+ ${align_b},
464
+ false,
465
+ ${math_operation}
466
+ ${residual}
467
+ >;
468
+ """
469
+ self.gemm_complex_template = """
470
+ // Gemm operator ${operation_name}
471
+ using Operation_${operation_name} = cutlass::gemm::device::GemmComplex<
472
+ ${element_a}, ${layout_a},
473
+ ${element_b}, ${layout_b},
474
+ ${element_c}, ${layout_c},
475
+ ${element_accumulator},
476
+ ${opcode_class},
477
+ ${arch},
478
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
479
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
480
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
481
+ ${epilogue_functor}<
482
+ ${element_c},
483
+ ${epilogue_vector_length},
484
+ ${element_accumulator},
485
+ ${element_epilogue}
486
+ >,
487
+ ${swizzling_functor},
488
+ ${stages},
489
+ ${transform_a},
490
+ ${transform_b},
491
+ ${math_operation}
492
+ ${residual}
493
+ >;
494
+ """
495
+
496
+ #
497
+ def instance_template(self):
498
+ return """
499
+ ${compile_guard_start}
500
+ manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
501
+ ${compile_guard_end}
502
+ """
503
+
504
+ #
505
+ def emit(self, operation):
506
+
507
+ warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
508
+
509
+ epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
510
+
511
+ residual = ''
512
+
513
+ values = {
514
+ 'operation_name': operation.procedural_name(),
515
+ 'element_a': DataTypeTag[operation.A.element],
516
+ 'layout_a': LayoutTag[operation.A.layout],
517
+ 'element_b': DataTypeTag[operation.B.element],
518
+ 'layout_b': LayoutTag[operation.B.layout],
519
+ 'element_c': DataTypeTag[operation.C.element],
520
+ 'layout_c': LayoutTag[operation.C.layout],
521
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
522
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
523
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
524
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
525
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
526
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
527
+ 'warp_shape_m': str(warp_shape[0]),
528
+ 'warp_shape_n': str(warp_shape[1]),
529
+ 'warp_shape_k': str(warp_shape[2]),
530
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
531
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
532
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
533
+ 'epilogue_vector_length': str(epilogue_vector_length),
534
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
535
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
536
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
537
+ 'stages': str(operation.tile_description.stages),
538
+ 'align_a': str(operation.A.alignment),
539
+ 'align_b': str(operation.B.alignment),
540
+ 'transform_a': ComplexTransformTag[operation.A.complex_transform],
541
+ 'transform_b': ComplexTransformTag[operation.B.complex_transform],
542
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
543
+ 'residual': residual
544
+ }
545
+
546
+ template = self.gemm_complex_template if operation.is_complex() else self.gemm_template
547
+
548
+ return SubstituteTemplate(template, values)
549
+
550
+ ###################################################################################################
551
+
552
+ class EmitSparseGemmInstance:
553
+ ''' Responsible for emitting a CUTLASS template definition'''
554
+
555
+ def __init__(self, operation_suffix = ''):
556
+ self.operation_suffix = operation_suffix
557
+ self.includes = []
558
+ self.gemm_template = """
559
+ // Gemm operator ${operation_name}
560
+ using Operation_${operation_name} = cutlass::gemm::device::SparseGemm<
561
+ ${element_a}, ${layout_a},
562
+ ${element_b}, ${layout_b},
563
+ ${element_c}, ${layout_c},
564
+ ${element_accumulator},
565
+ ${opcode_class},
566
+ ${arch},
567
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
568
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
569
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
570
+ ${epilogue_functor}<
571
+ ${element_c},
572
+ ${epilogue_vector_length},
573
+ ${element_accumulator},
574
+ ${element_epilogue}
575
+ >,
576
+ ${swizzling_functor},
577
+ ${stages},
578
+ ${align_a},
579
+ ${align_b},
580
+ false,
581
+ ${math_operation}
582
+ ${residual}
583
+ >;
584
+ """
585
+
586
+ #
587
+ def instance_template(self):
588
+ return """
589
+ ${compile_guard_start}
590
+ manifest.append(new ${gemm_kind}<Operation_${operation_name}>("${operation_name}"));
591
+ ${compile_guard_end}
592
+ """
593
+
594
+ #
595
+ def emit(self, operation):
596
+
597
+ warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
598
+
599
+ epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
600
+
601
+ residual = ''
602
+
603
+ values = {
604
+ 'operation_name': operation.procedural_name(),
605
+ 'element_a': DataTypeTag[operation.A.element],
606
+ 'layout_a': LayoutTag[operation.A.layout],
607
+ 'element_b': DataTypeTag[operation.B.element],
608
+ 'layout_b': LayoutTag[operation.B.layout],
609
+ 'element_c': DataTypeTag[operation.C.element],
610
+ 'layout_c': LayoutTag[operation.C.layout],
611
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
612
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
613
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
614
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
615
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
616
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
617
+ 'warp_shape_m': str(warp_shape[0]),
618
+ 'warp_shape_n': str(warp_shape[1]),
619
+ 'warp_shape_k': str(warp_shape[2]),
620
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
621
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
622
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
623
+ 'epilogue_vector_length': str(epilogue_vector_length),
624
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
625
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
626
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
627
+ 'stages': str(operation.tile_description.stages),
628
+ 'align_a': str(operation.A.alignment),
629
+ 'align_b': str(operation.B.alignment),
630
+ 'transform_a': ComplexTransformTag[operation.A.complex_transform],
631
+ 'transform_b': ComplexTransformTag[operation.B.complex_transform],
632
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
633
+ 'residual': residual
634
+ }
635
+
636
+ template = self.gemm_template
637
+
638
+ return SubstituteTemplate(template, values)
639
+
640
+ ###################################################################################################
641
+
642
+
643
+ #
644
+ class EmitGemmUniversalInstance:
645
+ ''' Responsible for emitting a CUTLASS template definition'''
646
+
647
+ def __init__(self, operation_suffix = ''):
648
+ self.operation_suffix = operation_suffix
649
+ self.includes = [
650
+ "cutlass/cutlass.h",
651
+ "cutlass/numeric_types.h",
652
+ "cutlass/arch/arch.h",
653
+ "cutlass/arch/mma.h",
654
+ "cutlass/layout/matrix.h",
655
+ "cutlass/gemm/device/gemm.h",
656
+ "cutlass/gemm/device/gemm_universal_adapter.h",
657
+ "cutlass/gemm/kernel/default_gemm_universal.h",
658
+ ]
659
+ self.builtin_epilogue_functor_template = """
660
+ ${epilogue_functor}<
661
+ ${element_c},
662
+ ${epilogue_vector_length},
663
+ ${element_accumulator},
664
+ ${element_epilogue}
665
+ >
666
+ """
667
+ self.gemm_template = """
668
+ // Gemm operator ${operation_name}
669
+ using ${operation_name}_base =
670
+ typename cutlass::gemm::kernel::DefaultGemmUniversal<
671
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b}, // transposed B operand
672
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a}, // transposed A operand
673
+ ${element_c}, ${layout_c},
674
+ ${element_accumulator},
675
+ ${opcode_class},
676
+ ${arch},
677
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
678
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
679
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
680
+ ${epilogue_functor},
681
+ ${swizzling_functor},
682
+ ${stages},
683
+ ${math_operation}
684
+ >::GemmKernel;
685
+
686
+ // Define named type
687
+ struct ${operation_name}${operation_suffix} :
688
+ public ${operation_name}_base { };
689
+ """
690
+ self.gemm_template_interleaved = """
691
+ // Gemm operator ${operation_name}
692
+ using ${operation_name}_base =
693
+ typename cutlass::gemm::kernel::DefaultGemmUniversal<
694
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
695
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
696
+ ${element_c}, ${layout_c},
697
+ ${element_accumulator},
698
+ ${opcode_class},
699
+ ${arch},
700
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
701
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
702
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
703
+ ${epilogue_functor},
704
+ ${swizzling_functor},
705
+ ${stages},
706
+ ${math_operation}
707
+ >::GemmKernel;
708
+
709
+ // Define named type
710
+ struct ${operation_name}${operation_suffix} :
711
+ public ${operation_name}_base { };
712
+ """
713
+
714
+ #
715
+ def instance_template(self):
716
+ return """
717
+ ${compile_guard_start}
718
+ manifest.append(new ${gemm_kind}<
719
+ cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
720
+ >("${operation_name}"));
721
+ ${compile_guard_end}
722
+ """
723
+
724
+ #
725
+ def emit(self, operation):
726
+
727
+ threadblock_shape = operation.tile_description.threadblock_shape
728
+ warp_count = operation.tile_description.warp_count
729
+
730
+ warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
731
+
732
+ transpose_layouts = {
733
+ LayoutType.ColumnMajor: LayoutType.RowMajor,
734
+ LayoutType.RowMajor: LayoutType.ColumnMajor
735
+ }
736
+
737
+ if operation.A.layout in transpose_layouts.keys() and \
738
+ operation.B.layout in transpose_layouts.keys() and \
739
+ operation.C.layout in transpose_layouts.keys():
740
+
741
+ instance_layout_A = transpose_layouts[operation.A.layout]
742
+ instance_layout_B = transpose_layouts[operation.B.layout]
743
+ instance_layout_C = transpose_layouts[operation.C.layout]
744
+
745
+ gemm_template = self.gemm_template
746
+ else:
747
+ instance_layout_A, instance_layout_B, instance_layout_C = \
748
+ (operation.A.layout, operation.B.layout, operation.C.layout)
749
+
750
+ gemm_template = self.gemm_template_interleaved
751
+ #
752
+
753
+ # Support built-in epilogue functors or user-defined functions
754
+ if isinstance(operation.epilogue_functor, enum.Enum):
755
+
756
+ epilogue_vector_length = \
757
+ min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
758
+
759
+ values = {
760
+ 'epilogue_vector_length': str(epilogue_vector_length),
761
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
762
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
763
+ }
764
+ epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
765
+ else:
766
+ epilogue_functor = self.epilogue_functor.emit_declaration()
767
+ #
768
+
769
+ values = {
770
+ 'operation_name': operation.procedural_name(),
771
+ 'operation_suffix': self.operation_suffix,
772
+ 'element_a': DataTypeTag[operation.A.element],
773
+ 'layout_a': LayoutTag[instance_layout_A],
774
+ 'element_b': DataTypeTag[operation.B.element],
775
+ 'layout_b': LayoutTag[instance_layout_B],
776
+ 'element_c': DataTypeTag[operation.C.element],
777
+ 'layout_c': LayoutTag[instance_layout_C],
778
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
779
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
780
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
781
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
782
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
783
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
784
+ 'warp_shape_m': str(warp_shape[0]),
785
+ 'warp_shape_n': str(warp_shape[1]),
786
+ 'warp_shape_k': str(warp_shape[2]),
787
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
788
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
789
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
790
+ 'epilogue_functor': epilogue_functor,
791
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
792
+ 'stages': str(operation.tile_description.stages),
793
+ 'align_a': str(operation.A.alignment),
794
+ 'align_b': str(operation.B.alignment),
795
+ 'transform_a': ComplexTransformTag[operation.A.complex_transform],
796
+ 'transform_b': ComplexTransformTag[operation.B.complex_transform],
797
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
798
+ }
799
+
800
+ return SubstituteTemplate(gemm_template, values)
801
+
802
+
803
+ ###################################################################################################
804
+
805
+ class EmitGemmUniversal3xInstance:
806
+ ''' Responsible for emitting a CUTLASS 3.x template definition'''
807
+
808
+ def __init__(self, operation_suffix = ''):
809
+ self.operation_suffix = operation_suffix
810
+ self.includes = [
811
+ "cutlass/cutlass.h",
812
+ "cutlass/gemm/gemm.h",
813
+ "cutlass/numeric_types.h",
814
+ "cutlass/gemm/kernel/gemm_universal.hpp",
815
+ "cutlass/gemm/collective/collective_builder.hpp",
816
+ "cutlass/epilogue/collective/collective_builder.hpp",
817
+ "cutlass/detail/blockwise_scale_layout.hpp",
818
+ ]
819
+ self.builtin_epilogue_functor_template = \
820
+ """${epilogue_functor}<
821
+ ${element_d},
822
+ ${element_epilogue},
823
+ ${element_c},
824
+ ${element_epilogue}
825
+ >"""
826
+
827
+ self.gemm_template = """
828
+
829
+ using ${operation_name}_epilogue =
830
+ typename cutlass::epilogue::collective::CollectiveBuilder<
831
+ ${arch}, ${opcode_class_epi},
832
+ cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
833
+ cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
834
+ ${epi_tile_mn},
835
+ ${element_accumulator}, ${element_epilogue},
836
+ ${element_c}, ${layout_c}, ${align_c},
837
+ ${element_d}, ${layout_d}, ${align_d},
838
+ ${epilogue_schedule},
839
+ ${epilogue_functor}
840
+ >::CollectiveOp;
841
+
842
+ ${mixed_dtype_prepare_code}
843
+ ${blockwise_prepare_code}
844
+
845
+ using ${operation_name}_mainloop =
846
+ typename cutlass::gemm::collective::CollectiveBuilder<
847
+ ${arch}, ${opcode_class_main},
848
+ ${element_a}, ${layout_a}, ${align_a},
849
+ ${element_b}, ${layout_b}, ${align_b},
850
+ ${element_accumulator},
851
+ cute::Shape<cute::_${tile_shape_m}, cute::_${tile_shape_n}, cute::_${tile_shape_k}>,
852
+ cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>,
853
+ ${stages},
854
+ ${kernel_schedule}
855
+ >::CollectiveOp;
856
+
857
+ // Gemm operator ${operation_name}
858
+ using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
859
+ ${problem_shape},
860
+ ${operation_name}_mainloop,
861
+ ${operation_name}_epilogue,
862
+ ${tile_scheduler}>;
863
+
864
+ // Define named type
865
+ struct ${operation_name} :
866
+ public ${operation_name}_base { };
867
+
868
+ """
869
+ #
870
+ def instance_template(self):
871
+ return """
872
+ ${compile_guard_start}
873
+ {
874
+ using GemmKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>;
875
+ manifest.append(
876
+ new ${gemm_kind}<GemmKernel>("${operation_name}"));
877
+ }
878
+ ${compile_guard_end}
879
+ """
880
+
881
+
882
+ def emit_block_scale_epilogue_functor(self, operation):
883
+ block_scaled_template = """
884
+ ${epilogue_functor}<
885
+ ${epi_vs},
886
+ ${element_d},
887
+ ${element_accumulator},
888
+ ${element_sfd},
889
+ ${layout_sfd},
890
+ ${element_c},
891
+ ${element_scalar}
892
+ >
893
+ """
894
+ block_scaled_values = {
895
+ 'epi_vs' : str(operation.ScaleFactorVectorSize),
896
+ 'element_d': str(DataTypeTag[operation.D.element]),
897
+ 'element_sfd': str(DataTypeTag[operation.ScaleFactorD.element]),
898
+ 'layout_sfd': LayoutTag[operation.ScaleFactorD.layout],
899
+ 'epilogue_functor': EpilogueFunctor3xTag[EpilogueFunctor3x.LinearCombinationBlockScaleFactor],
900
+ 'element_accumulator': str(DataTypeTag[operation.accumulator_type()]),
901
+ 'element_scalar': str(DataTypeTag[operation.accumulator_type()]),
902
+ 'element_c': str(DataTypeTag[operation.C.element]),
903
+ }
904
+ return SubstituteTemplate(block_scaled_template, block_scaled_values)
905
+
906
+
907
+ @staticmethod
908
+ def pointerize_if_grouped(operation, layout):
909
+ return layout if not is_grouped(operation.gemm_kind) else layout + "* "
910
+
911
+ @staticmethod
912
+ def transform_layout_A_if_blockwise(operation, layout):
913
+ layout_sfa = f"{operation.procedural_name()}_LayoutSFA"
914
+ layout_sfa = layout_sfa if not is_grouped(operation.gemm_kind) else layout_sfa + "* "
915
+ return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfa}>"
916
+
917
+ @staticmethod
918
+ def transform_layout_B_if_blockwise(operation, layout):
919
+ layout_sfb = f"{operation.procedural_name()}_LayoutSFB"
920
+ layout_sfb = layout_sfb if not is_grouped(operation.gemm_kind) else layout_sfb + "* "
921
+ return layout if not is_blockwise(operation.gemm_kind) else f"cute::tuple<{layout}, {layout_sfb}>"
922
+
923
+ @staticmethod
924
+ def problem_shape(operation):
925
+ gemm_shape_type = "cute::Shape<int,int,int,int>"
926
+ grouped_gemm_shape_type = "cute::Shape<int,int,int>"
927
+ grouped_gemm_shape_type = "cutlass::gemm::GroupProblemShape<" + grouped_gemm_shape_type + ">"
928
+
929
+ return gemm_shape_type if not is_grouped(operation.gemm_kind) else grouped_gemm_shape_type
930
+
931
+ def emit(self, operation):
932
+ _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
933
+ _LOGGER.debug("*** operation.procedural_name(): " + operation.procedural_name())
934
+ _LOGGER.debug("*** tile_shape: " + str(operation.tile_description.tile_shape))
935
+ _LOGGER.debug("*** warp_count: " + str(operation.tile_description.warp_count))
936
+
937
+ opcode_class_main = operation.tile_description.math_instruction.opcode_class
938
+ opcode_class_epi = opcode_class_main
939
+
940
+ tile_shape = operation.tile_description.tile_shape
941
+ instruction_shape = operation.tile_description.math_instruction.instruction_shape
942
+ cluster_m = operation.tile_description.cluster_shape[0]
943
+ cluster_n = operation.tile_description.cluster_shape[1]
944
+ cta_n = tile_shape[1] // cluster_n if cluster_n > 0 else tile_shape[1]
945
+ tile_shape_m, tile_shape_n, tile_shape_k = operation.get_collective_tile_shape()
946
+
947
+ # stage count set to zero indicates builder automatic stage selection
948
+ if operation.tile_description.stages > 0:
949
+ stage_count_string = f"cutlass::gemm::collective::StageCount<{str(operation.tile_description.stages)}>"
950
+ elif opcode_class_main == OpcodeClass.SparseTensorOp and operation.arch == 100:
951
+ stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveoutEpi<{str(operation.procedural_name())}_epilogue>"
952
+ else:
953
+ stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>"
954
+
955
+ epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto"
956
+
957
+ instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \
958
+ (operation.A.layout, operation.B.layout, operation.C.layout, operation.D.layout)
959
+
960
+ # 3.0 profiler integration only supports trivial epilogues for now
961
+ epilogue_vector_length = 1
962
+
963
+ # Support built-in epilogue functors or user-defined functions
964
+ if isinstance(operation.epilogue_functor, enum.Enum):
965
+ values = {
966
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
967
+ 'epilogue_functor': EpilogueFunctor3xTag[operation.epilogue_functor],
968
+ }
969
+ epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
970
+
971
+ if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
972
+ epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
973
+
974
+
975
+ else:
976
+ epilogue_functor = self.epilogue_functor.emit_declaration()
977
+
978
+ if is_block_scaled(operation.gemm_kind) and operation.ScaleFactorD.element != DataType.void:
979
+ epilogue_functor = self.emit_block_scale_epilogue_functor(operation)
980
+
981
+ #
982
+ # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple<Element, Transform>, Transform : cute::identity / cute::conjugate.
983
+ element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>"
984
+ element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>"
985
+ epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule]
986
+
987
+ if opcode_class_main == OpcodeClass.BlockScaledTensorOp:
988
+ grouped = is_grouped(operation.gemm_kind)
989
+ if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100, grouped):
990
+ epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
991
+ if is_tma_epilogue(operation.epilogue_schedule):
992
+ epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]
993
+ if cta_n == 256 and operation.kernel_schedule == to_grouped_schedule(KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100, grouped):
994
+ epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
995
+ if is_tma_epilogue(operation.epilogue_schedule):
996
+ epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
997
+ # SM103 FP4 Ultra
998
+ is_sm103_fp4_ultra_1sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103, grouped),
999
+ to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103, grouped),
1000
+ to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch, grouped),
1001
+ to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch, grouped),
1002
+ to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch, grouped),
1003
+ to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch, grouped)
1004
+ ]
1005
+ is_sm103_fp4_ultra_2sm_kernel_schedule = operation.kernel_schedule in [to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103, grouped),
1006
+ to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103, grouped),
1007
+ to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch, grouped),
1008
+ to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch, grouped),
1009
+ to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch, grouped),
1010
+ to_grouped_schedule(KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch, grouped)
1011
+ ]
1012
+ if cta_n == 256 and is_sm103_fp4_ultra_1sm_kernel_schedule:
1013
+ epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
1014
+ if is_tma_epilogue(operation.epilogue_schedule):
1015
+ epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized1Sm, grouped)]
1016
+ if cta_n == 256 and is_sm103_fp4_ultra_2sm_kernel_schedule:
1017
+ epi_tile_mn = "cute::Shape<cute::_128,cute::_64>"
1018
+ if is_tma_epilogue(operation.epilogue_schedule):
1019
+ epilogue_schedule_type = EpilogueScheduleTag[to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized2Sm, grouped)]
1020
+
1021
+ element_a = f'cute::tuple<{str(element_a)},{str(DataTypeTag[operation.ScaleFactorA])}>'
1022
+ element_b = f'cute::tuple<{str(element_b)},{str(DataTypeTag[operation.ScaleFactorB])}>'
1023
+
1024
+ alignment_c = get_tma_alignment(operation.C.element) \
1025
+ if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \
1026
+ else operation.C.alignment
1027
+ alignment_d = get_tma_alignment(operation.D.element) \
1028
+ if is_tma_epilogue(operation.epilogue_schedule) and opcode_class_epi != OpcodeClass.Simt \
1029
+ else operation.D.alignment
1030
+
1031
+ operation_name_str = operation.procedural_name()
1032
+ layout_a_str = LayoutTag[instance_layout_A]
1033
+ layout_b_str = LayoutTag[instance_layout_B]
1034
+ mixed_dtype_prepare_code = ""
1035
+ if operation.mixed_input_mode != None:
1036
+ A_dtype = operation.A.element
1037
+ B_dtype = operation.B.element
1038
+ A_dtype_bits = DataTypeSize[A_dtype]
1039
+ B_dtype_bits = DataTypeSize[B_dtype]
1040
+ is_A_dtype_narrow = A_dtype_bits < B_dtype_bits
1041
+ if is_A_dtype_narrow:
1042
+ narrow_dtype, wide_dtype = (A_dtype, B_dtype)
1043
+ narrow_dtype_bits, wide_dtype_bits = (A_dtype_bits, B_dtype_bits)
1044
+ else:
1045
+ narrow_dtype, wide_dtype = (B_dtype, A_dtype)
1046
+ narrow_dtype_bits, wide_dtype_bits = (B_dtype_bits, A_dtype_bits)
1047
+
1048
+ narrow_tag = DataTypeTag[narrow_dtype]
1049
+ wide_tag = DataTypeTag[wide_dtype]
1050
+ scale_tag = DataTypeTag[wide_dtype]
1051
+ zero_tag = DataTypeTag[wide_dtype]
1052
+
1053
+ do_shuffle = False
1054
+ value_shuffle_str = ""
1055
+ if narrow_dtype_bits == 4 and wide_dtype_bits == 16:
1056
+ value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_4>, cute::Stride<cute::_4,cute::_1>>"
1057
+ do_shuffle = True
1058
+ if narrow_dtype_bits == 8 and wide_dtype_bits == 16:
1059
+ value_shuffle_str = "cute::Layout<cute::Shape<cute::_2,cute::_2>, cute::Stride<cute::_2,cute::_1>>"
1060
+ do_shuffle = True
1061
+ do_shuffle = operation.mixed_input_shuffle and do_shuffle
1062
+
1063
+ if do_shuffle:
1064
+ if is_A_dtype_narrow:
1065
+ stride_narrow_str = f"cutlass::detail::TagToStrideA_t<{layout_a_str}>"
1066
+ layout_a_str = f"{operation_name_str}_LayoutNarrowReordered"
1067
+ else:
1068
+ stride_narrow_str = f"cutlass::detail::TagToStrideB_t<{layout_b_str}>"
1069
+ layout_b_str = f"{operation_name_str}_LayoutNarrowReordered"
1070
+ # The {operation_name_str}_ prefixs in mixed_dtype_prepare_code and
1071
+ # layout_{a, b}_str are to prevent errors in Windows platform unity build
1072
+ mixed_dtype_prepare_code = f"""
1073
+ using {operation_name_str}_StrideNarrow = {stride_narrow_str};
1074
+ using {operation_name_str}_ValueShuffle = {value_shuffle_str};
1075
+ static constexpr int {operation_name_str}_NumShuffleAtoms = 1;
1076
+ using {operation_name_str}_MmaAtomShape = cute::Layout<cute::Shape<cute::_1, cute::Int<{operation_name_str}_NumShuffleAtoms>>>;
1077
+ using {operation_name_str}_LayoutAtomQuant = decltype(cutlass::compute_memory_reordering_atom<{wide_tag}, {operation_name_str}_MmaAtomShape, {operation_name_str}_ValueShuffle>());
1078
+ using {operation_name_str}_LayoutNarrowReordered = decltype(cute::tile_to_shape({operation_name_str}_LayoutAtomQuant{{}}, cute::Layout<cute::Shape<int,int,int>, {operation_name_str}_StrideNarrow>{{}}));
1079
+ """
1080
+
1081
+ mixed_input_modes_to_element = {
1082
+ MixedInputMode.ConvertOnly: narrow_tag,
1083
+ MixedInputMode.ScaleOnly: f"cute::tuple<{narrow_tag}, {scale_tag}>",
1084
+ MixedInputMode.ScaleWithZeroPoint: f"cute::tuple<{narrow_tag}, {scale_tag}, {zero_tag}>"
1085
+ }
1086
+ narrow_element = mixed_input_modes_to_element.get(operation.mixed_input_mode, narrow_tag)
1087
+
1088
+ if narrow_dtype == DataType.s4 and (wide_dtype == DataType.e4m3 or wide_dtype == DataType.e5m2):
1089
+ narrow_element = f"cute::tuple<{narrow_tag}, cutlass::Array<{scale_tag}, 8>>"
1090
+
1091
+ if is_A_dtype_narrow:
1092
+ element_a = narrow_element
1093
+ else:
1094
+ element_b = narrow_element
1095
+
1096
+ blockwise_prepare_code = ""
1097
+ if is_blockwise(operation.gemm_kind):
1098
+ sfm_vec_size = operation.ScaleFactorMVecSize
1099
+ sfn_vec_size = operation.ScaleFactorNVecSize
1100
+ sfk_vec_size = operation.ScaleFactorKVecSize
1101
+ blockwise_prepare_code = f"""
1102
+ using {operation_name_str}_ScaleConfig = cutlass::detail::Sm{operation.arch}BlockwiseScaleConfig<{sfm_vec_size}, {sfn_vec_size}, {sfk_vec_size}>;
1103
+ using {operation_name_str}_LayoutSFA = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFA());
1104
+ using {operation_name_str}_LayoutSFB = decltype({operation_name_str}_ScaleConfig::deduce_layoutSFB());
1105
+ """
1106
+
1107
+ values = {
1108
+ 'operation_name': operation_name_str,
1109
+ 'operation_suffix': self.operation_suffix,
1110
+ 'problem_shape': self.problem_shape(operation),
1111
+ 'element_a': element_a,
1112
+ 'layout_a': self.transform_layout_A_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_a_str)),
1113
+ 'element_b': element_b,
1114
+ 'layout_b': self.transform_layout_B_if_blockwise(operation, self.pointerize_if_grouped(operation, layout_b_str)),
1115
+ 'element_c': DataTypeTag[operation.C.element],
1116
+ 'layout_c': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_C]),
1117
+ 'element_d': DataTypeTag[operation.D.element],
1118
+ 'layout_d': self.pointerize_if_grouped(operation, LayoutTag[instance_layout_D]),
1119
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
1120
+ 'opcode_class_main': OpcodeClassTag[opcode_class_main],
1121
+ 'opcode_class_epi': OpcodeClassTag[opcode_class_epi],
1122
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
1123
+ 'tile_shape_m': str(tile_shape_m),
1124
+ 'tile_shape_n': str(tile_shape_n),
1125
+ 'tile_shape_k': str(tile_shape_k),
1126
+ 'cluster_shape_m': 'cute::_' + str(operation.tile_description.cluster_shape[0]) if operation.tile_description.cluster_shape[0] > 0 else "int",
1127
+ 'cluster_shape_n': 'cute::_' + str(operation.tile_description.cluster_shape[1]) if operation.tile_description.cluster_shape[1] > 0 else "int",
1128
+ 'cluster_shape_k': 'cute::_' + str(operation.tile_description.cluster_shape[2]) if operation.tile_description.cluster_shape[2] > 0 else "int",
1129
+ 'instruction_shape_m': str(instruction_shape[0]),
1130
+ 'instruction_shape_n': str(instruction_shape[1]),
1131
+ 'instruction_shape_k': str(instruction_shape[2]),
1132
+ 'kernel_schedule' : str(KernelScheduleTag[operation.kernel_schedule]),
1133
+ 'epilogue_schedule' : str(epilogue_schedule_type),
1134
+ 'epi_tile_mn' : epi_tile_mn,
1135
+ 'epilogue_functor': epilogue_functor,
1136
+ 'stages': stage_count_string,
1137
+ 'align_a': str(operation.A.alignment),
1138
+ 'align_b': str(operation.B.alignment),
1139
+ 'align_c': str(alignment_c),
1140
+ 'align_d': str(alignment_d),
1141
+ 'transform_a': ComplexTransformTag[operation.A.complex_transform],
1142
+ 'transform_b': ComplexTransformTag[operation.B.complex_transform],
1143
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
1144
+ 'epilogue_vector_length': str(epilogue_vector_length),
1145
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
1146
+ 'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]),
1147
+ 'mixed_dtype_prepare_code': mixed_dtype_prepare_code,
1148
+ 'blockwise_prepare_code' : blockwise_prepare_code
1149
+ }
1150
+
1151
+ return SubstituteTemplate(self.gemm_template, values)
1152
+
1153
+ ###################################################################################################
1154
+
1155
+ #
1156
+ class EmitGemmPlanarComplexInstance:
1157
+ ''' Responsible for emitting a CUTLASS template definition'''
1158
+
1159
+ def __init__(self, operation_suffix = ''):
1160
+ self.operation_suffix = operation_suffix
1161
+ self.includes = []
1162
+ self.template = """
1163
+ // Gemm operator ${operation_name}
1164
+ using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
1165
+ ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
1166
+ ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
1167
+ ${element_c}, cutlass::layout::RowMajor,
1168
+ ${element_accumulator},
1169
+ ${opcode_class},
1170
+ ${arch},
1171
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1172
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1173
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1174
+ cutlass::epilogue::thread::LinearCombinationPlanarComplex<
1175
+ ${element_c},
1176
+ ${alignment_c},
1177
+ ${element_accumulator},
1178
+ ${element_epilogue}
1179
+ >,
1180
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
1181
+ ${stages},
1182
+ ${math_operator}
1183
+ >::GemmKernel;
1184
+
1185
+ struct ${operation_name} :
1186
+ public Operation_${operation_name} { };
1187
+ """
1188
+
1189
+ #
1190
+ def instance_template(self):
1191
+ return """
1192
+ ${compile_guard_start}
1193
+ manifest.append(new ${gemm_kind}<
1194
+ cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
1195
+ >("${operation_name}"));
1196
+ ${compile_guard_end}
1197
+ """
1198
+
1199
+ #
1200
+ def emit(self, operation):
1201
+
1202
+ warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
1203
+
1204
+ # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
1205
+ transposed_layout_A = TransposedLayout[operation.A.layout]
1206
+ transposed_layout_B = TransposedLayout[operation.B.layout]
1207
+
1208
+ values = {
1209
+ 'operation_name': operation.procedural_name(),
1210
+ 'element_a': DataTypeTag[operation.B.element],
1211
+ 'layout_a': LayoutTag[transposed_layout_B],
1212
+ 'transform_a': ComplexTransformTag[operation.B.complex_transform],
1213
+ 'alignment_a': str(operation.B.alignment),
1214
+ 'element_b': DataTypeTag[operation.A.element],
1215
+ 'layout_b': LayoutTag[transposed_layout_A],
1216
+ 'transform_b': ComplexTransformTag[operation.A.complex_transform],
1217
+ 'alignment_b': str(operation.A.alignment),
1218
+ 'element_c': DataTypeTag[operation.C.element],
1219
+ 'layout_c': LayoutTag[operation.C.layout],
1220
+ 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
1221
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
1222
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
1223
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
1224
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
1225
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
1226
+ 'warp_shape_m': str(warp_shape[0]),
1227
+ 'warp_shape_n': str(warp_shape[1]),
1228
+ 'warp_shape_k': str(warp_shape[2]),
1229
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
1230
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
1231
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
1232
+ 'alignment_c': str(operation.C.alignment),
1233
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
1234
+ 'stages': str(operation.tile_description.stages),
1235
+ 'math_operator': 'cutlass::arch::OpMultiplyAdd'
1236
+ }
1237
+
1238
+ return SubstituteTemplate(self.template, values)
1239
+
1240
+ ###################################################################################################
1241
+
1242
+ #
1243
+ class EmitGemmPlanarComplexArrayInstance:
1244
+ ''' Responsible for emitting a CUTLASS template definition'''
1245
+
1246
+ def __init__(self, operation_suffix = ''):
1247
+ self.operation_suffix = operation_suffix
1248
+ self.includes = []
1249
+ self.template = """
1250
+ // Gemm operator ${operation_name}
1251
+ using Operation_${operation_name} = typename cutlass::gemm::kernel::DefaultGemmPlanarComplexUniversal<
1252
+ ${element_a}, ${layout_a}, ${transform_a}, ${alignment_a},
1253
+ ${element_b}, ${layout_b}, ${transform_b}, ${alignment_b},
1254
+ ${element_c}, cutlass::layout::RowMajor,
1255
+ ${element_accumulator},
1256
+ ${opcode_class},
1257
+ ${arch},
1258
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1259
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1260
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1261
+ cutlass::epilogue::thread::LinearCombinationPlanarComplex<
1262
+ ${element_c},
1263
+ ${alignment_c},
1264
+ ${element_accumulator},
1265
+ ${element_epilogue}
1266
+ >,
1267
+ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
1268
+ ${stages},
1269
+ ${math_operator}
1270
+ >::GemmArrayKernel;
1271
+
1272
+ struct ${operation_name} : public Operation_${operation_name} { };
1273
+ """
1274
+
1275
+ #
1276
+ def instance_template(self):
1277
+ return """
1278
+ ${compile_guard_start}
1279
+ manifest.append(new ${gemm_kind}<
1280
+ cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
1281
+ >("${operation_name}"));
1282
+ ${compile_guard_end}
1283
+ """
1284
+
1285
+ #
1286
+ def emit(self, operation):
1287
+
1288
+ warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)]
1289
+
1290
+ # exchange and transpose A and B types, layouts, and complex transforms since the C layout is row-major
1291
+ transposed_layout_A = TransposedLayout[operation.A.layout]
1292
+ transposed_layout_B = TransposedLayout[operation.B.layout]
1293
+
1294
+ values = {
1295
+ 'operation_name': operation.procedural_name(),
1296
+ 'element_a': DataTypeTag[operation.B.element],
1297
+ 'layout_a': LayoutTag[transposed_layout_B],
1298
+ 'transform_a': ComplexTransformTag[operation.B.complex_transform],
1299
+ 'alignment_a': str(operation.B.alignment),
1300
+ 'element_b': DataTypeTag[operation.A.element],
1301
+ 'layout_b': LayoutTag[transposed_layout_A],
1302
+ 'transform_b': ComplexTransformTag[operation.A.complex_transform],
1303
+ 'alignment_b': str(operation.A.alignment),
1304
+ 'element_c': DataTypeTag[operation.C.element],
1305
+ 'layout_c': LayoutTag[operation.C.layout],
1306
+ 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
1307
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
1308
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
1309
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
1310
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
1311
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
1312
+ 'warp_shape_m': str(warp_shape[0]),
1313
+ 'warp_shape_n': str(warp_shape[1]),
1314
+ 'warp_shape_k': str(warp_shape[2]),
1315
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
1316
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
1317
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
1318
+ 'alignment_c': str(operation.C.alignment),
1319
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
1320
+ 'stages': str(operation.tile_description.stages),
1321
+ 'math_operator': 'cutlass::arch::OpMultiplyAdd'
1322
+ }
1323
+
1324
+ return SubstituteTemplate(self.template, values)
1325
+
1326
+ ###################################################################################################
1327
+
1328
+ #
1329
+ class EmitGemmGroupedInstance:
1330
+ ''' Responsible for emitting a CUTLASS template definition'''
1331
+
1332
+ def __init__(self, operation_suffix = ''):
1333
+ self.operation_suffix = operation_suffix
1334
+ self.includes = [
1335
+ "cutlass/cutlass.h",
1336
+ "cutlass/numeric_types.h",
1337
+ "cutlass/arch/arch.h",
1338
+ "cutlass/arch/mma.h",
1339
+ "cutlass/layout/matrix.h",
1340
+ "cutlass/gemm/device/gemm.h",
1341
+ "cutlass/gemm/kernel/gemm_grouped.h",
1342
+ "cutlass/gemm/kernel/default_gemm_grouped.h",
1343
+ "cutlass/gemm/device/gemm_grouped.h"
1344
+ ]
1345
+ self.builtin_epilogue_functor_template = \
1346
+ """${epilogue_functor}<
1347
+ ${element_c},
1348
+ ${epilogue_vector_length},
1349
+ ${element_accumulator},
1350
+ ${element_epilogue}
1351
+ >"""
1352
+
1353
+ self.gemm_template = """
1354
+ // Gemm operator ${operation_name}
1355
+ using ${operation_name}_base =
1356
+ typename cutlass::gemm::kernel::DefaultGemmGrouped<
1357
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1358
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1359
+ ${element_c}, ${layout_c},
1360
+ ${element_accumulator},
1361
+ ${opcode_class},
1362
+ ${arch},
1363
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1364
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1365
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1366
+ ${epilogue_functor},
1367
+ ${swizzling_functor},
1368
+ ${stages},
1369
+ ${scheduler_mode},
1370
+ ${math_operation}
1371
+ >::GemmKernel;
1372
+
1373
+ // Define named type
1374
+ struct ${operation_name}${operation_suffix} :
1375
+ public ${operation_name}_base { };
1376
+ """
1377
+
1378
+ #
1379
+ def instance_template(self):
1380
+ return """
1381
+ ${compile_guard_start}
1382
+ manifest.append(new ${gemm_kind}<
1383
+ cutlass::gemm::device::GemmGrouped<${operation_name}>
1384
+ >("${operation_name}"));
1385
+ ${compile_guard_end}
1386
+ """
1387
+
1388
+ #
1389
+ def emit(self, operation):
1390
+
1391
+ threadblock_shape = operation.tile_description.threadblock_shape
1392
+ warp_count = operation.tile_description.warp_count
1393
+
1394
+ warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
1395
+
1396
+ transpose_layouts = {
1397
+ LayoutType.ColumnMajor: LayoutType.RowMajor,
1398
+ LayoutType.RowMajor: LayoutType.ColumnMajor
1399
+ }
1400
+
1401
+ instance_layout_A, instance_layout_B, instance_layout_C = \
1402
+ (operation.A.layout, operation.B.layout, operation.C.layout)
1403
+ #
1404
+
1405
+ # Support built-in epilogue functors or user-defined functions
1406
+ if isinstance(operation.epilogue_functor, enum.Enum):
1407
+
1408
+ epilogue_vector_length = \
1409
+ min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element]
1410
+
1411
+ values = {
1412
+ 'epilogue_vector_length': str(epilogue_vector_length),
1413
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
1414
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
1415
+ }
1416
+ epilogue_functor = SubstituteTemplate(self.builtin_epilogue_functor_template, values)
1417
+ else:
1418
+ epilogue_functor = self.epilogue_functor.emit_declaration()
1419
+ #
1420
+
1421
+ values = {
1422
+ 'operation_name': operation.procedural_name(),
1423
+ 'operation_suffix': self.operation_suffix,
1424
+ 'element_a': DataTypeTag[operation.A.element],
1425
+ 'layout_a': LayoutTag[instance_layout_A],
1426
+ 'element_b': DataTypeTag[operation.B.element],
1427
+ 'layout_b': LayoutTag[instance_layout_B],
1428
+ 'element_c': DataTypeTag[operation.C.element],
1429
+ 'layout_c': LayoutTag[instance_layout_C],
1430
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
1431
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
1432
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
1433
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
1434
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
1435
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
1436
+ 'warp_shape_m': str(warp_shape[0]),
1437
+ 'warp_shape_n': str(warp_shape[1]),
1438
+ 'warp_shape_k': str(warp_shape[2]),
1439
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
1440
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
1441
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
1442
+ 'epilogue_functor': epilogue_functor,
1443
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
1444
+ 'stages': str(operation.tile_description.stages),
1445
+ 'align_a': str(operation.A.alignment),
1446
+ 'align_b': str(operation.B.alignment),
1447
+ 'transform_a': ComplexTransformTag[operation.A.complex_transform],
1448
+ 'transform_b': ComplexTransformTag[operation.B.complex_transform],
1449
+ 'scheduler_mode': GroupScheduleModeTag[operation.scheduler_mode],
1450
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation]
1451
+ }
1452
+
1453
+ return SubstituteTemplate(self.gemm_template, values)
1454
+
1455
+ ###################################################################################################
1456
+ #
1457
+ # Emitters functions for all targets
1458
+ #
1459
+ ###################################################################################################
1460
+
1461
+ class EmitGemmConfigurationLibrary:
1462
+ def __init__(self, operation_path, configuration_name):
1463
+ self.configuration_name = configuration_name
1464
+ self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
1465
+
1466
+ self.instance_emitter = {
1467
+ GemmKind.Gemm: EmitGemmInstance,
1468
+ GemmKind.Sparse: EmitSparseGemmInstance,
1469
+ GemmKind.Universal: EmitGemmUniversalInstance,
1470
+ GemmKind.Universal3x: EmitGemmUniversal3xInstance,
1471
+ GemmKind.SparseUniversal3x: EmitGemmUniversal3xInstance,
1472
+ GemmKind.BlockScaledUniversal3x: EmitGemmUniversal3xInstance,
1473
+ GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance,
1474
+ GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance,
1475
+ GemmKind.Grouped: EmitGemmGroupedInstance,
1476
+ GemmKind.GroupedUniversal3x: EmitGemmUniversal3xInstance,
1477
+ GemmKind.GroupedBlockScaledUniversal3x: EmitGemmUniversal3xInstance,
1478
+ GemmKind.BlockwiseUniversal3x: EmitGemmUniversal3xInstance,
1479
+ GemmKind.GroupedBlockwiseUniversal3x: EmitGemmUniversal3xInstance,
1480
+ }
1481
+
1482
+ self.gemm_kind_wrappers = {
1483
+ GemmKind.Gemm: 'GemmOperation',
1484
+ GemmKind.Sparse: 'GemmSparseOperation',
1485
+ GemmKind.Universal: 'GemmUniversalOperation',
1486
+ GemmKind.Universal3x: 'GemmUniversal3xOperation',
1487
+ GemmKind.SparseUniversal3x: 'SparseGemmUniversal3xOperation',
1488
+ GemmKind.BlockScaledUniversal3x: 'BlockScaledGemmUniversal3xOperation',
1489
+ GemmKind.PlanarComplex: 'GemmPlanarComplexOperation',
1490
+ GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation',
1491
+ GemmKind.Grouped: 'GemmGroupedOperation',
1492
+ GemmKind.GroupedUniversal3x: 'GroupedGemmUniversal3xOperation',
1493
+ GemmKind.GroupedBlockScaledUniversal3x: 'GroupedBlockScaledGemmUniversal3xOperation',
1494
+ GemmKind.BlockwiseUniversal3x: 'BlockwiseGemmUniversal3xOperation',
1495
+ GemmKind.GroupedBlockwiseUniversal3x: 'GroupedBlockwiseGemmUniversal3xOperation',
1496
+ }
1497
+
1498
+ self.wmma_guard_start = "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
1499
+
1500
+ self.separator = """
1501
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1502
+
1503
+ """
1504
+
1505
+ self.header_template = """
1506
+ /*
1507
+ Generated by gemm_operation.py - Do not edit.
1508
+ */
1509
+ """
1510
+
1511
+ self.initialize_function_template = """
1512
+
1513
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1514
+
1515
+ namespace cutlass {
1516
+ namespace library {
1517
+
1518
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1519
+
1520
+ void initialize_${configuration_name}(Manifest &manifest) {
1521
+
1522
+ """
1523
+ self.epilogue_template = """
1524
+
1525
+ }
1526
+
1527
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1528
+
1529
+ } // namespace library
1530
+ } // namespace cutlass
1531
+
1532
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
1533
+
1534
+ """
1535
+
1536
+ def __enter__(self):
1537
+ _LOGGER.debug("*** EmitGemmConfigurationLibrary::__enter__")
1538
+ _LOGGER.debug("*** configuration_path (file to write): " +
1539
+ str(self.configuration_path))
1540
+
1541
+ self.configuration_file = open(self.configuration_path, "w")
1542
+ self.configuration_file.write(self.header_template)
1543
+ self.configuration_file.write(self.separator)
1544
+
1545
+ self.includes = collections.OrderedDict([
1546
+ ("cutlass/cutlass.h", None),
1547
+ ("cutlass/library/library.h", None),
1548
+ ("cutlass/library/manifest.h", None),
1549
+ ("library_internal.h", None),
1550
+ ("gemm_operation.h", None),
1551
+ ("gemm_operation_3x.hpp", None),
1552
+ ("grouped_gemm_operation_3x.hpp", None),
1553
+ ("sparse_gemm_operation_3x.hpp", None),
1554
+ ("block_scaled_gemm_operation_3x.hpp", None),
1555
+ ("blockwise_gemm_operation_3x.hpp", None),
1556
+ ("cutlass/arch/wmma.h", None),
1557
+ ("cutlass/numeric_types.h", None)
1558
+ ])
1559
+ self.instance_definitions = []
1560
+ self.instance_wrappers = []
1561
+
1562
+ self.operations = []
1563
+ return self
1564
+
1565
+ def emit(self, operation):
1566
+ _LOGGER.debug("*** EmitGemmConfigurationLibrary::emit(operation)")
1567
+ _LOGGER.debug("*** operation.gemm_kind: " + str(operation.gemm_kind))
1568
+
1569
+ emitter = self.instance_emitter[operation.gemm_kind]()
1570
+
1571
+ for incl in emitter.includes:
1572
+ self.includes[incl] = None
1573
+
1574
+ self.operations.append(operation)
1575
+
1576
+ self.instance_definitions.append(emitter.emit(operation))
1577
+
1578
+ self.instance_wrappers.append(SubstituteTemplate(emitter.instance_template(), {
1579
+ 'configuration_name': self.configuration_name,
1580
+ 'operation_name': operation.procedural_name(),
1581
+ 'gemm_kind': self.gemm_kind_wrappers[operation.gemm_kind],
1582
+ 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
1583
+ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
1584
+ 'compile_guard_end': "#endif" \
1585
+ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
1586
+ }))
1587
+
1588
+ def __exit__(self, exception_type, exception_value, traceback):
1589
+
1590
+ # Write includes
1591
+ for incl, _ in self.includes.items():
1592
+ include_statement = "#include \"%s\"\n" % incl
1593
+ self.configuration_file.write(include_statement)
1594
+
1595
+ self.configuration_file.write(self.separator)
1596
+
1597
+ # Write instance definitions in top-level namespace
1598
+ for instance_definition in self.instance_definitions:
1599
+ self.configuration_file.write(instance_definition)
1600
+
1601
+ # Add wrapper objects within initialize() function
1602
+ self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
1603
+ 'configuration_name': self.configuration_name
1604
+ }))
1605
+
1606
+ for instance_wrapper in self.instance_wrappers:
1607
+ self.configuration_file.write(instance_wrapper)
1608
+
1609
+ self.configuration_file.write(self.epilogue_template)
1610
+ self.configuration_file.close()
1611
+
1612
+ ###################################################################################################
1613
+ ###################################################################################################
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/generator.py ADDED
The diff for this file is too large to render. See raw diff
 
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for selecting CUTLASS library kernels based on problem description
35
+ """
36
+ import json
37
+ import csv
38
+
39
+ try:
40
+ import builtins
41
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
42
+ raise ImportError("Disabling attempt to import cutlass_library")
43
+ from cutlass_library.library import *
44
+ from cutlass_library.generator import *
45
+ from cutlass_library.heuristics_provider import *
46
+ except ImportError:
47
+ from library import *
48
+ from generator import *
49
+ from heuristics_provider import *
50
+
51
+ try:
52
+ from .sm90_utils import (
53
+ get_valid_schedules,
54
+ generate_data_types_from_math_instruction,
55
+ fix_alignments,
56
+ )
57
+ except ImportError:
58
+ from sm90_utils import (
59
+ get_valid_schedules,
60
+ generate_data_types_from_math_instruction,
61
+ fix_alignments,
62
+ )
63
+
64
+ _LOGGER = logging.getLogger(__name__)
65
+
66
+ dtype_map = {v: k for k, v in DataTypeNames.items()}
67
+
68
+ def serialize_heuristics_results_to_json(problems_with_configs, outfile_path):
69
+ """
70
+ Utilitiy function to write heuristics results to a json file for debug
71
+
72
+ args:
73
+ problems_with_configs: List of problems provided to the heuristic, with a list of operations added to each problem dict
74
+ outfile_path: Outfile path
75
+
76
+ returns:
77
+ None
78
+ """
79
+ pc_copy = problems_with_configs.copy()
80
+ for p in pc_copy:
81
+ for k, v in p.items():
82
+ if isinstance(v, DataType):
83
+ p[k] = DataTypeNames[v]
84
+ elif isinstance(v, LayoutType):
85
+ p[k] = ShortLayoutTypeNames[v]
86
+ configs = p['configs']
87
+ for c in configs:
88
+ for k, v in c.items():
89
+ if isinstance(v, DataType):
90
+ c[k] = DataTypeNames[v]
91
+ elif isinstance(v, LayoutType):
92
+ c[k] = ShortLayoutTypeNames[v]
93
+ with open(outfile_path, 'w') as f:
94
+ json.dump(pc_copy, f, indent=2)
95
+
96
+ def get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, voidC=False, use_fast_acc=True, count=1, provider=None):
97
+ """
98
+ Get heuristic-suggested GEMM kernel configurations for a single GEMM problem.
99
+
100
+ args:
101
+ m, n, k: GEMM dimensions
102
+ batch_count: batch count
103
+ layouts: tuple of layouts of type LayoutType
104
+ use_fast_acc: Use fast accumulation for FP8. Ignored for other precisions
105
+ count: Number of configs to return
106
+ provider: Heuristics provider to use
107
+
108
+ returns:
109
+ A list of dictionaries containing the suggested kernel configurations and additional info from the input required to define a Cutlass GemmOperation, with the following keys:
110
+ - 'cta_tile_m', 'cta_tile_m', 'cta_tile_k': CTA tile size
111
+ - 'instr_tile_m', 'instr_tile_n', 'instr_tile_k': Instruction tile size
112
+ - 'stages': kernel pipeline stage count
113
+ - 'cluster_m', 'cluster_n', 'cluster_k': cluster size
114
+ - 'layout_a', 'layout_b': input tensor layouts of type LayoutType
115
+ - 'alignment_a', 'alignment_b': input tensor alignments, in count of elements
116
+ - 'dtype_a', 'dtype_b', 'dtype_acc': dtypes of a, b, and accumulator, of type DataType
117
+ - 'swizzle_size' : suggested threadblock swizzle
118
+ - 'split_k_slices': number of partitions of the k dimension for splitK
119
+ - 'raster_order': raster order for CTAs over output tiles ('along_m' or 'along_n')
120
+ """
121
+ if provider is None:
122
+ provider = MatmulHeuristics()
123
+ return provider.get_configs(m, n, k, batch_count, dtypes, layouts, alignment_a, alignment_b, voidC=voidC, use_fast_acc=use_fast_acc, count=count)
124
+
125
+ def get_gemm_configs(problems, provider=None, count=1):
126
+ """
127
+ Get heuristic-suggested GEMM kernel configurations for a set of GEMM problems.
128
+
129
+ args:
130
+ problems: List of dictionaries describing GEMM problems with the following keys:
131
+ - 'm', 'n', 'k': Matrix dimensions (required)
132
+ - 'dtype_a': Data type of matrix A (required)
133
+ - 'dtype_b': Data type of matrix B (required)
134
+ - 'dtype_c': Data type of matrix C (default: None)
135
+ - 'dtype_d': Data type of matrix D (required)
136
+ - 'dtype_acc': Compute data type (default 'f32')
137
+ - 'layout': Operation layout (e.g. 'tnt')
138
+ - 'alignment_a': Memory access granularity of A, in units of elements (default: 16 bytes equivalent elements)
139
+ - 'alignment_b': Memory access granularity of B, in units of elements (default: 16 bytes equivalent elements)
140
+ - 'alpha': Scalar multiplier for A*B (default: 1.0)
141
+ - 'beta': Scalar multiplier for C (default: 0.0)
142
+ - 'batch_count': Number of GEMM operations in batch (default: 1)
143
+ - 'use_fast_acc': Enable fast accumulation for FP8 on Hopper (default: True)
144
+ provider: Heuristics provider to use
145
+ count: Number of configurations to return per problem (defualt: 1)
146
+
147
+ returns:
148
+ A copy of the input dictionary, with key `configs` added containing the selected gemm configs
149
+ """
150
+ ret = []
151
+
152
+ for problem in problems:
153
+ problem = problem.copy()
154
+
155
+ try:
156
+ m = problem['m']
157
+ n = problem['n']
158
+ k = problem['k']
159
+ dtype_a = problem['dtype_a']
160
+ dtype_b = problem['dtype_b']
161
+ dtype_d = problem['dtype_d']
162
+ layout = problem['layout']
163
+ except KeyError as e:
164
+ _LOGGER.error(f"Missing required parameter {e} for problem {problem}")
165
+ raise
166
+
167
+ operation = problem.get('operation', 'gemm')
168
+ batch_count = problem.get('batch_count', 1)
169
+ dtype_acc = problem.get('dtype_acc', 'f32')
170
+ dtype_c = problem.get('dtype_c', None)
171
+ alpha = problem.get('alpha', 1.0)
172
+ beta = problem.get('beta', 0.0)
173
+ use_fast_acc = problem.get('use_fast_acc', True)
174
+
175
+ if operation != OperationKindNames[OperationKind.Gemm]:
176
+ raise ValueError(f"Unsupported operation {operation}")
177
+ if not (len(layout) == 3 and all(c in "nt" for c in layout)):
178
+ raise ValueError(f"layout must be a 3-character string containing only 'n' or 't', got {layout}")
179
+ layouts = tuple(LayoutType.RowMajor if l == 't' else LayoutType.ColumnMajor for l in layout)
180
+
181
+ try:
182
+ dtype_list = [dtype_a.lower(), dtype_b.lower(), dtype_acc.lower(), dtype_c.lower() if dtype_c is not None else dtype_d.lower(), dtype_d.lower()]
183
+ dtypes = tuple(dtype_map[dt] for dt in dtype_list)
184
+ except KeyError as dt:
185
+ _LOGGER.error(f"Unsupported data type: {dt}")
186
+ raise
187
+
188
+ alignment_a = problem.get('alignment_a', 128 // DataTypeSize[dtypes[0]])
189
+ alignment_b = problem.get('alignment_b', 128 // DataTypeSize[dtypes[1]])
190
+
191
+ configs = get_single_gemm_config(m, n, k, batch_count, layouts, dtypes, alignment_a, alignment_b, beta==0.0, use_fast_acc, count, provider)
192
+ problem['configs'] = configs
193
+
194
+ ret.append(problem)
195
+
196
+ return ret
197
+
198
+
199
+ def generate_sm100_from_heuristics_configs(manifest, cuda_version, kernel_configs):
200
+ """
201
+ Generate CUTLASS operations based on the list of configs provided by the heuristic provider
202
+
203
+ args:
204
+ manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
205
+ cuda_version: Cuda compiler version for generating cutlass operations
206
+ kernel_configs: list of configs generated by the heuristic
207
+
208
+ returns:
209
+ (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
210
+ """
211
+ min_cc = 100
212
+ max_cc = 101
213
+ if manifest is None:
214
+ # Use a dummy manifest so we can use existing CreateGemmOperator functions
215
+ manifest = Manifest()
216
+
217
+ configs = []
218
+ operations = []
219
+ for config in kernel_configs:
220
+ layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [config['layout_d'], 128 // DataTypeSize[config['dtype_d']]])
221
+ element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
222
+
223
+ # nvMMH assumes 2sm instruction for !(cluster_m % 2)
224
+ is_2sm = config['cluster_m'] % 2 == 0
225
+ instruction_shape = [(2 * config['cta_tile_m']) if is_2sm else config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k'] // 4]
226
+ math_instruction = MathInstruction(
227
+ instruction_shape,
228
+ element_a, element_b, element_accumulator,
229
+ OpcodeClass.TensorOp,
230
+ MathOperation.multiply_add
231
+ )
232
+
233
+ data_types = [
234
+ {
235
+ "a_type" : math_instruction.element_a,
236
+ "b_type" : math_instruction.element_b,
237
+ "c_type" : DataType.void if config['voidC'] else math_instruction.element_accumulator,
238
+ "d_type" : element_d,
239
+ "acc_type" : math_instruction.element_accumulator,
240
+ "epi_type" : math_instruction.element_accumulator,
241
+ }
242
+ ]
243
+
244
+ tile_multiplier = (config['cluster_m'] // (2 if is_2sm else 1), config['cluster_n'], config['cluster_k'])
245
+ tile_description = TileDescription(
246
+ [instruction_shape[0] * tile_multiplier[0],
247
+ instruction_shape[1] * tile_multiplier[1],
248
+ instruction_shape[2] * 4 * tile_multiplier[2]],
249
+ 0,
250
+ [4,1,1],
251
+ math_instruction,
252
+ min_cc,
253
+ max_cc,
254
+ cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
255
+ )
256
+
257
+ schedules = []
258
+ if is_2sm:
259
+ schedules.append([KernelScheduleType.TmaWarpSpecialized2SmSm100, EpilogueScheduleType.TmaWarpSpecialized2Sm])
260
+ else:
261
+ schedules.append([KernelScheduleType.TmaWarpSpecialized1SmSm100, EpilogueScheduleType.TmaWarpSpecialized1Sm])
262
+
263
+ for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, tile_schedulers=[TileSchedulerType.Default, TileSchedulerType.StreamK], gemm_kind=GemmKind.Universal3x):
264
+ configs.append(config)
265
+ operations.append(o)
266
+
267
+
268
+ return configs, operations
269
+
270
+
271
+ def generate_sm90_from_heuristics_configs(manifest, cuda_version, kernel_configs):
272
+ """
273
+ Generate CUTLASS operations based on the list of configs provided by the heuristic provider
274
+
275
+ args:
276
+ manifest: manifest argument to which to add operations, or None to just return the operations without a manifest (for pruning an existing manifest)
277
+ cuda_version: Cuda compiler version for generating cutlass operations
278
+ kernel_configs: list of configs generated by the heuristic
279
+
280
+ returns:
281
+ (configs, operations): a list of heuristic-provided kernel configs along with a one-to-one corresponding list of the generated operations
282
+ """
283
+ min_cc, max_cc = 90, 90
284
+
285
+ if manifest is None:
286
+ # Use a dummy manifest so we can use existing CreateGemmOperator functions
287
+ manifest = Manifest()
288
+
289
+ configs = []
290
+ operations = []
291
+ for config in kernel_configs:
292
+
293
+ is_aligned = (config['alignment_a'] * DataTypeSize[config['dtype_a']] >= 128) and (config['alignment_b'] * DataTypeSize[config['dtype_b']] >= 128)
294
+ layout = ([config['layout_a'], config['alignment_a']], [config['layout_b'], config['alignment_b']], [LayoutType.ColumnMajor, 1])
295
+ element_a, element_b, element_accumulator, element_c, element_d = config['dtype_a'], config['dtype_b'], config['dtype_acc'], config['dtype_c'], config['dtype_d']
296
+
297
+ # instr shape and warp config are unused for emitting 3x collective builder code
298
+ dummy_instr_shape = [0, 0, 0]
299
+ math_instruction = MathInstruction(
300
+ dummy_instr_shape,
301
+ element_a, element_b, element_accumulator,
302
+ OpcodeClass.TensorOp,
303
+ MathOperation.multiply_add
304
+ )
305
+
306
+ data_types = generate_data_types_from_math_instruction(math_instruction, element_source=element_c, element_dest=element_d)
307
+ if is_aligned:
308
+ layout = fix_alignments(data_types, layout, alignment_bits=128)
309
+
310
+ # instr shape and warp config are unused for emitting 3x collective builder code
311
+ dummy_warp_count = [0, 0, 0]
312
+ tile_description = TileDescription(
313
+ [config['cta_tile_m'], config['cta_tile_n'], config['cta_tile_k']],
314
+ 0,
315
+ dummy_warp_count,
316
+ math_instruction,
317
+ min_cc,
318
+ max_cc,
319
+ cluster_shape=(config['cluster_m'], config['cluster_n'], config['cluster_k'])
320
+ )
321
+
322
+ schedules, stream_k_schedules = get_valid_schedules(
323
+ tile_description=tile_description,
324
+ cuda_version=cuda_version,
325
+ is_aligned=is_aligned,
326
+ data_types=data_types,
327
+ instantiation_level=9000, # don't prune schedules: we didn't get any schedule suggestion from the heuristic
328
+ layout=layout,
329
+ gemm_kind=GemmKind.Universal3x,
330
+ enable_fp8_fast_acc=config['use_fast_acc']
331
+ )
332
+
333
+ if len(schedules):
334
+ for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types, schedules, gemm_kind=GemmKind.Universal3x):
335
+ configs.append(config)
336
+ operations.append(o)
337
+
338
+ if len(stream_k_schedules):
339
+ for o in CreateGemmUniversal3xOperator(manifest, [layout], [tile_description], data_types,
340
+ stream_k_schedules,
341
+ tile_schedulers=[TileSchedulerType.StreamK]):
342
+ configs.append(config)
343
+ operations.append(o)
344
+
345
+
346
+ return configs, operations
347
+
348
+ def filter_manifest_and_write_heuristics_file(manifest, args):
349
+ """
350
+ Prune a manifest according to heuristics suggestions from the problems file
351
+
352
+ args:
353
+ manifest: Cutlass manifest to prune
354
+ args: generator.py args, requires:
355
+ - args.heuristics_problems_file
356
+ - args.heuristics_gpu
357
+ - args.heuristics_testlist_file
358
+
359
+ returns:
360
+ A list of dictionaries, each of which has information about an operation and a problem from the input problems
361
+ """
362
+ heuristics_problems = []
363
+ with open(args.heuristics_problems_file, 'r') as f:
364
+ heuristics_problems = json.load(f)
365
+ gpu = None if (args.heuristics_gpu == "auto" or args.heuristics_gpu == "") else args.heuristics_gpu
366
+ mmh = MatmulHeuristics(gpu=gpu)
367
+ if any(('100' in arch) for arch in args.architectures.split(';')):
368
+ mmh.set_cta_div_n(64)
369
+ problems_with_configs = get_gemm_configs(heuristics_problems, provider=mmh, count=args.heuristics_configs_per_problem)
370
+
371
+ all_configs_and_operations = []
372
+ operations = []
373
+ for problem in problems_with_configs:
374
+ if any('90' in arch for arch in args.architectures.split(';')):
375
+ problem_configs, problem_operations = generate_sm90_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
376
+ if any(('100' in arch) or ('101' in arch) for arch in args.architectures.split(';')):
377
+ problem_configs, problem_operations = generate_sm100_from_heuristics_configs(None if args.heuristics_restrict_kernels else manifest, args.cuda_version, problem['configs'])
378
+
379
+ operations += problem_operations
380
+ problem_without_configs = {k: v for k, v in problem.items() if k != 'configs'}
381
+ with_problem_size = [{'operation_name': o.procedural_name(), **problem_without_configs, **c} for c, o in zip(problem_configs, problem_operations)]
382
+ all_configs_and_operations += with_problem_size
383
+
384
+ for operation in operations:
385
+ manifest.add_kernel_filter(f"^{operation.procedural_name()}$")
386
+ if not all_configs_and_operations:
387
+ raise Exception("No valid configurations generated")
388
+ write_profiler_testlist_to_csv(all_configs_and_operations, args.heuristics_testlist_file)
389
+ return all_configs_and_operations
390
+
391
+ def write_profiler_testlist_to_csv(configs_list, outfile_path):
392
+ """
393
+ Write a list of configs to a testlist to be consumed by cutlass_profiler
394
+
395
+ args:
396
+ configs_list: List of kernel configs along with runtime arguments and any other columns to include in the CSV, expressed as a list of dictionaries
397
+ outfile_path: Outfile path
398
+
399
+ returns:
400
+ None
401
+ """
402
+ profiler_testlist = configs_list.copy()
403
+ for c in profiler_testlist:
404
+ for k, v in c.items():
405
+ if isinstance(v, DataType):
406
+ c[k] = DataTypeNames[v]
407
+ elif isinstance(v, LayoutType):
408
+ c[k] = ShortLayoutTypeNames[v]
409
+
410
+ with open(outfile_path, mode='w', newline='') as ofile:
411
+ k_names = profiler_testlist[0].keys()
412
+
413
+ writer = csv.DictWriter(ofile, fieldnames=k_names)
414
+ writer.writeheader()
415
+ writer.writerows(profiler_testlist)
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/heuristics_provider.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Providers for kernel selection heuristics
35
+ """
36
+
37
+ import sys
38
+ import os
39
+ import glob
40
+ import logging
41
+ import ctypes
42
+ import functools
43
+
44
+
45
+ try:
46
+ import builtins
47
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
48
+ raise ImportError("Disabling attempt to import cutlass_library")
49
+ from cutlass_library.library import DataType, LayoutType
50
+ except ImportError:
51
+ from library import DataType, LayoutType
52
+
53
+ class MatmulHeuristics:
54
+
55
+ def __init__(self, gpu = None):
56
+ import nvMatmulHeuristics
57
+ self.mmh_lib = nvMatmulHeuristics
58
+ self.gpu = gpu
59
+
60
+ if 'CUTLASS_NVMMH_SO_PATH' in os.environ:
61
+ nvmmhInterfaceEx = functools.partial(self.mmh_lib.NvMatmulHeuristicsInterfaceEx, path=os.environ['CUTLASS_NVMMH_SO_PATH'])
62
+ else:
63
+ nvmmhInterfaceEx = self.mmh_lib.NvMatmulHeuristicsInterfaceEx
64
+
65
+ self.lh = nvmmhInterfaceEx(
66
+ backend=self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"],
67
+ flags=self.mmh_lib.NvMatmulHeuristicsFlags.PERF_MODEL_BASED_AUTO_TUNING,
68
+ load_discovery_implicitly=True,
69
+ gpu=self.mmh_lib.NvMatmulHeuristicsNvidiaGpu[self.gpu] if self.gpu else None
70
+ )
71
+ self.backend = self.lh.createBackend(self.mmh_lib.NvMatmulHeuristicsTarget["CUTLASS3"])
72
+
73
+ def _layout_from_cutlass(self, layouts):
74
+ assert(len(layouts)==3)
75
+ full_layout_str = ''.join('t' if l == LayoutType.RowMajor else 'n' for l in layouts)
76
+ input_layouts = full_layout_str[:2].upper()
77
+ lh_layout = input_layouts + '_' + str("ROW_MAJOR" if full_layout_str[-1]=='t' else "COL_MAJOR")
78
+ return self.mmh_lib.NvMatmulHeuristicsMatmulLayout[lh_layout]
79
+
80
+ def _precision_from_cutlass_dtypes(self, dtypes):
81
+ dtype_to_cublas = {
82
+ DataType.f64: 'D',
83
+ DataType.f32: 'S',
84
+ DataType.f16: 'H',
85
+ DataType.bf16: 'T',
86
+ DataType.e4m3: 'Q',
87
+ DataType.e5m2: 'R',
88
+ DataType.s32: 'I',
89
+ DataType.s8: 'B',
90
+ }
91
+
92
+ dtype_a, dtype_b, dtype_compute, dtype_c, dtype_d = dtypes
93
+
94
+ a_c = dtype_to_cublas[dtype_a]
95
+
96
+ if a_c.lower() != 'q':
97
+ return a_c + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
98
+ else:
99
+ return a_c + dtype_to_cublas[dtype_b] + dtype_to_cublas[dtype_c] + dtype_to_cublas[dtype_compute] + dtype_to_cublas[dtype_d]
100
+
101
+ def set_cta_div_n(self, div_n):
102
+ cta_n_div_requirement = ctypes.c_int(div_n)
103
+ self.lh.setBackendValueProperty(
104
+ self.backend,
105
+ self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_N_DIV_REQUIREMENT,
106
+ ctypes.byref(cta_n_div_requirement),
107
+ ctypes.sizeof(cta_n_div_requirement)
108
+ )
109
+
110
+ def set_cta_div_m(self, div_m):
111
+ cta_m_div_requirement = ctypes.c_int(div_m)
112
+ self.lh.setBackendValueProperty(
113
+ self.backend,
114
+ self.mmh_lib.NvMatmulHeuristicsBackendProperty.CTA_TILE_M_DIV_REQUIREMENT,
115
+ ctypes.byref(cta_m_div_requirement),
116
+ ctypes.sizeof(cta_m_div_requirement)
117
+ )
118
+
119
+ def get_configs(self, m, n, k, batch_count, dtypes, layouts, align_a, align_b, voidC=False, use_fast_acc=True, count=1):
120
+ if use_fast_acc:
121
+ disable_fast_acc_for_fp8 = ctypes.c_int(0)
122
+ else:
123
+ disable_fast_acc_for_fp8 = ctypes.c_int(1)
124
+ self.lh.setBackendValueProperty(
125
+ self.backend,
126
+ self.mmh_lib.NvMatmulHeuristicsBackendProperty.DISABLE_FAST_ACC_FOR_FP8,
127
+ ctypes.byref(disable_fast_acc_for_fp8),
128
+ ctypes.sizeof(disable_fast_acc_for_fp8)
129
+ )
130
+
131
+ precision = self._precision_from_cutlass_dtypes(dtypes)
132
+ layout = self._layout_from_cutlass(layouts)
133
+
134
+ matmul_problem = self.lh.makeNvMatmulHeuristicsProblem(m, n, k, layout, batch_count)
135
+ configs = self.lh.getEx(matmul_problem, count, self.backend, precision=precision)
136
+
137
+ ret = []
138
+ for c in configs:
139
+ kernel = c['kernel']
140
+ problem = c['problem']
141
+
142
+ r = {}
143
+ r['estimated_runtime'] = c['runtime']
144
+ r['cta_tile_m'] = kernel.cta_tile_m
145
+ r['cta_tile_n'] = kernel.cta_tile_n
146
+ r['cta_tile_k'] = kernel.cta_tile_k
147
+ r['instr_tile_m'] = kernel.instr_tile_m
148
+ r['instr_tile_n'] = kernel.instr_tile_n
149
+ r['instr_tile_k'] = kernel.instr_tile_k
150
+ r['warp_tile_m'] = kernel.warp_tile_m
151
+ r['warp_tile_n'] = kernel.warp_tile_n
152
+ r['warp_tile_k'] = kernel.warp_tile_k
153
+ r['cluster_m'] = kernel.cluster_m
154
+ r['cluster_n'] = kernel.cluster_n
155
+ r['cluster_k'] = 1
156
+ r['layout_a'] = layouts[0]
157
+ r['layout_b'] = layouts[1]
158
+ r['layout_d'] = layouts[2]
159
+ r['dtype_a'] = dtypes[0]
160
+ r['dtype_b'] = dtypes[1]
161
+ r['dtype_acc'] = dtypes[2]
162
+ r['dtype_c'] = dtypes[3]
163
+ r['dtype_d'] = dtypes[4]
164
+ r['alignment_a'] = align_a
165
+ r['alignment_b'] = align_b
166
+ r['swizzle_size'] = kernel.swizzle_factor
167
+ r['raster_order'] = 'along_m' if kernel.cta_order==0 else 'along_n'
168
+ r['split_k_slices'] = kernel.split_k
169
+ r['use_fast_acc'] = use_fast_acc
170
+ r['voidC'] = voidC
171
+
172
+ ret.append(r)
173
+
174
+ return ret
175
+
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/library.py ADDED
@@ -0,0 +1,1531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Data types and tags used for emitting CUTLASS C++ kernels
35
+ """
36
+
37
+ import enum
38
+ import re
39
+
40
+ # The following block implements enum.auto() for Python 3.5 variants that don't include it such
41
+ # as the default 3.5.2 on Ubuntu 16.04.
42
+ #
43
+ # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
44
+
45
+ try:
46
+ from enum import auto as enum_auto
47
+ except ImportError:
48
+ __cutlass_library_auto_enum = 0
49
+ def enum_auto() -> int:
50
+ global __cutlass_library_auto_enum
51
+ i = __cutlass_library_auto_enum
52
+ __cutlass_library_auto_enum += 1
53
+ return i
54
+
55
+ ###################################################################################################
56
+
57
+ #
58
+ class GeneratorTarget(enum.Enum):
59
+ Library = enum_auto()
60
+ #
61
+ GeneratorTargetNames = {
62
+ GeneratorTarget.Library: 'library'
63
+ }
64
+ #
65
+
66
+ ###################################################################################################
67
+
68
+ #
69
+ class DataType(enum.Enum):
70
+ void = enum_auto() # primarily used to disable C tensor for epilogues
71
+ b1 = enum_auto()
72
+ u2 = enum_auto()
73
+ u4 = enum_auto()
74
+ u8 = enum_auto()
75
+ u16 = enum_auto()
76
+ u32 = enum_auto()
77
+ u64 = enum_auto()
78
+ s2 = enum_auto()
79
+ s4 = enum_auto()
80
+ s8 = enum_auto()
81
+ s16 = enum_auto()
82
+ s32 = enum_auto()
83
+ s64 = enum_auto()
84
+ e4m3 = enum_auto()
85
+ e5m2 = enum_auto()
86
+ f8 = enum_auto()
87
+ f6 = enum_auto()
88
+ f4 = enum_auto()
89
+ e3m2 = enum_auto()
90
+ e2m3 = enum_auto()
91
+ e2m1 = enum_auto()
92
+ ue8m0 = enum_auto()
93
+ ue4m3 = enum_auto()
94
+ f16 = enum_auto()
95
+ bf16 = enum_auto()
96
+ f32 = enum_auto()
97
+ tf32 = enum_auto()
98
+ f64 = enum_auto()
99
+ cf16 = enum_auto()
100
+ cbf16 = enum_auto()
101
+ cf32 = enum_auto()
102
+ ctf32 = enum_auto()
103
+ cf64 = enum_auto()
104
+ cs2 = enum_auto()
105
+ cs4 = enum_auto()
106
+ cs8 = enum_auto()
107
+ cs16 = enum_auto()
108
+ cs32 = enum_auto()
109
+ cs64 = enum_auto()
110
+ cu2 = enum_auto()
111
+ cu4 = enum_auto()
112
+ cu8 = enum_auto()
113
+ cu16 = enum_auto()
114
+ cu32 = enum_auto()
115
+ cu64 = enum_auto()
116
+ invalid = enum_auto()
117
+
118
+ #
119
+ ShortDataTypeNames = {
120
+ DataType.s32: 'i',
121
+ DataType.e4m3: 'e4m3',
122
+ DataType.e5m2: 'e5m2',
123
+ DataType.f16: 'h',
124
+ DataType.f32: 's',
125
+ DataType.f64: 'd',
126
+ DataType.cf32: 'c',
127
+ DataType.cf64: 'z',
128
+ DataType.f8: 'f8',
129
+ DataType.f6: 'f6',
130
+ DataType.f4: 'f4',
131
+ }
132
+
133
+ #
134
+ DataTypeNames = {
135
+ DataType.void: "void",
136
+ DataType.b1: "b1",
137
+ DataType.u2: "u2",
138
+ DataType.u4: "u4",
139
+ DataType.u8: "u8",
140
+ DataType.u16: "u16",
141
+ DataType.u32: "u32",
142
+ DataType.u64: "u64",
143
+ DataType.s2: "s2",
144
+ DataType.s4: "s4",
145
+ DataType.s8: "s8",
146
+ DataType.s16: "s16",
147
+ DataType.s32: "s32",
148
+ DataType.s64: "s64",
149
+ DataType.e4m3: 'e4m3',
150
+ DataType.e5m2: 'e5m2',
151
+ DataType.f8: 'f8',
152
+ DataType.f6: 'f6',
153
+ DataType.f4: 'f4',
154
+ DataType.e2m3: 'e2m3',
155
+ DataType.e3m2: 'e3m2',
156
+ DataType.e2m1: 'e2m1',
157
+ DataType.ue8m0: 'ue8m0',
158
+ DataType.ue4m3: 'ue4m3',
159
+ DataType.f16: "f16",
160
+ DataType.bf16: "bf16",
161
+ DataType.f32: "f32",
162
+ DataType.tf32: "tf32",
163
+ DataType.f64: "f64",
164
+ DataType.cf16: "cf16",
165
+ DataType.cbf16: "cbf16",
166
+ DataType.cf32: "cf32",
167
+ DataType.ctf32: "ctf32",
168
+ DataType.cf64: "cf64",
169
+ DataType.cu2: "cu2",
170
+ DataType.cu4: "cu4",
171
+ DataType.cu8: "cu8",
172
+ DataType.cu16: "cu16",
173
+ DataType.cu32: "cu32",
174
+ DataType.cu64: "cu64",
175
+ DataType.cs2: "cs2",
176
+ DataType.cs4: "cs4",
177
+ DataType.cs8: "cs8",
178
+ DataType.cs16: "cs16",
179
+ DataType.cs32: "cs32",
180
+ DataType.cs64: "cs64",
181
+ }
182
+
183
+ DataTypeTag = {
184
+ DataType.void: "void",
185
+ DataType.b1: "cutlass::uint1b_t",
186
+ DataType.u2: "cutlass::uint2b_t",
187
+ DataType.u4: "cutlass::uint4b_t",
188
+ DataType.u8: "uint8_t",
189
+ DataType.u16: "uint16_t",
190
+ DataType.u32: "uint32_t",
191
+ DataType.u64: "uint64_t",
192
+ DataType.s2: "cutlass::int2b_t",
193
+ DataType.s4: "cutlass::int4b_t",
194
+ DataType.s8: "int8_t",
195
+ DataType.s16: "int16_t",
196
+ DataType.s32: "int32_t",
197
+ DataType.s64: "int64_t",
198
+ DataType.e4m3: 'cutlass::float_e4m3_t',
199
+ DataType.e5m2: 'cutlass::float_e5m2_t',
200
+ DataType.f8: 'cutlass::type_erased_dynamic_float8_t',
201
+ DataType.f6: 'cutlass::type_erased_dynamic_float6_t',
202
+ DataType.f4: 'cutlass::type_erased_dynamic_float4_t',
203
+ DataType.e2m3: 'cutlass::float_e2m3_t',
204
+ DataType.e3m2: 'cutlass::float_e3m2_t',
205
+ DataType.e2m1: 'cutlass::float_e2m1_t',
206
+ DataType.ue8m0: 'cutlass::float_ue8m0_t',
207
+ DataType.ue4m3: 'cutlass::float_ue4m3_t',
208
+ DataType.f16: "cutlass::half_t",
209
+ DataType.bf16: "cutlass::bfloat16_t",
210
+ DataType.f32: "float",
211
+ DataType.tf32: "cutlass::tfloat32_t",
212
+ DataType.f64: "double",
213
+ DataType.cf16: "cutlass::complex<cutlass::half_t>",
214
+ DataType.cbf16: "cutlass::complex<cutlass::bfloat16_t>",
215
+ DataType.cf32: "cutlass::complex<float>",
216
+ DataType.ctf32: "cutlass::complex<cutlass::tfloat32_t>",
217
+ DataType.cf64: "cutlass::complex<double>",
218
+ DataType.cu2: "cutlass::complex<cutlass::uint2b_t>",
219
+ DataType.cu4: "cutlass::complex<cutlass::uint4b_t>",
220
+ DataType.cu8: "cutlass::complex<cutlass::uint8_t>",
221
+ DataType.cu16: "cutlass::complex<cutlass::uint16_t>",
222
+ DataType.cu32: "cutlass::complex<cutlass::uint32_t>",
223
+ DataType.cu64: "cutlass::complex<cutlass::uint64_t>",
224
+ DataType.cs2: "cutlass::complex<cutlass::int2b_t>",
225
+ DataType.cs4: "cutlass::complex<cutlass::int4b_t>",
226
+ DataType.cs8: "cutlass::complex<cutlass::int8_t>",
227
+ DataType.cs16: "cutlass::complex<cutlass::int16_t>",
228
+ DataType.cs32: "cutlass::complex<cutlass::int32_t>",
229
+ DataType.cs64: "cutlass::complex<cutlass::int64_t>",
230
+ }
231
+
232
+ DataTypeSize = {
233
+ DataType.void: 0,
234
+ DataType.b1: 1,
235
+ DataType.u2: 2,
236
+ DataType.u4: 4,
237
+ DataType.u8: 8,
238
+ DataType.u16: 16,
239
+ DataType.u32: 32,
240
+ DataType.u64: 64,
241
+ DataType.s2: 2,
242
+ DataType.s4: 4,
243
+ DataType.s8: 8,
244
+ DataType.s16: 16,
245
+ DataType.s32: 32,
246
+ DataType.s64: 64,
247
+ DataType.e4m3: 8,
248
+ DataType.e5m2: 8,
249
+ DataType.f8: 8,
250
+ DataType.f6: 6,
251
+ DataType.f4: 4,
252
+ DataType.e2m3: 6,
253
+ DataType.e3m2: 6,
254
+ DataType.e2m1: 4,
255
+ DataType.ue8m0: 8,
256
+ DataType.ue4m3: 8,
257
+ DataType.f16: 16,
258
+ DataType.bf16: 16,
259
+ DataType.f32: 32,
260
+ DataType.tf32: 32,
261
+ DataType.f64: 64,
262
+ DataType.cf16: 32,
263
+ DataType.cbf16: 32,
264
+ DataType.cf32: 64,
265
+ DataType.ctf32: 32,
266
+ DataType.cf64: 128,
267
+ DataType.cu2: 4,
268
+ DataType.cu4: 8,
269
+ DataType.cu8: 16,
270
+ DataType.cu16: 32,
271
+ DataType.cu32: 64,
272
+ DataType.cu64: 128,
273
+ DataType.cs2: 4,
274
+ DataType.cs4: 8,
275
+ DataType.cs8: 16,
276
+ DataType.cs16: 32,
277
+ DataType.cs32: 64,
278
+ DataType.cs64: 128,
279
+ }
280
+
281
+ ###################################################################################################
282
+ #
283
+ class BlasMode(enum.Enum):
284
+ symmetric = enum_auto()
285
+ hermitian = enum_auto()
286
+
287
+ #
288
+ BlasModeTag = {
289
+ BlasMode.symmetric: 'cutlass::BlasMode::kSymmetric',
290
+ BlasMode.hermitian: 'cutlass::BlasMode::kHermitian',
291
+ }
292
+
293
+ #
294
+ class ComplexTransform(enum.Enum):
295
+ none = enum_auto()
296
+ conj = enum_auto()
297
+
298
+ #
299
+ ComplexTransformTag = {
300
+ ComplexTransform.none: 'cutlass::ComplexTransform::kNone',
301
+ ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate',
302
+ }
303
+
304
+ # Used for cutlass3x complex kernel collective mainloop builder instantiation
305
+ ComplexTransformTag3x = {
306
+ ComplexTransform.none: 'cute::identity',
307
+ ComplexTransform.conj: 'cute::conjugate',
308
+ }
309
+
310
+ #
311
+ RealComplexBijection = [
312
+ (DataType.f16, DataType.cf16),
313
+ (DataType.f32, DataType.cf32),
314
+ (DataType.f64, DataType.cf64),
315
+ ]
316
+
317
+ #
318
+ def is_complex(data_type):
319
+ for r, c in RealComplexBijection:
320
+ if data_type == c:
321
+ return True
322
+ return False
323
+
324
+ def is_block_scaled(gemm_kind):
325
+ return gemm_kind in (GemmKind.BlockScaledUniversal3x, GemmKind.GroupedBlockScaledUniversal3x)
326
+
327
+ def is_blockwise(gemm_kind):
328
+ return gemm_kind in (GemmKind.BlockwiseUniversal3x, GemmKind.GroupedBlockwiseUniversal3x)
329
+
330
+ def is_grouped(gemm_kind):
331
+ return gemm_kind in (GemmKind.GroupedUniversal3x,
332
+ GemmKind.GroupedBlockScaledUniversal3x, GemmKind.GroupedBlockwiseUniversal3x)
333
+
334
+ #
335
+ def get_complex_from_real(real_type):
336
+ for r, c in RealComplexBijection:
337
+ if real_type == r:
338
+ return c
339
+ return DataType.invalid
340
+
341
+ #
342
+ def get_real_from_complex(complex_type):
343
+ for r, c in RealComplexBijection:
344
+ if complex_type == c:
345
+ return r
346
+ return DataType.invalid
347
+
348
+ # TMA requires an alignment of 128 bits for all data types
349
+ def get_tma_alignment(data_type):
350
+ if data_type == DataType.void:
351
+ return 0
352
+ elif DataTypeSize[data_type] == 6:
353
+ return 128 # 96B alignment for 16U6 format
354
+ else:
355
+ return 128 // DataTypeSize[data_type]
356
+
357
+ #
358
+ class ComplexMultiplyOp(enum.Enum):
359
+ multiply_add = enum_auto()
360
+ gaussian = enum_auto()
361
+
362
+ ###################################################################################################
363
+
364
+ #
365
+ class MathOperation(enum.Enum):
366
+ multiply_add = enum_auto()
367
+ multiply_add_saturate = enum_auto()
368
+ multiply_add_mixed_input_upcast = enum_auto()
369
+ xor_popc = enum_auto()
370
+ and_popc = enum_auto()
371
+ multiply_add_fast_bf16 = enum_auto()
372
+ multiply_add_fast_f16 = enum_auto()
373
+ multiply_add_fast_f32 = enum_auto()
374
+ multiply_add_complex_fast_f32 = enum_auto()
375
+ multiply_add_complex = enum_auto()
376
+ multiply_add_complex_gaussian = enum_auto()
377
+ multiply_add_fast_accum = enum_auto()
378
+
379
+ #
380
+ MathOperationTag = {
381
+ MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
382
+ MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
383
+ MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast',
384
+ MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
385
+ MathOperation.and_popc: 'cutlass::arch::OpAndPopc',
386
+ MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
387
+ MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16',
388
+ MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32',
389
+ MathOperation.multiply_add_complex_fast_f32: 'cutlass::arch::OpMultiplyAddComplexFastF32',
390
+ MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex',
391
+ MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex',
392
+ MathOperation.multiply_add_fast_accum: 'cutlass::arch::OpMultiplyAddFastAccum',
393
+ }
394
+
395
+ ###################################################################################################
396
+
397
+ #
398
+ class LayoutType(enum.Enum):
399
+ ColumnMajor = enum_auto()
400
+ RowMajor = enum_auto()
401
+ ColumnMajorInterleaved2 = enum_auto()
402
+ RowMajorInterleaved2 = enum_auto()
403
+ ColumnMajorInterleaved32 = enum_auto()
404
+ RowMajorInterleaved32 = enum_auto()
405
+ ColumnMajorInterleaved64 = enum_auto()
406
+ RowMajorInterleaved64 = enum_auto()
407
+ TensorNWC = enum_auto()
408
+ TensorNHWC = enum_auto()
409
+ TensorNDHWC = enum_auto()
410
+ TensorNCHW = enum_auto()
411
+ TensorNGHWC = enum_auto()
412
+ TensorNC32HW32 = enum_auto()
413
+ TensorNC64HW64 = enum_auto()
414
+ TensorC32RSK32 = enum_auto()
415
+ TensorC64RSK64 = enum_auto()
416
+ TensorKCS = enum_auto()
417
+ TensorKCSR = enum_auto()
418
+ TensorKCSRT = enum_auto()
419
+
420
+ #
421
+ LayoutTag = {
422
+ LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor',
423
+ LayoutType.RowMajor: 'cutlass::layout::RowMajor',
424
+ LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>',
425
+ LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>',
426
+ LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>',
427
+ LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>',
428
+ LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>',
429
+ LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>',
430
+ LayoutType.TensorNWC: 'cutlass::layout::TensorNWC',
431
+ LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC',
432
+ LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC',
433
+ LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW',
434
+ LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC',
435
+ LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>',
436
+ LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>',
437
+ LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>',
438
+ LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>',
439
+ LayoutType.TensorKCS: 'cutlass::layout::TensorKCS',
440
+ LayoutType.TensorKCSR: 'cutlass::layout::TensorKCSR',
441
+ LayoutType.TensorKCSRT: 'cutlass::layout::TensorKCSRT'
442
+ }
443
+
444
+ #
445
+ TransposedLayout = {
446
+ LayoutType.ColumnMajor: LayoutType.RowMajor,
447
+ LayoutType.RowMajor: LayoutType.ColumnMajor,
448
+ LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2,
449
+ LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2,
450
+ LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32,
451
+ LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32,
452
+ LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64,
453
+ LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64,
454
+ LayoutType.TensorNHWC: LayoutType.TensorNHWC
455
+ }
456
+
457
+ #
458
+ ShortLayoutTypeNames = {
459
+ LayoutType.ColumnMajor: 'n',
460
+ LayoutType.ColumnMajorInterleaved2: 'n2',
461
+ LayoutType.ColumnMajorInterleaved32: 'n32',
462
+ LayoutType.ColumnMajorInterleaved64: 'n64',
463
+ LayoutType.RowMajor: 't',
464
+ LayoutType.RowMajorInterleaved2: 't2',
465
+ LayoutType.RowMajorInterleaved32: 't32',
466
+ LayoutType.RowMajorInterleaved64: 't64',
467
+ LayoutType.TensorNWC: 'nwc',
468
+ LayoutType.TensorNHWC: 'nhwc',
469
+ LayoutType.TensorNDHWC: 'ndhwc',
470
+ LayoutType.TensorNCHW: 'nchw',
471
+ LayoutType.TensorNGHWC: 'nghwc',
472
+ LayoutType.TensorNC32HW32: 'nc32hw32',
473
+ LayoutType.TensorNC64HW64: 'nc64hw64',
474
+ LayoutType.TensorC32RSK32: 'c32rsk32',
475
+ LayoutType.TensorC64RSK64: 'c64rsk64',
476
+ LayoutType.TensorKCS: 'kcs',
477
+ LayoutType.TensorKCSR: 'kcsr',
478
+ LayoutType.TensorKCSRT: 'kcsrt'
479
+ }
480
+
481
+ #
482
+ ShortComplexLayoutNames = {
483
+ (LayoutType.ColumnMajor, ComplexTransform.none): 'n',
484
+ (LayoutType.ColumnMajor, ComplexTransform.conj): 'c',
485
+ (LayoutType.RowMajor, ComplexTransform.none): 't',
486
+ (LayoutType.RowMajor, ComplexTransform.conj): 'h'
487
+ }
488
+
489
+ ###################################################################################################
490
+ class KernelScheduleType(enum.Enum):
491
+ ScheduleAuto = enum_auto()
492
+ Multistage = enum_auto()
493
+ CpAsyncWarpSpecialized = enum_auto()
494
+ CpAsyncWarpSpecializedPingpong = enum_auto()
495
+ CpAsyncWarpSpecializedCooperative = enum_auto()
496
+ Tma = enum_auto()
497
+ TmaWarpSpecialized = enum_auto()
498
+ TmaWarpSpecializedPingpong = enum_auto()
499
+ TmaWarpSpecializedCooperative = enum_auto()
500
+ TmaWarpSpecializedFP8FastAccum = enum_auto()
501
+ TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
502
+ TmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
503
+ ImplicitTmaWarpSpecializedSm90 = enum_auto()
504
+ PtrArrayTmaWarpSpecializedCooperative = enum_auto()
505
+ PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum = enum_auto()
506
+ PtrArrayTmaWarpSpecializedPingpong = enum_auto()
507
+ PtrArrayTmaWarpSpecializedPingpongFP8FastAccum = enum_auto()
508
+
509
+ BlockwiseTmaWarpSpecializedCooperative = enum_auto()
510
+ PtrArrayBlockwiseTmaWarpSpecializedCooperative = enum_auto()
511
+ BlockwiseTmaWarpSpecializedPingpong = enum_auto()
512
+ PtrArrayBlockwiseTmaWarpSpecializedPingpong = enum_auto()
513
+
514
+ TmaWarpSpecialized1SmSm100 = enum_auto()
515
+ TmaWarpSpecialized2SmSm100 = enum_auto()
516
+ ImplicitTmaWarpSpecialized1SmSm100 = enum_auto()
517
+ ImplicitTmaWarpSpecialized2SmSm100 = enum_auto()
518
+
519
+ PtrArrayTmaWarpSpecialized1SmSm100 = enum_auto()
520
+ PtrArrayTmaWarpSpecialized2SmSm100 = enum_auto()
521
+
522
+ PtrArrayTmaWarpSpecialized1SmBlockScaledSm100 = enum_auto()
523
+ PtrArrayTmaWarpSpecialized2SmBlockScaledSm100 = enum_auto()
524
+ PtrArrayNvf4TmaWarpSpecialized1SmSm100 = enum_auto()
525
+ PtrArrayNvf4TmaWarpSpecialized2SmSm100 = enum_auto()
526
+ PtrArrayMxf4TmaWarpSpecialized1SmSm100 = enum_auto()
527
+ PtrArrayMxf4TmaWarpSpecialized2SmSm100 = enum_auto()
528
+ PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
529
+ PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
530
+
531
+ SparseTmaWarpSpecialized1SmSm100 = enum_auto()
532
+ SparseTmaWarpSpecialized2SmSm100 = enum_auto()
533
+
534
+ BlockScaledTmaWarpSpecialized1SmSm100 = enum_auto()
535
+ BlockScaledTmaWarpSpecialized2SmSm100 = enum_auto()
536
+ Mxf8f6f4TmaWarpSpecialized1SmSm100 = enum_auto()
537
+ Mxf8f6f4TmaWarpSpecialized2SmSm100 = enum_auto()
538
+
539
+ BlockwiseTmaWarpSpecialized1SmSm100 = enum_auto()
540
+ BlockwiseTmaWarpSpecialized2SmSm100 = enum_auto()
541
+
542
+ PtrArrayBlockwiseTmaWarpSpecialized1SmSm100 = enum_auto()
543
+ PtrArrayBlockwiseTmaWarpSpecialized2SmSm100 = enum_auto()
544
+
545
+
546
+ Mxf4TmaWarpSpecialized1SmSm100 = enum_auto()
547
+ Mxf4TmaWarpSpecialized2SmSm100 = enum_auto()
548
+ Nvf4TmaWarpSpecialized1SmSm100 = enum_auto()
549
+ Nvf4TmaWarpSpecialized2SmSm100 = enum_auto()
550
+
551
+ # FP4 Ultra
552
+ MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto()
553
+ MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto()
554
+ MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto()
555
+ MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto()
556
+
557
+ MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto()
558
+ MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto()
559
+ MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto()
560
+ MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto()
561
+
562
+ MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto()
563
+ MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto()
564
+ MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto()
565
+ MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto()
566
+
567
+ PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103 = enum_auto()
568
+ PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103 = enum_auto()
569
+ PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103 = enum_auto()
570
+ PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103 = enum_auto()
571
+
572
+ PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch = enum_auto()
573
+ PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch = enum_auto()
574
+ PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch = enum_auto()
575
+ PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch = enum_auto()
576
+
577
+ PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch = enum_auto()
578
+ PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch = enum_auto()
579
+ PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch = enum_auto()
580
+ PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch = enum_auto()
581
+
582
+ Mxf8f6f4TmaWarpSpecializedCooperativeSm120 = enum_auto()
583
+ Mxf8f6f4TmaWarpSpecializedPingpongSm120 = enum_auto()
584
+ Nvf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
585
+ Nvf4TmaWarpSpecializedPingpongSm120 = enum_auto()
586
+ Mxf4TmaWarpSpecializedCooperativeSm120 = enum_auto()
587
+ Mxf4TmaWarpSpecializedPingpongSm120 = enum_auto()
588
+
589
+ F8f6f4SparseTmaWarpSpecializedCooperativeSm120 = enum_auto()
590
+
591
+ BlockwiseTmaWarpSpecializedCooperativeSm120 = enum_auto()
592
+ BlockwiseTmaWarpSpecializedPingpongSm120 = enum_auto()
593
+
594
+ KernelScheduleTag = {
595
+ KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto',
596
+ KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage',
597
+ KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized',
598
+ KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong',
599
+ KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative',
600
+ KernelScheduleType.Tma: 'cutlass::gemm::KernelTma',
601
+ KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized',
602
+ KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong',
603
+ KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative',
604
+ KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum',
605
+ KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum',
606
+ KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum',
607
+ KernelScheduleType.ImplicitTmaWarpSpecializedSm90: 'cutlass::conv::KernelImplicitTmaWarpSpecializedSm90',
608
+
609
+ KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8Blockwise',
610
+ KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8Blockwise',
611
+
612
+ KernelScheduleType.TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmSm100',
613
+ KernelScheduleType.TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmSm100',
614
+
615
+ KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized1SmSm100',
616
+ KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: 'cutlass::conv::KernelImplicitTmaWarpSpecialized2SmSm100',
617
+
618
+ KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100',
619
+ KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100',
620
+
621
+ KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized1SmSm100',
622
+ KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelSparseTmaWarpSpecialized2SmSm100',
623
+
624
+ KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledSm100',
625
+ KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledSm100',
626
+ KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf8f6f4Sm100',
627
+ KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf8f6f4Sm100',
628
+
629
+ KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100',
630
+ KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100',
631
+
632
+ KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100',
633
+ KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100',
634
+
635
+ KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmMxf4Sm100',
636
+ KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmMxf4Sm100',
637
+ KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized1SmNvf4Sm100',
638
+ KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: 'cutlass::gemm::KernelTmaWarpSpecialized2SmNvf4Sm100',
639
+
640
+ # FP4 Ultra
641
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103',
642
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103',
643
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103',
644
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103',
645
+
646
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
647
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
648
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
649
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
650
+
651
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
652
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
653
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
654
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
655
+
656
+ KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative',
657
+ KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum',
658
+ KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong',
659
+ KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum',
660
+
661
+ KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8Blockwise',
662
+ KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8Blockwise',
663
+
664
+ KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledSm100",
665
+ KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledSm100",
666
+ KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100",
667
+ KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmNvf4Sm100",
668
+ KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf4Sm100",
669
+ KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf4Sm100",
670
+ KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100",
671
+ KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: "cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmMxf8f6f4Sm100",
672
+
673
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103',
674
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103',
675
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103',
676
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103',
677
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
678
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103TmaPrefetch',
679
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
680
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103TmaPrefetch',
681
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
682
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103DisablePrefetch',
683
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
684
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: 'cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs32Sm103DisablePrefetch',
685
+
686
+ KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf8f6f4Sm120',
687
+ KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf8f6f4Sm120',
688
+ KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedNvf4Sm120',
689
+ KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongNvf4Sm120',
690
+ KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedMxf4Sm120',
691
+ KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongMxf4Sm120',
692
+
693
+ KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelScheduleSparseF8f6f4Sm120',
694
+
695
+ KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwiseCooperativeSm120',
696
+ KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: 'cutlass::gemm::KernelTmaWarpSpecializedBlockwisePingpongSm120',
697
+ }
698
+
699
+ #
700
+ KernelScheduleSuffixes = {
701
+ KernelScheduleType.ScheduleAuto: '',
702
+ KernelScheduleType.Multistage: '_cpasync',
703
+ KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized',
704
+ KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong',
705
+ KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative',
706
+ KernelScheduleType.Tma: '_unspecialized',
707
+ KernelScheduleType.TmaWarpSpecialized: '_warpspecialized',
708
+ KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
709
+ KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
710
+ KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum',
711
+ KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
712
+ KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
713
+ KernelScheduleType.ImplicitTmaWarpSpecializedSm90: '_warpspecialized',
714
+
715
+ KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
716
+ KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
717
+
718
+ KernelScheduleType.TmaWarpSpecialized1SmSm100: '_1sm',
719
+ KernelScheduleType.TmaWarpSpecialized2SmSm100: '_2sm',
720
+
721
+ KernelScheduleType.ImplicitTmaWarpSpecialized1SmSm100: '_1sm',
722
+ KernelScheduleType.ImplicitTmaWarpSpecialized2SmSm100: '_2sm',
723
+
724
+ KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100: '_1sm',
725
+ KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100: '_2sm',
726
+
727
+ KernelScheduleType.SparseTmaWarpSpecialized1SmSm100: '_1sm',
728
+ KernelScheduleType.SparseTmaWarpSpecialized2SmSm100: '_2sm',
729
+
730
+ KernelScheduleType.BlockScaledTmaWarpSpecialized1SmSm100: '_1sm',
731
+ KernelScheduleType.BlockScaledTmaWarpSpecialized2SmSm100: '_2sm',
732
+ KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100: '_q_1sm',
733
+ KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100: '_q_2sm',
734
+
735
+ KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100: '_1sm',
736
+ KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100: '_2sm',
737
+ KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100: '_1sm',
738
+ KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100: '_2sm',
739
+
740
+ KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
741
+ KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
742
+ KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
743
+ KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
744
+
745
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm',
746
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm',
747
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm',
748
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm',
749
+
750
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf',
751
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf',
752
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf',
753
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf',
754
+
755
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf',
756
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf',
757
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf',
758
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf',
759
+
760
+ KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
761
+ KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum',
762
+ KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
763
+ KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum',
764
+
765
+ KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative: '_warpspecialized_cooperative',
766
+ KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong: '_warpspecialized_pingpong',
767
+
768
+ KernelScheduleType.PtrArrayTmaWarpSpecialized1SmBlockScaledSm100: '_1sm',
769
+ KernelScheduleType.PtrArrayTmaWarpSpecialized2SmBlockScaledSm100: '_2sm',
770
+ KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100: '_o_vs16_1sm',
771
+ KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100: '_o_vs16_2sm',
772
+ KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
773
+ KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
774
+ KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100: '_o_vs32_1sm',
775
+ KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100: '_o_vs32_2sm',
776
+
777
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: '_o_vs16_ultra_1sm',
778
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: '_o_vs16_ultra_2sm',
779
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: '_o_vs32_ultra_1sm',
780
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: '_o_vs32_ultra_2sm',
781
+
782
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_1sm_nopf',
783
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: '_o_vs16_ultra_2sm_nopf',
784
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_1sm_nopf',
785
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: '_o_vs32_ultra_2sm_nopf',
786
+
787
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_1sm_tmapf',
788
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: '_o_vs16_ultra_2sm_tmapf',
789
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_1sm_tmapf',
790
+ KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: '_o_vs32_ultra_2sm_tmapf',
791
+
792
+ KernelScheduleType.Mxf8f6f4TmaWarpSpecializedCooperativeSm120: '_cooperative_q',
793
+ KernelScheduleType.Mxf8f6f4TmaWarpSpecializedPingpongSm120: '_pingpong_q',
794
+ KernelScheduleType.Nvf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs16',
795
+ KernelScheduleType.Nvf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs16',
796
+ KernelScheduleType.Mxf4TmaWarpSpecializedCooperativeSm120: '_cooperative_o_vs32',
797
+ KernelScheduleType.Mxf4TmaWarpSpecializedPingpongSm120: '_pingpong_o_vs32',
798
+
799
+ KernelScheduleType.F8f6f4SparseTmaWarpSpecializedCooperativeSm120: '_q',
800
+
801
+ KernelScheduleType.BlockwiseTmaWarpSpecializedCooperativeSm120: '_cooperative_q',
802
+ KernelScheduleType.BlockwiseTmaWarpSpecializedPingpongSm120: '_pingpong_q'
803
+ }
804
+
805
+ class EpilogueScheduleType(enum.Enum):
806
+ ScheduleAuto = enum_auto()
807
+ EpilogueTransposed = enum_auto()
808
+ NoSmemWarpSpecialized = enum_auto()
809
+ PtrArrayNoSmemWarpSpecialized = enum_auto()
810
+ NoSmemWarpSpecialized1Sm = enum_auto()
811
+ NoSmemWarpSpecialized2Sm = enum_auto()
812
+ FastF32NoSmemWarpSpecialized1Sm = enum_auto()
813
+ FastF32NoSmemWarpSpecialized2Sm = enum_auto()
814
+ BlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
815
+ BlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
816
+ PtrArrayNoSmemWarpSpecialized1Sm = enum_auto()
817
+ PtrArrayNoSmemWarpSpecialized2Sm = enum_auto()
818
+ PtrArrayFastF32NoSmemWarpSpecialized1Sm = enum_auto()
819
+ PtrArrayFastF32NoSmemWarpSpecialized2Sm = enum_auto()
820
+ PtrArrayBlockwiseNoSmemWarpSpecialized1Sm = enum_auto()
821
+ PtrArrayBlockwiseNoSmemWarpSpecialized2Sm = enum_auto()
822
+ TmaWarpSpecialized = enum_auto()
823
+ TmaWarpSpecializedCooperative = enum_auto()
824
+ TmaWarpSpecialized1Sm = enum_auto()
825
+ TmaWarpSpecialized2Sm = enum_auto()
826
+ PtrArrayTmaWarpSpecialized1Sm = enum_auto()
827
+ PtrArrayTmaWarpSpecialized2Sm = enum_auto()
828
+ PtrArrayTmaWarpSpecializedPingpong = enum_auto()
829
+ PtrArrayTmaWarpSpecializedCooperative = enum_auto()
830
+
831
+ #
832
+ EpilogueScheduleTag = {
833
+ EpilogueScheduleType.ScheduleAuto: 'cutlass::epilogue::collective::EpilogueScheduleAuto',
834
+ EpilogueScheduleType.EpilogueTransposed: 'cutlass::gemm::EpilogueTransposed',
835
+ EpilogueScheduleType.NoSmemWarpSpecialized: 'cutlass::epilogue::NoSmemWarpSpecialized',
836
+ EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized',
837
+ EpilogueScheduleType.NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::NoSmemWarpSpecialized1Sm',
838
+ EpilogueScheduleType.NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::NoSmemWarpSpecialized2Sm',
839
+ EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized1Sm',
840
+ EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::FastF32NoSmemWarpSpecialized2Sm',
841
+ EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized1Sm',
842
+ EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::BlockwiseNoSmemWarpSpecialized2Sm',
843
+ EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized1Sm',
844
+ EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayNoSmemWarpSpecialized2Sm',
845
+ EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized1Sm',
846
+ EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayFastF32NoSmemWarpSpecialized2Sm',
847
+ EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized1Sm',
848
+ EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayBlockwiseNoSmemWarpSpecialized2Sm',
849
+ EpilogueScheduleType.TmaWarpSpecialized: 'cutlass::epilogue::TmaWarpSpecialized',
850
+ EpilogueScheduleType.TmaWarpSpecializedCooperative: 'cutlass::epilogue::TmaWarpSpecializedCooperative',
851
+ EpilogueScheduleType.TmaWarpSpecialized1Sm: 'cutlass::epilogue::TmaWarpSpecialized1Sm',
852
+ EpilogueScheduleType.TmaWarpSpecialized2Sm: 'cutlass::epilogue::TmaWarpSpecialized2Sm',
853
+ EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm',
854
+ EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: 'cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm',
855
+ EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative',
856
+ EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: 'cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong',
857
+ }
858
+
859
+ #
860
+ EpilogueScheduleSuffixes = {
861
+ EpilogueScheduleType.ScheduleAuto: '',
862
+ EpilogueScheduleType.EpilogueTransposed: '',
863
+ EpilogueScheduleType.NoSmemWarpSpecialized: '_epi_nosmem',
864
+ EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized: '_epi_nosmem',
865
+ EpilogueScheduleType.NoSmemWarpSpecialized1Sm: '_epi_nosmem',
866
+ EpilogueScheduleType.NoSmemWarpSpecialized2Sm: '_epi_nosmem',
867
+ EpilogueScheduleType.FastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
868
+ EpilogueScheduleType.FastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
869
+ EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
870
+ EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
871
+ EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm: '_epi_nosmem',
872
+ EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm: '_epi_nosmem',
873
+ EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized1Sm: '_epi_nosmem_fastf32',
874
+ EpilogueScheduleType.PtrArrayFastF32NoSmemWarpSpecialized2Sm: '_epi_nosmem_fastf32',
875
+ EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm: '_epi_nosmem',
876
+ EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm: '_epi_nosmem',
877
+ EpilogueScheduleType.TmaWarpSpecialized: '_epi_tma',
878
+ EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma',
879
+ EpilogueScheduleType.TmaWarpSpecialized1Sm: '',
880
+ EpilogueScheduleType.TmaWarpSpecialized2Sm: '_epi_tma',
881
+ EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm: '',
882
+ EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm: '_epi_tma',
883
+ EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative: '_epi_tma',
884
+ EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong: '_epi_tma',
885
+ }
886
+
887
+ class EpilogueFunctor3x(enum.Enum):
888
+ LinearCombination = enum_auto()
889
+ LinearCombinationBlockScaleFactor = enum_auto()
890
+
891
+ #
892
+ EpilogueFunctor3xTag = {
893
+ EpilogueFunctor3x.LinearCombination: 'cutlass::epilogue::fusion::LinearCombination',
894
+ EpilogueFunctor3x.LinearCombinationBlockScaleFactor: 'cutlass::epilogue::fusion::LinCombBlockScaleFactor',
895
+ }
896
+
897
+ # TMA epilogues have certain alignment requirements as calculated in get_tma_alignment(data_type)
898
+ def is_tma_epilogue(epilogue_schedule_type):
899
+ return epilogue_schedule_type in [
900
+ EpilogueScheduleType.ScheduleAuto,
901
+ EpilogueScheduleType.TmaWarpSpecialized,
902
+ EpilogueScheduleType.TmaWarpSpecializedCooperative,
903
+ EpilogueScheduleType.TmaWarpSpecialized1Sm,
904
+ EpilogueScheduleType.TmaWarpSpecialized2Sm,
905
+ EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
906
+ EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
907
+ EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
908
+ EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
909
+ ]
910
+
911
+ def to_grouped_schedule(schedule, grouped):
912
+ if not grouped:
913
+ return schedule
914
+
915
+ group_schedule_map = {
916
+ # SM90
917
+ KernelScheduleType.TmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperative,
918
+ KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedCooperative,
919
+ KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecializedPingpong,
920
+ KernelScheduleType.TmaWarpSpecializedPingpong : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpong,
921
+ KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedCooperativeFP8FastAccum,
922
+ KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum : KernelScheduleType.PtrArrayTmaWarpSpecializedPingpongFP8FastAccum,
923
+ EpilogueScheduleType.TmaWarpSpecialized : EpilogueScheduleType.PtrArrayTmaWarpSpecializedPingpong,
924
+ EpilogueScheduleType.TmaWarpSpecializedCooperative : EpilogueScheduleType.PtrArrayTmaWarpSpecializedCooperative,
925
+ EpilogueScheduleType.NoSmemWarpSpecialized : EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized,
926
+ # SM100
927
+ KernelScheduleType.TmaWarpSpecialized1SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized1SmSm100,
928
+ KernelScheduleType.TmaWarpSpecialized2SmSm100: KernelScheduleType.PtrArrayTmaWarpSpecialized2SmSm100,
929
+ KernelScheduleType.Nvf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized1SmSm100,
930
+ KernelScheduleType.Nvf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayNvf4TmaWarpSpecialized2SmSm100,
931
+ KernelScheduleType.Mxf4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized1SmSm100,
932
+ KernelScheduleType.Mxf4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf4TmaWarpSpecialized2SmSm100,
933
+ KernelScheduleType.Mxf8f6f4TmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized1SmSm100,
934
+ KernelScheduleType.Mxf8f6f4TmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayMxf8f6f4TmaWarpSpecialized2SmSm100,
935
+ KernelScheduleType.BlockwiseTmaWarpSpecialized1SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized1SmSm100,
936
+ KernelScheduleType.BlockwiseTmaWarpSpecialized2SmSm100 : KernelScheduleType.PtrArrayBlockwiseTmaWarpSpecialized2SmSm100,
937
+ EpilogueScheduleType.TmaWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized1Sm,
938
+ EpilogueScheduleType.TmaWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayTmaWarpSpecialized2Sm,
939
+ EpilogueScheduleType.NoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized1Sm,
940
+ EpilogueScheduleType.NoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayNoSmemWarpSpecialized2Sm,
941
+ EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized1Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized1Sm,
942
+ EpilogueScheduleType.BlockwiseNoSmemWarpSpecialized2Sm: EpilogueScheduleType.PtrArrayBlockwiseNoSmemWarpSpecialized2Sm,
943
+ # SM103
944
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103,
945
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103,
946
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103,
947
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103,
948
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103DisablePrefetch,
949
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103DisablePrefetch,
950
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103DisablePrefetch,
951
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103DisablePrefetch,
952
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs16Sm103TmaPrefetch,
953
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs16Sm103TmaPrefetch,
954
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized1SmVs32Sm103TmaPrefetch,
955
+ KernelScheduleType.MxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch: KernelScheduleType.PtrArrayMxNvf4UltraTmaWarpSpecialized2SmVs32Sm103TmaPrefetch,
956
+ }
957
+
958
+ return group_schedule_map[schedule]
959
+
960
+ class TileSchedulerType(enum.Enum):
961
+ Default = enum_auto()
962
+ Persistent = enum_auto()
963
+ StreamK = enum_auto()
964
+ #
965
+ TileSchedulerTag = {
966
+ TileSchedulerType.Default: 'void',
967
+ TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler',
968
+ TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler',
969
+ }
970
+
971
+ #
972
+ TileSchedulerSuffixes = {
973
+ TileSchedulerType.Default: '',
974
+ TileSchedulerType.Persistent: '',
975
+ TileSchedulerType.StreamK: '_stream_k',
976
+ }
977
+
978
+ ###################################################################################################
979
+
980
+ #
981
+ class SideMode(enum.Enum):
982
+ Left = enum_auto()
983
+ Right = enum_auto()
984
+
985
+ #
986
+ SideModeTag = {
987
+ SideMode.Left: 'cutlass::SideMode::kLeft',
988
+ SideMode.Right: 'cutlass::SideMode::kRight'
989
+ }
990
+
991
+ #
992
+ ShortSideModeNames = {
993
+ SideMode.Left: 'ls',
994
+ SideMode.Right: 'rs'
995
+ }
996
+
997
+ ###################################################################################################
998
+
999
+ #
1000
+ class FillMode(enum.Enum):
1001
+ Lower = enum_auto()
1002
+ Upper = enum_auto()
1003
+
1004
+ #
1005
+ FillModeTag = {
1006
+ FillMode.Lower: 'cutlass::FillMode::kLower',
1007
+ FillMode.Upper: 'cutlass::FillMode::kUpper'
1008
+ }
1009
+
1010
+ #
1011
+ ShortFillModeNames = {
1012
+ FillMode.Lower: 'l',
1013
+ FillMode.Upper: 'u'
1014
+ }
1015
+
1016
+ ###################################################################################################
1017
+
1018
+ #
1019
+ class DiagType(enum.Enum):
1020
+ NonUnit = enum_auto()
1021
+ Unit = enum_auto()
1022
+
1023
+ #
1024
+ DiagTypeTag = {
1025
+ DiagType.NonUnit: 'cutlass::DiagType::kNonUnit',
1026
+ DiagType.Unit: 'cutlass::DiagType::kUnit'
1027
+ }
1028
+
1029
+ #
1030
+ ShortDiagTypeNames = {
1031
+ DiagType.NonUnit: 'nu',
1032
+ DiagType.Unit: 'un'
1033
+ }
1034
+
1035
+ ###################################################################################################
1036
+
1037
+ #
1038
+ class OpcodeClass(enum.Enum):
1039
+ Simt = enum_auto()
1040
+ TensorOp = enum_auto()
1041
+ WmmaTensorOp = enum_auto()
1042
+ SparseTensorOp = enum_auto()
1043
+ BlockScaledTensorOp = enum_auto()
1044
+
1045
+
1046
+ OpcodeClassNames = {
1047
+ OpcodeClass.Simt: 'simt',
1048
+ OpcodeClass.TensorOp: 'tensorop',
1049
+ OpcodeClass.WmmaTensorOp: 'wmma_tensorop',
1050
+ OpcodeClass.SparseTensorOp: 'sptensorop',
1051
+ OpcodeClass.BlockScaledTensorOp: 'bstensorop'
1052
+ }
1053
+
1054
+ OpcodeClassTag = {
1055
+ OpcodeClass.Simt: 'cutlass::arch::OpClassSimt',
1056
+ OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp',
1057
+ OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp',
1058
+ OpcodeClass.SparseTensorOp: 'cutlass::arch::OpClassSparseTensorOp',
1059
+ OpcodeClass.BlockScaledTensorOp: 'cutlass::arch::OpClassBlockScaledTensorOp'
1060
+ }
1061
+
1062
+ ###################################################################################################
1063
+
1064
+ #
1065
+ class OperationKind(enum.Enum):
1066
+ Gemm = enum_auto()
1067
+ RankK = enum_auto()
1068
+ Rank2K = enum_auto()
1069
+ Trmm = enum_auto()
1070
+ Symm = enum_auto()
1071
+ Conv2d = enum_auto()
1072
+ Conv3d = enum_auto()
1073
+
1074
+ #
1075
+ OperationKindNames = {
1076
+ OperationKind.Gemm: 'gemm'
1077
+ , OperationKind.RankK: 'rank_k'
1078
+ , OperationKind.Rank2K: 'rank_2k'
1079
+ , OperationKind.Trmm: 'trmm'
1080
+ , OperationKind.Symm: 'symm'
1081
+ , OperationKind.Conv2d: 'conv2d'
1082
+ , OperationKind.Conv3d: 'conv3d'
1083
+ }
1084
+
1085
+ #
1086
+ class Target(enum.Enum):
1087
+ library = enum_auto()
1088
+ #
1089
+ ArchitectureNames = {
1090
+ 50: 'maxwell',
1091
+ 60: 'pascal',
1092
+ 61: 'pascal',
1093
+ 70: 'volta',
1094
+ 75: 'turing',
1095
+ 80: 'ampere',
1096
+ 89: 'ada',
1097
+ 90: 'hopper'
1098
+ }
1099
+
1100
+ #
1101
+ SharedMemPerCC = {
1102
+ 70: 96, # 96KB of SMEM
1103
+ 72: 96, # 96KB of SMEM
1104
+ 75: 64, # 64KB of SMEM
1105
+ 80: 163, # 163KB of SMEM - 1KB reserved for the driver
1106
+ 86: 99, # 99KB of SMEM - 1KB reserved for the driver
1107
+ 87: 163, # 163KB of SMEM - 1KB reserved for the driver
1108
+ 89: 99, # 99KB of SMEM - 1KB reserved for the driver
1109
+ 90: 227, # 227KB of SMEM - 1KB reserved for the driver
1110
+ 100: 227, # 227KB of SMEM - 1KB reserved for the driver
1111
+ }
1112
+
1113
+ ###################################################################################################
1114
+
1115
+ #
1116
+ def SubstituteTemplate(template, values):
1117
+ text = template
1118
+ changed = True
1119
+ while changed:
1120
+ changed = False
1121
+ for key, value in values.items():
1122
+ regex = "\\$\\{%s\\}" % key
1123
+ newtext = re.sub(regex, value, text)
1124
+ if newtext != text:
1125
+ changed = True
1126
+ text = newtext
1127
+ return text
1128
+
1129
+ ###################################################################################################
1130
+
1131
+ #
1132
+ class GemmKind(enum.Enum):
1133
+ Gemm = enum_auto()
1134
+ Sparse = enum_auto()
1135
+ Universal = enum_auto()
1136
+ Universal3x = enum_auto()
1137
+ SparseUniversal3x = enum_auto()
1138
+ PlanarComplex = enum_auto()
1139
+ PlanarComplexArray = enum_auto()
1140
+ Grouped = enum_auto()
1141
+ BlockScaledUniversal3x = enum_auto()
1142
+ GroupedUniversal3x = enum_auto()
1143
+ GroupedBlockScaledUniversal3x = enum_auto()
1144
+ BlockwiseUniversal3x = enum_auto()
1145
+ GroupedBlockwiseUniversal3x = enum_auto()
1146
+
1147
+ #
1148
+ GemmKindNames = {
1149
+ GemmKind.Gemm: "gemm",
1150
+ GemmKind.Sparse: "spgemm",
1151
+ GemmKind.Universal: "gemm",
1152
+ GemmKind.Universal3x: "gemm",
1153
+ GemmKind.SparseUniversal3x: "spgemm",
1154
+ GemmKind.PlanarComplex: "gemm_planar_complex",
1155
+ GemmKind.PlanarComplexArray: "gemm_planar_complex_array",
1156
+ GemmKind.Grouped: "gemm_grouped",
1157
+ GemmKind.BlockScaledUniversal3x: "gemm",
1158
+ GemmKind.GroupedUniversal3x: "gemm_grouped",
1159
+ GemmKind.GroupedBlockScaledUniversal3x: "gemm_grouped",
1160
+ GemmKind.BlockwiseUniversal3x: "gemm",
1161
+ GemmKind.GroupedBlockwiseUniversal3x: "gemm_grouped"
1162
+ }
1163
+
1164
+ #
1165
+ class RankKKind(enum.Enum):
1166
+ Universal = enum_auto()
1167
+
1168
+ #
1169
+ RankKKindNames = {
1170
+ RankKKind.Universal: "rank_k"
1171
+ }
1172
+
1173
+ #
1174
+ class TrmmKind(enum.Enum):
1175
+ Universal = enum_auto()
1176
+
1177
+ #
1178
+ TrmmKindNames = {
1179
+ TrmmKind.Universal: "trmm"
1180
+ }
1181
+
1182
+ #
1183
+ class SymmKind(enum.Enum):
1184
+ Universal = enum_auto()
1185
+
1186
+ #
1187
+ SymmKindNames = {
1188
+ SymmKind.Universal: "symm"
1189
+ }
1190
+
1191
+ #
1192
+ class EpilogueFunctor(enum.Enum):
1193
+ LinearCombination = enum_auto()
1194
+ LinearCombinationClamp = enum_auto()
1195
+
1196
+ #
1197
+ EpilogueFunctorTag = {
1198
+ EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination',
1199
+ EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp',
1200
+ }
1201
+
1202
+ #
1203
+ class MixedInputMode(enum.Enum):
1204
+ ConvertOnly = enum_auto()
1205
+ ScaleOnly = enum_auto()
1206
+ ScaleWithZeroPoint = enum_auto()
1207
+
1208
+ #
1209
+ class SwizzlingFunctor(enum.Enum):
1210
+ Identity1 = enum_auto()
1211
+ Identity2 = enum_auto()
1212
+ Identity4 = enum_auto()
1213
+ Identity8 = enum_auto()
1214
+ Horizontal = enum_auto()
1215
+ StridedDgradIdentity1 = enum_auto()
1216
+ StridedDgradIdentity4 = enum_auto()
1217
+ StridedDgradHorizontal = enum_auto()
1218
+ StreamK = enum_auto()
1219
+
1220
+ #
1221
+ SwizzlingFunctorTag = {
1222
+ SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>',
1223
+ SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>',
1224
+ SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>',
1225
+ SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>',
1226
+ SwizzlingFunctor.Horizontal: 'cutlass::gemm::threadblock::GemmHorizontalThreadblockSwizzle',
1227
+ SwizzlingFunctor.StridedDgradIdentity1: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>',
1228
+ SwizzlingFunctor.StridedDgradIdentity4: 'cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>',
1229
+ SwizzlingFunctor.StridedDgradHorizontal: 'cutlass::conv::threadblock::StridedDgradHorizontalThreadblockSwizzle',
1230
+ SwizzlingFunctor.StreamK: 'cutlass::gemm::threadblock::ThreadblockSwizzleStreamK',
1231
+ }
1232
+
1233
+ #
1234
+ class GroupScheduleMode(enum.Enum):
1235
+ Device = enum_auto(),
1236
+ Host = enum_auto()
1237
+
1238
+ #
1239
+ GroupScheduleModeTag = {
1240
+ GroupScheduleMode.Device: 'cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly',
1241
+ GroupScheduleMode.Host: 'cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute'
1242
+ }
1243
+
1244
+ #
1245
+ ShortGroupScheduleModeNames = {
1246
+ GroupScheduleMode.Device: 'Device',
1247
+ GroupScheduleMode.Host: 'Host'
1248
+ }
1249
+
1250
+ ###################################################################################################
1251
+
1252
+ #
1253
+ class ConvKind(enum.IntEnum):
1254
+ Fprop = 0
1255
+ Dgrad = 1
1256
+ Wgrad = 2
1257
+
1258
+ #
1259
+ ConvKindTag = {
1260
+ ConvKind.Fprop: 'cutlass::conv::Operator::kFprop',
1261
+ ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad',
1262
+ ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad'
1263
+ }
1264
+
1265
+ ConvKindNames = {
1266
+ ConvKind.Fprop: 'fprop',
1267
+ ConvKind.Dgrad: 'dgrad',
1268
+ ConvKind.Wgrad: 'wgrad',
1269
+ }
1270
+
1271
+ class ConvMode(enum.IntEnum):
1272
+ CrossCorrelation = 0
1273
+ Convolution = 1
1274
+
1275
+ #
1276
+ class IteratorAlgorithm(enum.Enum):
1277
+ Analytic = 0
1278
+ Optimized = 1
1279
+ FixedChannels = 2
1280
+ FewChannels = 3
1281
+ FixedStrideDilation = 4
1282
+
1283
+ #
1284
+ IteratorAlgorithmTag = {
1285
+ IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic',
1286
+ IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized',
1287
+ IteratorAlgorithm.FixedChannels: 'cutlass::conv::IteratorAlgorithm::kFixedChannels',
1288
+ IteratorAlgorithm.FewChannels: 'cutlass::conv::IteratorAlgorithm::kFewChannels',
1289
+ IteratorAlgorithm.FixedStrideDilation: 'cutlass::conv::IteratorAlgorithm::kFixedStrideDilation'
1290
+ }
1291
+
1292
+ IteratorAlgorithmNames = {
1293
+ IteratorAlgorithm.Analytic: 'analytic',
1294
+ IteratorAlgorithm.Optimized: 'optimized',
1295
+ IteratorAlgorithm.FixedChannels: 'fixed_channels',
1296
+ IteratorAlgorithm.FewChannels: 'few_channels',
1297
+ IteratorAlgorithm.FixedStrideDilation: 'fixed_stride_dilation'
1298
+ }
1299
+
1300
+ #
1301
+ class StrideSupport(enum.Enum):
1302
+ Strided = 0
1303
+ Unity = 1
1304
+ Fixed = 2
1305
+
1306
+ #
1307
+ StrideSupportTag = {
1308
+ StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided',
1309
+ StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity',
1310
+ StrideSupport.Fixed: 'cutlass::conv::StrideSupport::kFixed'
1311
+ }
1312
+
1313
+ StrideSupportNames = {
1314
+ StrideSupport.Strided: '',
1315
+ StrideSupport.Unity: 'unity_stride',
1316
+ StrideSupport.Fixed: 'fixed_stride'
1317
+ }
1318
+
1319
+ #
1320
+ class GroupMode(enum.Enum):
1321
+ NoneGroup = enum_auto() # dense conv (G=1)
1322
+ SingleGroup = enum_auto() # grouped convolution (single group per CTA)
1323
+ MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA)
1324
+ Depthwise = enum_auto() # Depthwise convolution ( C=K=G )
1325
+
1326
+ #
1327
+ GroupModeTag = {
1328
+ GroupMode.NoneGroup: 'cutlass::conv::GroupMode::kNone',
1329
+ GroupMode.SingleGroup: 'cutlass::conv::GroupMode::kSingleGroup',
1330
+ GroupMode.MultipleGroup: 'cutlass::conv::GroupMode::kMultipleGroup',
1331
+ GroupMode.Depthwise: 'cutlass::conv::GroupMode::kDepthwise',
1332
+ }
1333
+
1334
+ GroupModeNames = {
1335
+ GroupMode.NoneGroup: '',
1336
+ GroupMode.SingleGroup: 'single_group',
1337
+ GroupMode.MultipleGroup: 'multiple_group',
1338
+ GroupMode.Depthwise: 'depthwise',
1339
+ }
1340
+
1341
+ DynamicClusterShape = [0, 0, 1]
1342
+
1343
+ ###################################################################################################
1344
+
1345
+ #
1346
+ class MathInstruction:
1347
+ def __init__(self,
1348
+ instruction_shape, \
1349
+ element_a, element_b, element_accumulator, \
1350
+ opcode_class, math_operation = MathOperation.multiply_add \
1351
+ , element_scale_factor = None
1352
+ ):
1353
+
1354
+ self.instruction_shape = instruction_shape
1355
+ self.element_a = element_a
1356
+ self.element_b = element_b
1357
+ self.element_accumulator = element_accumulator
1358
+ self.opcode_class = opcode_class
1359
+ self.math_operation = math_operation
1360
+ self.element_scale_factor = element_scale_factor
1361
+
1362
+ #
1363
+ class TileDescription:
1364
+
1365
+ def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute, cluster_shape = [1,1,1], explicit_vector_sizes = None):
1366
+ self.threadblock_shape = threadblock_shape
1367
+ self.tile_shape = threadblock_shape
1368
+ self.stages = stages
1369
+ self.warp_count = warp_count
1370
+ self.math_instruction = math_instruction
1371
+ self.minimum_compute_capability = min_compute
1372
+ self.maximum_compute_capability = max_compute
1373
+ self.cluster_shape = cluster_shape
1374
+ self.explicit_vector_sizes = explicit_vector_sizes
1375
+
1376
+ def procedural_name(self):
1377
+ if self.minimum_compute_capability >= 90:
1378
+ return "{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{s}".format(
1379
+ tbm = self.threadblock_shape[0],
1380
+ tbn = self.threadblock_shape[1],
1381
+ tbk = self.threadblock_shape[2],
1382
+ cm = self.cluster_shape[0],
1383
+ cn = self.cluster_shape[1],
1384
+ ck = self.cluster_shape[2],
1385
+ s = self.stages)
1386
+ else:
1387
+ return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
1388
+
1389
+ #
1390
+ class Direct2dConvFixedStrideDilationTileDescription:
1391
+ def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
1392
+ self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
1393
+ self.threadblock_output_shape = threadblock_output_shape
1394
+ self.filter_shape = filter_shape
1395
+ self.stages = stages
1396
+ self.warp_count = warp_count
1397
+ self.stride = stride
1398
+ self.dilation = dilation
1399
+ self.math_instruction = math_instruction
1400
+ self.minimum_compute_capability = min_compute
1401
+ self.maximum_compute_capability = max_compute
1402
+
1403
+ def procedural_name(self):
1404
+ str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
1405
+ self.threadblock_shape[1],
1406
+ self.threadblock_shape[2],
1407
+ self.threadblock_output_shape[0],
1408
+ self.threadblock_output_shape[1],
1409
+ self.threadblock_output_shape[2],
1410
+ self.threadblock_output_shape[3],
1411
+ self.stages,
1412
+ self.filter_shape[0],
1413
+ self.filter_shape[1])
1414
+ # Fixed Strided and dilation
1415
+ if self.stride != [-1, -1] and self.dilation != [-1, -1]:
1416
+ str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
1417
+ self.stride[1],
1418
+ self.dilation[0],
1419
+ self.dilation[1])
1420
+ return str_name
1421
+
1422
+ #
1423
+ class Direct2dConvFixedStrideDilationTileDescription:
1424
+ def __init__(self, threadblock_output_shape, filter_shape, stages, stride, dilation, warp_count, math_instruction, min_compute, max_compute):
1425
+ self.threadblock_shape = [threadblock_output_shape[0]*threadblock_output_shape[1]*threadblock_output_shape[2], threadblock_output_shape[3], filter_shape[0]*filter_shape[1]]
1426
+ self.threadblock_output_shape = threadblock_output_shape
1427
+ self.filter_shape = filter_shape
1428
+ self.stages = stages
1429
+ self.warp_count = warp_count
1430
+ self.stride = stride
1431
+ self.dilation = dilation
1432
+ self.math_instruction = math_instruction
1433
+ self.minimum_compute_capability = min_compute
1434
+ self.maximum_compute_capability = max_compute
1435
+
1436
+ def procedural_name(self):
1437
+ str_name = "%dx%dx%d_%dx%dx%dx%d_%d_filter%dx%d" % (self.threadblock_shape[0],
1438
+ self.threadblock_shape[1],
1439
+ self.threadblock_shape[2],
1440
+ self.threadblock_output_shape[0],
1441
+ self.threadblock_output_shape[1],
1442
+ self.threadblock_output_shape[2],
1443
+ self.threadblock_output_shape[3],
1444
+ self.stages,
1445
+ self.filter_shape[0],
1446
+ self.filter_shape[1])
1447
+ # Fixed Strided and dilation
1448
+ if self.stride != [-1, -1] and self.dilation != [-1, -1]:
1449
+ str_name += "_stride%dx%d_dilation%dx%d" % (self.stride[0],
1450
+ self.stride[1],
1451
+ self.dilation[0],
1452
+ self.dilation[1])
1453
+ return str_name
1454
+
1455
+ #
1456
+ class TensorDescription:
1457
+ def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none):
1458
+ self.element = element
1459
+ self.layout = layout
1460
+ self.alignment = alignment
1461
+ self.complex_transform = complex_transform
1462
+
1463
+ #
1464
+ class SymmetricTensorDescription:
1465
+ def __init__(self, element, layout, fill_mode, alignment = 1, complex_transform = ComplexTransform.none, side_mode = SideMode.Left):
1466
+ self.element = element
1467
+ self.layout = layout
1468
+ self.fill_mode = fill_mode
1469
+ self.alignment = alignment
1470
+ self.complex_transform = complex_transform
1471
+ self.side_mode = side_mode
1472
+
1473
+ #
1474
+ class TriangularTensorDescription:
1475
+ def __init__(self, element, layout, side_mode, fill_mode, diag_type, alignment = 1, complex_transform = ComplexTransform.none):
1476
+ self.element = element
1477
+ self.layout = layout
1478
+ self.side_mode = side_mode
1479
+ self.fill_mode = fill_mode
1480
+ self.diag_type = diag_type
1481
+ self.alignment = alignment
1482
+ self.complex_transform = complex_transform
1483
+
1484
+ #
1485
+ def CalculateSmemUsage(operation):
1486
+ cta_shape = operation.tile_description.threadblock_shape
1487
+ stages = operation.tile_description.stages
1488
+
1489
+ if operation.operation_kind == OperationKind.Gemm and operation.gemm_kind == GemmKind.Sparse:
1490
+ # Elements represented by 8 bits of metadata (based on 4:8, 2:4 or 1:2 sparsity)
1491
+ if DataTypeSize[operation.A.element] == 32:
1492
+ elements_per_8b_md = 2
1493
+ elif DataTypeSize[operation.A.element] == 4:
1494
+ elements_per_8b_md = 8
1495
+ else:
1496
+ elements_per_8b_md = 4
1497
+
1498
+ smem_per_stage = DataTypeSize[operation.A.element] * cta_shape[0] * (cta_shape[2] // 2) // 8 + \
1499
+ DataTypeSize[operation.B.element] * cta_shape[1] * cta_shape[2] // 8 + \
1500
+ cta_shape[0] * (cta_shape[2] // 2) // elements_per_8b_md
1501
+ else:
1502
+ # Few BLAS3 operations only have A tensor
1503
+ data_type_size_a = DataTypeSize[operation.A.element]
1504
+ data_type_size_b = DataTypeSize[operation.A.element]
1505
+ if operation.is_mixed_input():
1506
+ data_type_size_b = DataTypeSize[operation.B.element]
1507
+
1508
+ smem_per_stage = data_type_size_a * cta_shape[0] * cta_shape[2] // 8 + \
1509
+ data_type_size_b * cta_shape[1] * cta_shape[2] // 8
1510
+
1511
+ smem_usage = smem_per_stage * stages
1512
+ return (smem_usage >> 10)
1513
+
1514
+
1515
+ class GemmUniversalMode(enum.IntEnum):
1516
+ """
1517
+ Types corresponding to GemmUniversalMode
1518
+ """
1519
+ Gemm = 0
1520
+ GemmSplitKParallel = 1
1521
+ Batched = 2
1522
+ Array = 3
1523
+
1524
+
1525
+ class SplitKMode(enum.IntEnum):
1526
+ """
1527
+ Types corresponding to SplitKMode
1528
+ """
1529
+ NoneSplitK = 0
1530
+ Serial = 1
1531
+ Parallel = 2
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/manifest.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for filtering CUTLASS library kernels and emitting library intitialization
35
+ and building code
36
+ """
37
+
38
+ import enum
39
+ import logging
40
+ import os.path
41
+ import shutil
42
+
43
+ try:
44
+ import builtins
45
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
+ raise ImportError("Disabling attempt to import cutlass_library")
47
+ from cutlass_library.library import *
48
+ from cutlass_library.gemm_operation import *
49
+ from cutlass_library.rank_k_operation import *
50
+ from cutlass_library.rank_2k_operation import *
51
+ from cutlass_library.trmm_operation import *
52
+ from cutlass_library.symm_operation import *
53
+ from cutlass_library.conv2d_operation import *
54
+ from cutlass_library.conv3d_operation import *
55
+ except ImportError:
56
+ from library import *
57
+ from gemm_operation import *
58
+ from rank_k_operation import *
59
+ from rank_2k_operation import *
60
+ from trmm_operation import *
61
+ from symm_operation import *
62
+ from conv2d_operation import *
63
+ from conv3d_operation import *
64
+
65
+ ###################################################################################################
66
+ _LOGGER = logging.getLogger(__name__)
67
+
68
+
69
+ class EmitOperationKindAll:
70
+ """
71
+ Emit the OperationKind-level CUTLASS library initialization code.
72
+ The code is generated in the {generated_path}/{operation_kind} directory
73
+ (e.g., tools/library/generated/gemm in the build directory,
74
+ for OperationKind=Gemm), in the all_{operation_kind}_operations.cu file
75
+ (e.g., all_gemm_operations.cu for OperationKind=Gemm).
76
+ That file declares several functions in namespace cutlass::library.
77
+ The functions all have this form,
78
+
79
+ void initialize_{configuration_name}(Manifest& manifest);
80
+
81
+ The file also _defines_ the following function in that namespace.
82
+
83
+ void initialize_all_{operation_kind}_operations(Manifest& manifest);
84
+
85
+ That function calls all of the functions declared in this file.
86
+ Those functions are defined in subdirectories
87
+ (which this class does not create).
88
+ """
89
+
90
+ def __init__(self, generated_path, kind, args):
91
+ self.generated_path = generated_path
92
+ self.kind = kind
93
+ self.args = args
94
+
95
+ self.header_template ="""
96
+ /*
97
+ Generated by manifest.py - Do not edit.
98
+ */
99
+
100
+ #include "cutlass/cutlass.h"
101
+ #include "cutlass/library/library.h"
102
+ #include "cutlass/library/manifest.h"
103
+
104
+ namespace cutlass {
105
+ namespace library {
106
+
107
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
108
+
109
+ """
110
+
111
+ self.entry_template = """
112
+
113
+ //
114
+ // Entry point to construct operations
115
+ //
116
+ void initialize_all_${operation_name}_operations(Manifest &manifest) {
117
+ """
118
+ self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n"
119
+ self.configuration_template =" initialize_${configuration_name}(manifest);\n"
120
+
121
+ self.epilogue_template ="""}
122
+
123
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
124
+
125
+ } // namespace library
126
+ } // namespace cutlass
127
+
128
+ """
129
+
130
+ #
131
+ def __enter__(self):
132
+ _LOGGER.debug("*** EmitOperationKindAll::__enter__")
133
+
134
+ self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind])
135
+ _LOGGER.debug('*** operation_path (directory to create): ' +
136
+ str(self.operation_path));
137
+ os.makedirs(self.operation_path, exist_ok=True)
138
+
139
+ self.top_level_path = os.path.join(self.operation_path, f"all_{OperationKindNames[self.kind]}_operations.cu")
140
+ _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}")
141
+
142
+ self.top_level_file = open(self.top_level_path, "w")
143
+ self.top_level_file.write(self.header_template)
144
+
145
+ self.source_files = [self.top_level_path,]
146
+
147
+ self.configurations = []
148
+
149
+ return self
150
+
151
+ #
152
+ def emit(self, operations):
153
+ _LOGGER.debug('*** EmitOperationKindAll::emit')
154
+ _LOGGER.debug(f"*** len(operations): {len(operations)}")
155
+ _LOGGER.debug(f"*** min_cc list: {sorted(min_cc for min_cc, _ in operations.items())}")
156
+
157
+ for min_cc, configurations in sorted(operations.items()):
158
+ _LOGGER.debug(f"*** min_cc={min_cc}")
159
+
160
+ for configuration_name, _ in configurations.items():
161
+ _LOGGER.debug(f"*** configuration_name={configuration_name}")
162
+ self.configurations.append(configuration_name)
163
+ self.top_level_file.write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} ))
164
+
165
+ #
166
+ def __exit__(self, exception_type, exception_value, traceback):
167
+ _LOGGER.debug("*** EmitOperationKindAll::__exit__")
168
+
169
+ self.top_level_file.write(SubstituteTemplate(self.entry_template, {'operation_name': OperationKindNames[self.kind]}))
170
+
171
+ for configuration_name in self.configurations:
172
+ self.top_level_file.write(SubstituteTemplate(self.configuration_template, {'configuration_name': configuration_name}))
173
+
174
+ self.top_level_file.write(self.epilogue_template)
175
+ self.top_level_file.close()
176
+
177
+
178
+ class EmitOperationKindLibrary:
179
+ """
180
+ Emit the CUTLASS library initialization code for each OperationKind.
181
+ The code is generated in the directory
182
+ {generated_path}/{operation_kind}/{min_cc}
183
+ (e.g., tools/library/generated/gemm/90 in the build directory,
184
+ for min_cc=90 and OperationKind=Gemm), in the file
185
+ all_sm{min_cc}_{operation_kind}_operations.cu
186
+ (e.g., all_sm90_gemm_operations.cu for min_cc=90 and OperationKind=Gemm).
187
+ The min_cc variable here indicates the minimum GPU architecture version
188
+ that the things to be initialized require.
189
+ For example, min_cc=90 indicates sm90.
190
+
191
+ That file declares several functions in namespace cutlass::library.
192
+ The functions all have this form,
193
+
194
+ void initialize_all_sm{min_cc}_{subclass_name}_{extended_name}_operations(Manifest& manifest);
195
+
196
+ where extended_name is operation.extended_name() for all the operations
197
+ given to the emit method (which see below). (All operations for a given
198
+ configuration_name are guaranteed to have the same extended_name().)
199
+
200
+ The file also _defines_ the following function in that namespace.
201
+
202
+ void initialize_all_sm{min_cc}__{operation_kind}_operations(Manifest& manifest);
203
+
204
+ That function calls all of the functions declared in this file.
205
+ Those functions are defined in subdirectories.
206
+ The mapping from OperationKind to emitter handles the details
207
+ of what happens in each of those subdirectories.
208
+ """
209
+
210
+ def __init__(self, generated_path, min_cc, kind, args):
211
+ self.generated_path = generated_path
212
+ self.min_cc = min_cc
213
+ self.kind = kind
214
+ self.args = args
215
+ self.emitters = {
216
+ OperationKind.Gemm: EmitGemmConfigurationLibrary,
217
+ OperationKind.Conv2d: EmitConv2dConfigurationLibrary,
218
+ OperationKind.Conv3d: EmitConv3dConfigurationLibrary,
219
+ OperationKind.RankK: EmitRankKConfigurationLibrary,
220
+ OperationKind.Rank2K: EmitRank2KConfigurationLibrary,
221
+ OperationKind.Trmm: EmitTrmmConfigurationLibrary,
222
+ OperationKind.Symm: EmitSymmConfigurationLibrary
223
+ }
224
+
225
+ self.header_template ="""
226
+ /*
227
+ Generated by manifest.py - Do not edit.
228
+ */
229
+
230
+ #include "cutlass/cutlass.h"
231
+ #include "cutlass/library/library.h"
232
+ #include "cutlass/library/manifest.h"
233
+
234
+ namespace cutlass {
235
+ namespace library {
236
+
237
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
238
+
239
+ """
240
+ self.entry_template = """
241
+
242
+ //
243
+ // Entry point to construct operations
244
+ //
245
+ void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest) {
246
+ """
247
+ self.configuration_prototype_template = "void initialize_${configuration_name}(Manifest &manifest);\n"
248
+ self.configuration_template = " initialize_${configuration_name}(manifest);\n"
249
+ self.subclass_call_template = " initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(manifest);\n"
250
+ self.subclass_prototype_template = "void initialize_all_sm${min_cc}_${subclass_name}_${operation_name}_operations(Manifest &manifest);\n"
251
+ self.epilogue_template ="""}
252
+
253
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
254
+
255
+ } // namespace library
256
+ } // namespace cutlass
257
+
258
+ """
259
+
260
+ #
261
+ def __enter__(self):
262
+ _LOGGER.debug("*** EmitOperationKindLibrary::__enter__")
263
+ _LOGGER.debug(f"*** generated_path: {str(self.generated_path)}")
264
+ _LOGGER.debug(f"*** OperationKindNames[kind]: {OperationKindNames[self.kind]}")
265
+ _LOGGER.debug(f"*** min_cc: {self.min_cc}")
266
+
267
+ self.operation_path = os.path.join(self.generated_path, OperationKindNames[self.kind], str(self.min_cc))
268
+ _LOGGER.debug(f"*** operation_path (directory to make): {str(self.operation_path)}")
269
+ os.makedirs(self.operation_path)
270
+
271
+ self.top_level_path = os.path.join(self.operation_path, f"all_sm{self.min_cc}_{OperationKindNames[self.kind]}_operations.cu")
272
+ _LOGGER.debug(f"*** top_level_path (file to write): {str(self.top_level_path)}")
273
+
274
+ self.top_level_file = open(self.top_level_path, "w")
275
+ self.top_level_file.write(self.header_template)
276
+
277
+ self.source_files = {}
278
+
279
+ # Each {operation_kind x cc} combination is further decomposed by the instruction
280
+ # types used. This dictionary used to track the file handles for the top-level
281
+ # files of each subclass
282
+ self.subclass_files = {}
283
+
284
+ # Configurations in each sub class
285
+ self.subclass_configurations = {}
286
+
287
+ return self
288
+
289
+ #
290
+ def emit(self, configuration_name, operations):
291
+ _LOGGER.debug("*** EmitOperationKindLibrary::emit")
292
+ _LOGGER.debug(f"*** configuration_name: {configuration_name}")
293
+
294
+ assert len(operations) > 0
295
+
296
+ # The extended name for all operations of a given configuration_name is guaranteed
297
+ # to be the same because extended_name() is used in defining configuration_name. Thus,
298
+ # we can safely use the extended_name() of the first operation.
299
+ extended_name = operations[0].extended_name()
300
+ _LOGGER.debug('*** extended_name (for all ops): ' + extended_name)
301
+
302
+ # Create a directory for operations with this subclass if it does not exist
303
+ if extended_name not in self.subclass_files:
304
+ subclass_path = os.path.join(self.operation_path, extended_name)
305
+ _LOGGER.debug(f"*** subclass_path: {str(subclass_path)}")
306
+ os.mkdir(subclass_path)
307
+
308
+ self.subclass_configurations[extended_name] = []
309
+
310
+ # Open a new top-level file for this sub class
311
+ subclass_top_level_path = os.path.join(
312
+ subclass_path, f"all_sm{self.min_cc}_{extended_name}_{OperationKindNames[self.kind]}_operations.cu")
313
+ _LOGGER.debug('*** subclass_top_level_path (min_cc, extended_name, ' +
314
+ 'OperationKind): ' + str(subclass_top_level_path))
315
+
316
+ self.subclass_files[extended_name] = open(subclass_top_level_path, "w")
317
+ self.subclass_files[extended_name].write(self.header_template)
318
+
319
+ self.source_files[extended_name] = [subclass_top_level_path]
320
+
321
+ subclass_dir = os.path.dirname(self.subclass_files[extended_name].name)
322
+ _LOGGER.debug('*** subclass_dir: ' + str(subclass_dir))
323
+
324
+ with self.emitters[self.kind](subclass_dir, configuration_name) as configuration_emitter:
325
+ for operation in operations:
326
+ configuration_emitter.emit(operation)
327
+
328
+ _LOGGER.debug('*** configuration_emitter.configuration_path: ' +
329
+ str(configuration_emitter.configuration_path))
330
+ self.source_files[extended_name].append(configuration_emitter.configuration_path)
331
+
332
+ self.subclass_configurations[extended_name].append(configuration_name)
333
+ self.subclass_files[extended_name].write(SubstituteTemplate(self.configuration_prototype_template, {'configuration_name': configuration_name} ))
334
+
335
+ #
336
+ def __exit__(self, exception_type, exception_value, traceback):
337
+ _LOGGER.debug("*** EmitOperationKindLibrary::__exit__")
338
+ for subclass_name, subclass_file in sorted(self.subclass_files.items()):
339
+ subclass_cfg = {
340
+ 'min_cc': str(self.min_cc),
341
+ 'subclass_name': subclass_name,
342
+ 'operation_name': OperationKindNames[self.kind]
343
+ }
344
+ self.top_level_file.write(SubstituteTemplate(self.subclass_prototype_template, subclass_cfg))
345
+
346
+ self.top_level_file.write(
347
+ SubstituteTemplate(self.entry_template, {
348
+ 'min_cc': str(self.min_cc),
349
+ 'subclass_name': '',
350
+ 'operation_name': OperationKindNames[self.kind]
351
+ }))
352
+
353
+ # Finish and close all subclass files
354
+ for subclass_name, subclass_file in sorted(self.subclass_files.items()):
355
+ subclass_cfg = {
356
+ 'min_cc': str(self.min_cc),
357
+ 'subclass_name': subclass_name,
358
+ 'operation_name': OperationKindNames[self.kind]
359
+ }
360
+ subclass_file.write(SubstituteTemplate(self.entry_template, subclass_cfg))
361
+
362
+ for configuration in self.subclass_configurations[subclass_name]:
363
+ subclass_file.write(
364
+ SubstituteTemplate(self.configuration_template, {
365
+ 'configuration_name': configuration
366
+ }))
367
+
368
+ subclass_file.write(self.epilogue_template)
369
+ subclass_file.close()
370
+
371
+ # Write the call to initialize_all for this subclass to the top-level file
372
+ self.top_level_file.write(SubstituteTemplate(self.subclass_call_template, subclass_cfg))
373
+
374
+ self.top_level_file.write(self.epilogue_template)
375
+ self.top_level_file.close()
376
+
377
+ class EmitInterfaceLibrary:
378
+ """
379
+ Emit the topmost-level CUTLASS library initialization code.
380
+ The code is generated in the generated_path directory
381
+ (e.g., tools/library/generated in the build directory),
382
+ in the initialize_all.cpp file.
383
+ That file declares several functions in namespace cutlass::library.
384
+ The functions all have this form,
385
+
386
+ void initialize_all_{operation_kind}_operations(Manifest& manifest);
387
+
388
+ where {operation_kind} abbreviates the "kind" of operation
389
+ (e.g., gemm for matrix-matrix multiply, conv2d for 2-d convolution,
390
+ or trmm for triangular solve with multiple right-hand sides).
391
+ The definitions of these functions live in subdirectories.
392
+
393
+ The file also _defines_ the following function in that namespace.
394
+
395
+ void initialize_all(Manifest& manifest);
396
+
397
+ That function first prepares the manifest, and then
398
+ calls all of the functions declared in this file.
399
+ """
400
+
401
+ def __init__(self, generated_path, operation_count, args):
402
+ self.generated_path = generated_path
403
+ self.args = args
404
+
405
+ self.prototypes = []
406
+ self.fn_calls = []
407
+ self.operation_count = str(operation_count)
408
+
409
+ self.top_level_hdr_template = '''
410
+ /*
411
+ Generated by manifest.py - Do not edit.
412
+ */
413
+ '''
414
+ self.top_level_prologue = '''
415
+
416
+ #include "cutlass/library/library.h"
417
+ #include "cutlass/library/manifest.h"
418
+
419
+ namespace cutlass {
420
+ \tnamespace library {
421
+
422
+ ${prototypes}
423
+ '''
424
+
425
+ self.top_level_initialize_kind = '''
426
+ \t\tvoid initialize_all_${kind}_operations(Manifest &manifest) {
427
+ ${fn_calls}
428
+ \t\t}
429
+ '''
430
+
431
+ self.top_level_initialize = '''
432
+ \t\tvoid initialize_all(Manifest &manifest) {
433
+ \t\t\tmanifest.reserve(${operation_count});\n
434
+ ${fn_calls}
435
+ \t\t}
436
+ '''
437
+
438
+ self.top_level_suffix = '''
439
+ \t} // namespace library
440
+ } // namespace cutlass
441
+
442
+ '''
443
+
444
+ #
445
+ def __enter__(self):
446
+ _LOGGER.debug("*** EmitInterfaceLibrary::__enter__")
447
+
448
+ self.top_level_path = os.path.join(self.generated_path, 'initialize_all.cpp')
449
+ _LOGGER.debug("*** top_level_path: " + str(self.top_level_path))
450
+
451
+ self.top_level_file = open(self.top_level_path, "w")
452
+ self.top_level_file.write(self.top_level_hdr_template)
453
+
454
+ self.source_files = [self.top_level_path,]
455
+
456
+ return self
457
+
458
+ #
459
+ def emit(self, operation_name):
460
+ _LOGGER.debug("*** EmitInterfaceLibrary::emit")
461
+ _LOGGER.debug("*** operation_name: " + operation_name)
462
+
463
+ self.prototypes.append(SubstituteTemplate(
464
+ "\t\tvoid initialize_all_${operation_kind}_operations(Manifest &manifest);",
465
+ {'operation_kind': operation_name}))
466
+
467
+ self.fn_calls.append(SubstituteTemplate(
468
+ "\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
469
+ {'operation_kind': operation_name}))
470
+
471
+ #
472
+ def __exit__(self, exception_type, exception_value, traceback):
473
+ _LOGGER.debug("*** EmitInterfaceLibrary::__exit__")
474
+
475
+ self.top_level_file.write(SubstituteTemplate(self.top_level_prologue, {'prototypes':"\n".join(self.prototypes)}))
476
+
477
+ # Write out initialize_all method
478
+ self.top_level_file.write(SubstituteTemplate(self.top_level_initialize,
479
+ {'operation_count': self.operation_count, 'fn_calls':"\n".join(self.fn_calls)}))
480
+
481
+ self.top_level_file.write(self.top_level_suffix)
482
+ self.top_level_file.close()
483
+
484
+ ###################################################################################################
485
+ ###################################################################################################
486
+
487
+ class Options:
488
+ def __init__(self):
489
+ pass
490
+
491
+ ###################################################################################################
492
+
493
+ #
494
+ class Manifest:
495
+
496
+ #
497
+ def __init__(self, args = None):
498
+ self.operations = {}
499
+ self.args = args
500
+ self.operation_count = 0
501
+ self.operations_by_name = {}
502
+
503
+ self.kernel_filter = ''
504
+ self.kernel_filter_list = []
505
+ self.kernel_names = []
506
+ self.operations_enabled = []
507
+ self.selected_kernels = []
508
+ self.ignore_kernel_names = []
509
+ self.exclude_kernel_names = []
510
+ self.compute_capabilities_baseline = [50,]
511
+ self.compute_capabilities_feature_set = ['50',]
512
+ self.curr_build_dir = '.'
513
+ self.filter_by_cc = True
514
+
515
+ if self.args:
516
+ self.kernel_filter = self.args.kernels
517
+ self.curr_build_dir = args.curr_build_dir
518
+
519
+ # A common user error is to use commas instead of semicolons.
520
+ if ',' in args.architectures:
521
+ raise RuntimeError("The list of architectures (CMake option CUTLASS_NVCC_ARCHS) must be semicolon-delimited.\nDon't use commas to separate the architectures; use semicolons.\nYou specified the list as: " + args.architectures)
522
+
523
+ self.compute_capabilities_feature_set = args.architectures.split(';') if len(args.architectures) else ['50',]
524
+ self.compute_capabilities_baseline = sorted(set(int(arch.split('a')[0].split('f')[0]) for arch in self.compute_capabilities_feature_set))
525
+
526
+ if args.filter_by_cc in ['false', 'False', '0']:
527
+ self.filter_by_cc = False
528
+
529
+ if args.operations == 'all':
530
+ self.operations_enabled = []
531
+ else:
532
+ operations_list = [
533
+ OperationKind.Gemm
534
+ , OperationKind.Conv2d
535
+ , OperationKind.Conv3d
536
+ , OperationKind.RankK
537
+ , OperationKind.Trmm
538
+ , OperationKind.Symm
539
+ ]
540
+ self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]
541
+
542
+ if args.kernels == 'all':
543
+ self.kernel_names = []
544
+ else:
545
+ self.kernel_names = [x for x in args.kernels.split(',') if x != '']
546
+
547
+ self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != '']
548
+ self.exclude_kernel_names = [x for x in args.exclude_kernels.split(',') if x != '']
549
+
550
+ if args.kernel_filter_file is None:
551
+ self.kernel_filter_list = []
552
+ else:
553
+ self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file)
554
+ _LOGGER.debug("Using {filter_count} kernel filters from {filter_file}".format(
555
+ filter_count = len(self.kernel_filter_list),
556
+ filter_file = args.kernel_filter_file))
557
+
558
+ self.operation_count = 0
559
+ self.operations_by_name = {}
560
+ self.disable_full_archs_compilation = args.disable_full_archs_compilation
561
+ self.is_kernel_filter_set_to_all = args.instantiation_level == "max" and args.kernels != ''
562
+ self.instantiation_level = 0
563
+ try:
564
+ self.instantiation_level = int(args.instantiation_level)
565
+ except ValueError:
566
+ self.instantiation_level = 0
567
+
568
+ def add_kernel_filter(self, filter_str):
569
+ filter_re = re.compile(filter_str)
570
+
571
+ self.kernel_filter_list.append(filter_re)
572
+
573
+ def get_instantiation_level(self, pruned_level=0, default_level=111, exhaustive_level=9992):
574
+ # Non-negative integer which determines how many kernels are instantiated.
575
+ # 0 = 0000 generates the fewest kernels, 9999 generates all possible combinations.
576
+ # increasing first digit reduces schedule / mixed type pruning,
577
+ # increasing second digit generates more cluster sizes,
578
+ # increasing third digit generates more MMA multipliers,
579
+ # increasing fourth digit generates more instruction shapes.
580
+
581
+ if self.instantiation_level > 0:
582
+ return self.instantiation_level
583
+
584
+ elif self.is_kernel_filter_set_to_all:
585
+ return exhaustive_level
586
+
587
+ elif self.kernel_filter == '':
588
+ return pruned_level
589
+
590
+ else:
591
+ return default_level
592
+
593
+
594
+ def get_kernel_filters(self, kernelListFile):
595
+ if os.path.isfile(kernelListFile):
596
+ with open(kernelListFile, 'r') as fileReader:
597
+ lines = [line.rstrip() for line in fileReader if not line.startswith("#")]
598
+
599
+ lines = [re.compile(line) for line in lines if line]
600
+ return lines
601
+ else:
602
+ return []
603
+
604
+ #
605
+ def filter_out_kernels(self, kernel_name, kernel_filter_list):
606
+
607
+ for kernel_filter_re in kernel_filter_list:
608
+ if kernel_filter_re.search(kernel_name) is not None:
609
+ return True
610
+
611
+ return False
612
+
613
+
614
+ #
615
+ def _filter_string_matches(self, filter_string, haystack):
616
+ ''' Returns true if all substrings appear in the haystack in order'''
617
+ substrings = filter_string.split('*')
618
+ for sub in substrings:
619
+ idx = haystack.find(sub)
620
+ if idx < 0:
621
+ return False
622
+ haystack = haystack[idx + len(sub):]
623
+ return True
624
+
625
+ #
626
+ def filter(self, operation):
627
+ ''' Filtering operations based on various criteria'''
628
+
629
+ # filter based on compute capability
630
+ enabled = not (self.filter_by_cc)
631
+
632
+ for cc in self.compute_capabilities_baseline:
633
+
634
+ if cc >= operation.tile_description.minimum_compute_capability and \
635
+ cc <= operation.tile_description.maximum_compute_capability and \
636
+ (cc not in SharedMemPerCC or SharedMemPerCC[cc] >= CalculateSmemUsage(operation)):
637
+
638
+ enabled = True
639
+ break
640
+
641
+ if not enabled:
642
+ return False
643
+
644
+ if len(self.operations_enabled) and not operation.operation_kind in self.operations_enabled:
645
+ return False
646
+
647
+ name = operation.procedural_name()
648
+
649
+ # eliminate duplicates
650
+ if name in self.operations_by_name.keys():
651
+ return False
652
+
653
+ # Filter based on list of valid substrings
654
+ if len(self.kernel_names):
655
+ enabled = False
656
+
657
+ # compare against the include list
658
+ for name_substr in self.kernel_names:
659
+ if self._filter_string_matches(name_substr, name):
660
+ _LOGGER.debug(f"Kernel {name} included due to filter string '{name_substr}'.")
661
+ enabled = True
662
+ break
663
+ else:
664
+ _LOGGER.debug(f"Kernel {name} NOT included due to not matching '{name_substr}'.")
665
+
666
+ # compare against the exclude list
667
+ for name_substr in self.ignore_kernel_names:
668
+ if self._filter_string_matches(name_substr, name):
669
+ _LOGGER.debug(f"Kernel {name} ignored due to filter string '{name_substr}'.")
670
+ enabled = False
671
+ break
672
+ else:
673
+ _LOGGER.debug(f"Kernel {name} NOT ignored due to not matching '{name_substr}'.")
674
+
675
+ if len(self.kernel_filter_list) > 0:
676
+ if self.filter_out_kernels(name, self.kernel_filter_list):
677
+ _LOGGER.debug(f"Kernel {name} matched via kernel filter file.")
678
+ enabled = True
679
+ else:
680
+ _LOGGER.debug(f"Kernel {name} culled due to no match in kernel filter file.")
681
+ enabled = False
682
+
683
+ # CUTLASS_LIBRARY_IGNORE_KERNELS ("ignore" list) only takes effect
684
+ # if CUTLASS_LIBRARY_KERNELS was specified.
685
+ # Changing that would break backwards compatibility.
686
+ # Thus, CUTLASS has introduced the new CMake option CUTLASS_LIBRARY_EXCLUDE_KERNELS,
687
+ # that always takes effect, whether or not CUTLASS_LIBRARY_KERNELS was specified.
688
+ for name_substr in self.exclude_kernel_names:
689
+ if self._filter_string_matches(name_substr, name):
690
+ _LOGGER.debug(f"Kernel {name} excluded due to filter string '{name_substr}'.")
691
+ enabled = False
692
+ break
693
+ else:
694
+ _LOGGER.debug(f"Kernel {name} NOT excluded due to not matching '{name_substr}'.")
695
+
696
+ # TODO: filter based on compute data type
697
+ return enabled
698
+ #
699
+
700
+ #
701
+ def append(self, operation):
702
+ '''
703
+ Inserts the operation.
704
+
705
+ operation_kind -> configuration_name -> []
706
+ '''
707
+
708
+ if self.filter(operation):
709
+
710
+ self.selected_kernels.append(operation.procedural_name())
711
+
712
+ self.operations_by_name[operation.procedural_name()] = operation
713
+
714
+ # add the configuration
715
+ configuration_name = operation.configuration_name()
716
+
717
+ # Split operations by minimum CC
718
+ min_cc = operation.arch
719
+
720
+ if operation.operation_kind not in self.operations.keys():
721
+ self.operations[operation.operation_kind] = {}
722
+
723
+ if min_cc not in self.operations[operation.operation_kind]:
724
+ self.operations[operation.operation_kind][min_cc] = {}
725
+
726
+ if configuration_name not in self.operations[operation.operation_kind][min_cc].keys():
727
+ self.operations[operation.operation_kind][min_cc][configuration_name] = []
728
+
729
+ self.operations[operation.operation_kind][min_cc][configuration_name].append(operation)
730
+ self.operation_count += 1
731
+ else:
732
+ _LOGGER.debug("Culled {} from manifest".format(operation.procedural_name()))
733
+ #
734
+
735
+ def emit_manifest_cmake(self, manifest_path, top_level_path, source_files):
736
+ with open(manifest_path, "w") as manifest_file:
737
+
738
+ target_text = SubstituteTemplate("""cutlass_target_sources(cutlass_library_objs PRIVATE
739
+ """, { })
740
+ manifest_file.write(target_text + '\n\n')
741
+ manifest_file.write(" %s\n" % str(top_level_path.replace('\\', '/')))
742
+ generated_path = os.path.join(self.curr_build_dir, 'generated')
743
+ for kind in self.operations.keys():
744
+ kind_str = OperationKindNames[kind]
745
+ all_kind_file = os.path.join(generated_path, kind_str, f"all_{kind_str}_operations.cu").replace('\\', '/')
746
+ manifest_file.write(f" {all_kind_file}\n")
747
+ manifest_file.write(')\n\n')
748
+
749
+ for kind in self.operations.keys():
750
+ for min_cc in sorted(self.operations[kind].keys()):
751
+ for subclass in sorted(source_files[kind][min_cc].keys()):
752
+ target_text = SubstituteTemplate("""cutlass_add_cutlass_library(
753
+ SUFFIX ${kind}_sm${min_cc}_${subclass}
754
+ """, { 'min_cc': str(min_cc), 'kind': OperationKindNames[kind], 'subclass': subclass })
755
+ manifest_file.write(target_text + '\n\n')
756
+
757
+ for source_file in source_files[kind][min_cc][subclass]:
758
+ manifest_file.write(" %s\n" % str(source_file.replace('\\', '/')))
759
+
760
+ manifest_file.write(")\n")
761
+
762
+ if self.disable_full_archs_compilation:
763
+ self.emit_disable_full_archs_compilation(manifest_file, source_files)
764
+
765
+ def emit_disable_full_archs_compilation(manifest_file, source_files):
766
+ def for_hopper(name):
767
+ pass
768
+
769
+ def for_ampere(name):
770
+ return "16816" in name or \
771
+ "16832" in name or \
772
+ "16864" in name or \
773
+ ("1688" in name and "tf32" in name)
774
+
775
+ def for_turing(name):
776
+ return ("1688" in name and "tf32" not in name) or \
777
+ "8816" in name
778
+
779
+ def for_volta(name):
780
+ return "884" in name
781
+
782
+ def is_cpp(name):
783
+ return name.endswith(".cpp")
784
+
785
+ def get_src_archs_str_given_requested_cuda_archs(archs, source_file):
786
+ intersected_archs = archs & set(self.compute_capabilities_baseline)
787
+ if intersected_archs == set():
788
+ raise RuntimeError(
789
+ """
790
+ Empty archs set for file {} after taking
791
+ the intersection of {} (global requested archs) and
792
+ {} (per file requested archs)
793
+ """.format(source_file, set(self.compute_capabilities_baseline), archs))
794
+ else:
795
+ return " ".join(map(str, intersected_archs))
796
+
797
+ for min_cc in sorted(source_files.keys()):
798
+ for source_file in source_files[min_cc]:
799
+ if is_cpp(source_file):
800
+ continue # skip because source is cpp
801
+ elif for_ampere(source_file):
802
+ archs_str = get_src_archs_str_given_requested_cuda_archs({80, 87, 90}, source_file)
803
+ elif for_turing(source_file):
804
+ archs_str = get_src_archs_str_given_requested_cuda_archs({75}, source_file)
805
+ elif for_volta(source_file):
806
+ archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file)
807
+ else:
808
+ raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
809
+
810
+ manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str))
811
+
812
+ #
813
+ def emit(self, target = GeneratorTarget.Library):
814
+
815
+ operation_emitters = {
816
+ GeneratorTarget.Library: EmitOperationKindLibrary
817
+ }
818
+
819
+ # Emitters for all operations that fall under a particular kind (e.g., GEMM, Conv2d)
820
+ kind_emitters = {
821
+ GeneratorTarget.Library: EmitOperationKindAll
822
+ }
823
+
824
+ interface_emitters = {
825
+ GeneratorTarget.Library: EmitInterfaceLibrary
826
+ }
827
+
828
+ generated_path = os.path.join(self.curr_build_dir, 'generated')
829
+
830
+ # create generated/
831
+ if os.path.exists(generated_path):
832
+ shutil.rmtree(generated_path)
833
+
834
+ os.mkdir(generated_path)
835
+
836
+ with interface_emitters[target](generated_path, self.operation_count, self.args) as iface_emitter:
837
+ top_level_path = iface_emitter.top_level_path
838
+ for operation_kind in self.operations.keys():
839
+ iface_emitter.emit(OperationKindNames[operation_kind])
840
+
841
+ source_files = {}
842
+ for kind in self.operations.keys():
843
+ source_files[kind] = {}
844
+ for min_cc in self.operations[kind].keys():
845
+ source_files[kind][min_cc] = {}
846
+
847
+ for operation_kind, ops in self.operations.items():
848
+ for min_cc, configurations in sorted(ops.items()):
849
+ with operation_emitters[target](generated_path, min_cc, operation_kind, self.args) as operation_kind_emitter:
850
+ for configuration_name, operations in configurations.items():
851
+ _LOGGER.info(f"Emitting {configuration_name} with {len(operations)} operation{'' if len(operations) == 1 else 's'}.")
852
+ operation_kind_emitter.emit(configuration_name, operations)
853
+
854
+ for subclass, files in operation_kind_emitter.source_files.items():
855
+ if subclass not in source_files[operation_kind][min_cc]:
856
+ source_files[operation_kind][min_cc][subclass] = []
857
+ source_files[operation_kind][min_cc][subclass].extend(operation_kind_emitter.source_files[subclass])
858
+
859
+ # Emit top level all_{gemm, conv2d, ...}_operations.cu files
860
+ with kind_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter:
861
+ operation_kind_emitter.emit(ops)
862
+
863
+ # write the manifest.cmake file containing paths from all targets
864
+ manifest_path = os.path.join(generated_path, "manifest.cmake")
865
+
866
+ self.emit_manifest_cmake(manifest_path, top_level_path, source_files)
867
+
868
+ ###################################################################################################
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_2k_operation.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for emitting Rank2K kernels
35
+ """
36
+
37
+ import enum
38
+ import functools
39
+ import operator
40
+ import os.path
41
+ import shutil
42
+
43
+ try:
44
+ import builtins
45
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
+ raise ImportError("Disabling attempt to import cutlass_library")
47
+ from cutlass_library.library import *
48
+ except ImportError:
49
+ from library import *
50
+
51
+
52
+ ###################################################################################################
53
+ #
54
+ # Data structure modeling a Rank K update operation
55
+ #
56
+ ###################################################################################################
57
+
58
+ #
59
+ class Rank2KOperation:
60
+ #
61
+ def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \
62
+ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
63
+ blas_mode = BlasMode.symmetric):
64
+
65
+ self.blas_mode = blas_mode
66
+ self.operation_kind = OperationKind.Rank2K
67
+ self.arch = arch
68
+ self.tile_description = tile_description
69
+ self.rank_k_kind = rank_k_kind
70
+ # tensor A and B have same data type and layout
71
+ self.A = A
72
+ self.B = A
73
+ self.C = C
74
+ self.element_epilogue = element_epilogue
75
+ self.epilogue_functor = epilogue_functor
76
+ self.swizzling_functor = swizzling_functor
77
+
78
+ #
79
+ def is_complex(self):
80
+ complex_operators = [
81
+ MathOperation.multiply_add_complex,
82
+ MathOperation.multiply_add_complex_gaussian,
83
+ MathOperation.multiply_add_complex_fast_f32
84
+ ]
85
+ return self.tile_description.math_instruction.math_operation in complex_operators
86
+ return False
87
+
88
+ #
89
+ def is_mixed_input(self):
90
+ return self.A.element != self.B.element
91
+
92
+ #
93
+ def is_planar_complex(self):
94
+ return False
95
+
96
+ #
97
+ def accumulator_type(self):
98
+ accum = self.tile_description.math_instruction.element_accumulator
99
+
100
+ if self.is_complex():
101
+ return get_complex_from_real(accum)
102
+
103
+ return accum
104
+
105
+ #
106
+ def short_math_name(self):
107
+ if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
108
+ return "g%s" % ShortDataTypeNames[self.accumulator_type()]
109
+ return ShortDataTypeNames[self.accumulator_type()]
110
+
111
+
112
+ #
113
+ def core_name(self):
114
+ ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
115
+
116
+ inst_shape = ''
117
+ inst_operation = ''
118
+ intermediate_type = ''
119
+
120
+ math_operations_map = {
121
+ MathOperation.xor_popc: 'xor',
122
+ MathOperation.and_popc: 'and'
123
+ }
124
+
125
+ if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
126
+ self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
127
+
128
+ math_op = self.tile_description.math_instruction.math_operation
129
+ math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
130
+
131
+ inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
132
+ inst_shape += math_op_string
133
+
134
+ if self.tile_description.math_instruction.element_a != self.A.element and \
135
+ self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
136
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
137
+
138
+ operation_name = 'syr2k' if self.blas_mode == BlasMode.symmetric else 'her2k'
139
+
140
+ return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name)
141
+
142
+ #
143
+ def extended_name(self):
144
+ ''' Append data types if they differ from compute type. '''
145
+ if self.is_complex():
146
+ extended_name = "${core_name}"
147
+ else:
148
+ if self.C.element != self.tile_description.math_instruction.element_accumulator and \
149
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
150
+ extended_name = "${element_c}_${core_name}_${element_a}"
151
+ elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
152
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
153
+ extended_name = "${core_name}_${element_a}"
154
+ else:
155
+ extended_name = "${core_name}"
156
+
157
+ extended_name = SubstituteTemplate(extended_name, {
158
+ 'element_a': DataTypeNames[self.A.element],
159
+ 'element_c': DataTypeNames[self.C.element],
160
+ 'core_name': self.core_name()
161
+ })
162
+
163
+ return extended_name
164
+
165
+ #
166
+ def layout_name(self):
167
+ if self.is_complex() or self.is_planar_complex():
168
+ return "%s" % (
169
+ ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)]
170
+ )
171
+ return "%s" % (ShortLayoutTypeNames[self.A.layout])
172
+
173
+ #
174
+ def fill_mode_name(self):
175
+ return "%s" % (ShortFillModeNames[self.C.fill_mode])
176
+
177
+ #
178
+ def procedural_name(self):
179
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
180
+ threadblock = self.tile_description.procedural_name()
181
+
182
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
183
+
184
+ alignment = max([self.A.alignment, self.C.alignment])
185
+
186
+ return SubstituteTemplate(
187
+ "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}",
188
+ {
189
+ 'opcode_class': opcode_class_name,
190
+ 'extended_name': self.extended_name(),
191
+ 'threadblock': threadblock,
192
+ 'layout': self.layout_name(),
193
+ 'fill_mode': self.fill_mode_name(),
194
+ 'alignment': "%d" % self.A.alignment,
195
+ }
196
+ )
197
+
198
+ #
199
+ def configuration_name(self):
200
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
201
+ return self.procedural_name()
202
+
203
+ ###################################################################################################
204
+ #
205
+ # Emits single instances of a CUTLASS device-wide operator
206
+ #
207
+ ###################################################################################################
208
+
209
+ #
210
+ class EmitRank2KUniversalInstance:
211
+ ''' Responsible for emitting a CUTLASS template definition'''
212
+
213
+ def __init__(self):
214
+ self.rank_k_template = """
215
+ // Rank K operator ${operation_name}
216
+ using Operation_${operation_name} =
217
+ typename cutlass::gemm::device::Rank2K<
218
+ ${element_a}, ${layout_a},
219
+ ${element_b}, ${layout_b},
220
+ ${element_c}, ${layout_c}, ${fill_mode},
221
+ ${element_accumulator},
222
+ ${opcode_class},
223
+ ${arch},
224
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
225
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
226
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
227
+ ${epilogue_functor}<
228
+ ${element_c},
229
+ ${epilogue_vector_length},
230
+ ${element_accumulator},
231
+ ${element_epilogue}
232
+ >,
233
+ ${swizzling_functor},
234
+ ${stages},
235
+ ${align_a},
236
+ ${align_b},
237
+ ${split_k_serial},
238
+ ${math_operation}
239
+ >;
240
+ """
241
+ self.rank_k_complex_template = """
242
+ // Rank K operator ${operation_name}
243
+ using Operation_${operation_name} =
244
+ typename cutlass::gemm::device::Rank2K<
245
+ ${element_a}, ${layout_a},
246
+ ${element_b}, ${layout_b},
247
+ ${element_c}, ${layout_c}, ${fill_mode},
248
+ ${element_accumulator},
249
+ ${opcode_class},
250
+ ${arch},
251
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
252
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
253
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
254
+ ${epilogue_functor}<
255
+ ${element_c},
256
+ ${epilogue_vector_length},
257
+ ${element_accumulator},
258
+ ${element_epilogue}
259
+ >,
260
+ ${swizzling_functor},
261
+ ${stages},
262
+ ${align_a},
263
+ ${align_b},
264
+ ${split_k_serial},
265
+ ${math_operation},
266
+ ${transform_a},
267
+ ${transform_b},
268
+ ${blas_mode}
269
+ >;
270
+ """
271
+
272
+ def emit(self, operation):
273
+
274
+ threadblock_shape = operation.tile_description.threadblock_shape
275
+
276
+ warp_count = operation.tile_description.warp_count
277
+ warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
278
+
279
+ epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
280
+
281
+ values = {
282
+ 'operation_name': operation.procedural_name(),
283
+ 'element_a': DataTypeTag[operation.A.element],
284
+ 'layout_a': LayoutTag[operation.A.layout],
285
+ 'element_b': DataTypeTag[operation.B.element],
286
+ 'layout_b': LayoutTag[operation.B.layout],
287
+ 'element_c': DataTypeTag[operation.C.element],
288
+ 'layout_c': LayoutTag[operation.C.layout],
289
+ 'fill_mode': FillModeTag[operation.C.fill_mode],
290
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
291
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
292
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
293
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
294
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
295
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
296
+ 'warp_shape_m': str(warp_shape[0]),
297
+ 'warp_shape_n': str(warp_shape[1]),
298
+ 'warp_shape_k': str(warp_shape[2]),
299
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
300
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
301
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
302
+ 'epilogue_vector_length': str(epilogue_vector_length),
303
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
304
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
305
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
306
+ 'stages': str(operation.tile_description.stages),
307
+ 'align_a': str(operation.A.alignment),
308
+ 'align_b': str(operation.B.alignment),
309
+ 'split_k_serial': 'false',
310
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
311
+ 'transform_a': ComplexTransformTag[operation.A.complex_transform],
312
+ 'transform_b': ComplexTransformTag[operation.B.complex_transform],
313
+ 'blas_mode': BlasModeTag[operation.blas_mode]
314
+ }
315
+
316
+ rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template
317
+
318
+ return SubstituteTemplate(rank_k_template, values)
319
+
320
+ ###################################################################################################
321
+
322
+
323
+ ###################################################################################################
324
+ #
325
+ # Emitters functions for all targets
326
+ #
327
+ ###################################################################################################
328
+
329
+ class EmitRank2KConfigurationLibrary:
330
+ def __init__(self, operation_path, configuration_name):
331
+ self.configuration_name = configuration_name
332
+ self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
333
+
334
+ self.instance_emitter = {
335
+ RankKKind.Universal: EmitRank2KUniversalInstance,
336
+ }
337
+
338
+ self.rank_k_kind_wrappers = {
339
+ RankKKind.Universal: 'Rank2KOperation',
340
+ }
341
+
342
+ self.instance_template = {
343
+ RankKKind.Universal: """
344
+ ${compile_guard_start}
345
+ manifest.append(new ${rank_k_kind}<
346
+ Operation_${operation_name}
347
+ >("${operation_name}"));
348
+ ${compile_guard_end}
349
+ """
350
+ }
351
+
352
+ self.header_template = """
353
+ /*
354
+ Generated by rank_2k_operation.py - Do not edit.
355
+ */
356
+
357
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
358
+ #include "cutlass/cutlass.h"
359
+ #include "cutlass/library/library.h"
360
+ #include "cutlass/library/manifest.h"
361
+
362
+ #include "library_internal.h"
363
+ #include "rank_2k_operation.h"
364
+
365
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
366
+
367
+ """
368
+
369
+ self.initialize_function_template = """
370
+
371
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
372
+
373
+ namespace cutlass {
374
+ namespace library {
375
+
376
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
377
+
378
+ void initialize_${configuration_name}(Manifest &manifest) {
379
+
380
+ """
381
+ self.epilogue_template = """
382
+
383
+ }
384
+
385
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
386
+
387
+ } // namespace library
388
+ } // namespace cutlass
389
+
390
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
391
+
392
+ """
393
+
394
+ def __enter__(self):
395
+ self.configuration_file = open(self.configuration_path, "w")
396
+ self.configuration_file.write(self.header_template)
397
+
398
+ self.instance_definitions = []
399
+ self.instance_wrappers = []
400
+
401
+ self.operations = []
402
+ return self
403
+
404
+ def emit(self, operation):
405
+ emitter = self.instance_emitter[operation.rank_k_kind]()
406
+
407
+ self.operations.append(operation)
408
+
409
+ self.instance_definitions.append(emitter.emit(operation))
410
+
411
+ self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], {
412
+ 'configuration_name': self.configuration_name,
413
+ 'operation_name': operation.procedural_name(),
414
+ 'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind],
415
+ 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
416
+ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
417
+ 'compile_guard_end': "#endif" \
418
+ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
419
+ }))
420
+
421
+ def __exit__(self, exception_type, exception_value, traceback):
422
+
423
+ # Write instance definitions in top-level namespace
424
+ for instance_definition in self.instance_definitions:
425
+ self.configuration_file.write(instance_definition)
426
+
427
+ # Add wrapper objects within initialize() function
428
+ self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
429
+ 'configuration_name': self.configuration_name
430
+ }))
431
+
432
+ for instance_wrapper in self.instance_wrappers:
433
+ self.configuration_file.write(instance_wrapper)
434
+
435
+ self.configuration_file.write(self.epilogue_template)
436
+ self.configuration_file.close()
437
+
438
+ ###################################################################################################
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/rank_k_operation.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for emitting RankK kernels
35
+ """
36
+
37
+ import enum
38
+ import functools
39
+ import operator
40
+ import os.path
41
+ import shutil
42
+
43
+ try:
44
+ import builtins
45
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
+ raise ImportError("Disabling attempt to import cutlass_library")
47
+ from cutlass_library.library import *
48
+ except ImportError:
49
+ from library import *
50
+
51
+
52
+ ###################################################################################################
53
+ #
54
+ # Data structure modeling a Rank K update operation
55
+ #
56
+ ###################################################################################################
57
+
58
+ #
59
+ class RankKOperation:
60
+ #
61
+ def __init__(self, rank_k_kind, arch, tile_description, A, C, element_epilogue, \
62
+ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
63
+ blas_mode = BlasMode.symmetric):
64
+
65
+ self.blas_mode = blas_mode
66
+ self.operation_kind = OperationKind.RankK
67
+ self.arch = arch
68
+ self.tile_description = tile_description
69
+ self.rank_k_kind = rank_k_kind
70
+ self.A = A
71
+ self.C = C
72
+ self.element_epilogue = element_epilogue
73
+ self.epilogue_functor = epilogue_functor
74
+ self.swizzling_functor = swizzling_functor
75
+
76
+ #
77
+ def is_complex(self):
78
+ complex_operators = [
79
+ MathOperation.multiply_add_complex,
80
+ MathOperation.multiply_add_complex_gaussian,
81
+ MathOperation.multiply_add_complex_fast_f32
82
+ ]
83
+ return self.tile_description.math_instruction.math_operation in complex_operators
84
+ return False
85
+
86
+ #
87
+ def is_mixed_input(self):
88
+ return False
89
+
90
+ #
91
+ def is_planar_complex(self):
92
+ return False
93
+
94
+ #
95
+ def accumulator_type(self):
96
+ accum = self.tile_description.math_instruction.element_accumulator
97
+
98
+ if self.is_complex():
99
+ return get_complex_from_real(accum)
100
+
101
+ return accum
102
+
103
+ #
104
+ def short_math_name(self):
105
+ if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
106
+ return "g%s" % ShortDataTypeNames[self.accumulator_type()]
107
+ return ShortDataTypeNames[self.accumulator_type()]
108
+
109
+
110
+ #
111
+ def core_name(self):
112
+ ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
113
+
114
+ inst_shape = ''
115
+ inst_operation = ''
116
+ intermediate_type = ''
117
+
118
+ math_operations_map = {
119
+ MathOperation.xor_popc: 'xor',
120
+ MathOperation.and_popc: 'and'
121
+ }
122
+
123
+ if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
124
+ self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
125
+
126
+ math_op = self.tile_description.math_instruction.math_operation
127
+ math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
128
+
129
+ inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
130
+ inst_shape += math_op_string
131
+
132
+ if self.tile_description.math_instruction.element_a != self.A.element and \
133
+ self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
134
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
135
+
136
+ operation_name = 'syrk' if self.blas_mode == BlasMode.symmetric else 'herk'
137
+
138
+ return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name)
139
+
140
+ #
141
+ def extended_name(self):
142
+ ''' Append data types if they differ from compute type. '''
143
+ if self.is_complex():
144
+ extended_name = "${core_name}"
145
+ else:
146
+ if self.C.element != self.tile_description.math_instruction.element_accumulator and \
147
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
148
+ extended_name = "${element_c}_${core_name}_${element_a}"
149
+ elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
150
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
151
+ extended_name = "${core_name}_${element_a}"
152
+ else:
153
+ extended_name = "${core_name}"
154
+
155
+ extended_name = SubstituteTemplate(extended_name, {
156
+ 'element_a': DataTypeNames[self.A.element],
157
+ 'element_c': DataTypeNames[self.C.element],
158
+ 'core_name': self.core_name()
159
+ })
160
+
161
+ return extended_name
162
+
163
+ #
164
+ def layout_name(self):
165
+ if self.is_complex() or self.is_planar_complex():
166
+ return "%s" % (
167
+ ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)]
168
+ )
169
+ return "%s" % (ShortLayoutTypeNames[self.A.layout])
170
+
171
+ #
172
+ def fill_mode_name(self):
173
+ return "%s" % (ShortFillModeNames[self.C.fill_mode])
174
+
175
+ #
176
+ def procedural_name(self):
177
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
178
+ threadblock = self.tile_description.procedural_name()
179
+
180
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
181
+
182
+ alignment = max([self.A.alignment, self.C.alignment])
183
+
184
+ return SubstituteTemplate(
185
+ "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${fill_mode}_align${alignment}",
186
+ {
187
+ 'opcode_class': opcode_class_name,
188
+ 'extended_name': self.extended_name(),
189
+ 'threadblock': threadblock,
190
+ 'layout': self.layout_name(),
191
+ 'fill_mode': self.fill_mode_name(),
192
+ 'alignment': "%d" % self.A.alignment,
193
+ }
194
+ )
195
+
196
+ #
197
+ def configuration_name(self):
198
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
199
+ return self.procedural_name()
200
+
201
+ ###################################################################################################
202
+ #
203
+ # Emits single instances of a CUTLASS device-wide operator
204
+ #
205
+ ###################################################################################################
206
+
207
+ #
208
+ class EmitRankKUniversalInstance:
209
+ ''' Responsible for emitting a CUTLASS template definition'''
210
+
211
+ def __init__(self):
212
+ self.rank_k_template = """
213
+ // Rank K operator ${operation_name}
214
+ using Operation_${operation_name} =
215
+ typename cutlass::gemm::device::RankK<
216
+ ${element_a}, ${layout_a},
217
+ ${element_c}, ${layout_c}, ${fill_mode},
218
+ ${element_accumulator},
219
+ ${opcode_class},
220
+ ${arch},
221
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
222
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
223
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
224
+ ${epilogue_functor}<
225
+ ${element_c},
226
+ ${epilogue_vector_length},
227
+ ${element_accumulator},
228
+ ${element_epilogue}
229
+ >,
230
+ ${swizzling_functor},
231
+ ${stages},
232
+ ${align_a},
233
+ ${split_k_serial},
234
+ ${math_operation}
235
+ >;
236
+ """
237
+ self.rank_k_complex_template = """
238
+ // Rank K operator ${operation_name}
239
+ using Operation_${operation_name} =
240
+ typename cutlass::gemm::device::RankK<
241
+ ${element_a}, ${layout_a},
242
+ ${element_c}, ${layout_c}, ${fill_mode},
243
+ ${element_accumulator},
244
+ ${opcode_class},
245
+ ${arch},
246
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
247
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
248
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
249
+ ${epilogue_functor}<
250
+ ${element_c},
251
+ ${epilogue_vector_length},
252
+ ${element_accumulator},
253
+ ${element_epilogue}
254
+ >,
255
+ ${swizzling_functor},
256
+ ${stages},
257
+ ${align_a},
258
+ ${split_k_serial},
259
+ ${math_operation},
260
+ ${transform_a},
261
+ ${blas_mode}
262
+ >;
263
+ """
264
+
265
+ def emit(self, operation):
266
+
267
+ threadblock_shape = operation.tile_description.threadblock_shape
268
+
269
+ warp_count = operation.tile_description.warp_count
270
+ warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
271
+
272
+ epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
273
+
274
+ values = {
275
+ 'operation_name': operation.procedural_name(),
276
+ 'element_a': DataTypeTag[operation.A.element],
277
+ 'layout_a': LayoutTag[operation.A.layout],
278
+ 'element_c': DataTypeTag[operation.C.element],
279
+ 'layout_c': LayoutTag[operation.C.layout],
280
+ 'fill_mode': FillModeTag[operation.C.fill_mode],
281
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
282
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
283
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
284
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
285
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
286
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
287
+ 'warp_shape_m': str(warp_shape[0]),
288
+ 'warp_shape_n': str(warp_shape[1]),
289
+ 'warp_shape_k': str(warp_shape[2]),
290
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
291
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
292
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
293
+ 'epilogue_vector_length': str(epilogue_vector_length),
294
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
295
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
296
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
297
+ 'stages': str(operation.tile_description.stages),
298
+ 'align_a': str(operation.A.alignment),
299
+ 'split_k_serial': 'false',
300
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
301
+ 'transform_a': ComplexTransformTag[operation.A.complex_transform],
302
+ 'blas_mode': BlasModeTag[operation.blas_mode]
303
+ }
304
+
305
+ rank_k_template = self.rank_k_complex_template if operation.is_complex() else self.rank_k_template
306
+
307
+ return SubstituteTemplate(rank_k_template, values)
308
+
309
+ ###################################################################################################
310
+
311
+
312
+ ###################################################################################################
313
+ #
314
+ # Emitters functions for all targets
315
+ #
316
+ ###################################################################################################
317
+
318
+ class EmitRankKConfigurationLibrary:
319
+ def __init__(self, operation_path, configuration_name):
320
+ self.configuration_name = configuration_name
321
+ self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
322
+
323
+ self.instance_emitter = {
324
+ RankKKind.Universal: EmitRankKUniversalInstance,
325
+ }
326
+
327
+ self.rank_k_kind_wrappers = {
328
+ RankKKind.Universal: 'RankKOperation',
329
+ }
330
+
331
+ self.instance_template = {
332
+ RankKKind.Universal: """
333
+ ${compile_guard_start}
334
+ manifest.append(new ${rank_k_kind}<
335
+ Operation_${operation_name}
336
+ >("${operation_name}"));
337
+ ${compile_guard_end}
338
+ """
339
+ }
340
+
341
+ self.header_template = """
342
+ /*
343
+ Generated by rank_k_operation.py - Do not edit.
344
+ */
345
+
346
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
347
+ #include "cutlass/cutlass.h"
348
+ #include "cutlass/library/library.h"
349
+ #include "cutlass/library/manifest.h"
350
+
351
+ #include "library_internal.h"
352
+ #include "rank_k_operation.h"
353
+
354
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
355
+
356
+ """
357
+
358
+ self.initialize_function_template = """
359
+
360
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
361
+
362
+ namespace cutlass {
363
+ namespace library {
364
+
365
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
366
+
367
+ void initialize_${configuration_name}(Manifest &manifest) {
368
+
369
+ """
370
+ self.epilogue_template = """
371
+
372
+ }
373
+
374
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
375
+
376
+ } // namespace library
377
+ } // namespace cutlass
378
+
379
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
380
+
381
+ """
382
+
383
+ def __enter__(self):
384
+ self.configuration_file = open(self.configuration_path, "w")
385
+ self.configuration_file.write(self.header_template)
386
+
387
+ self.instance_definitions = []
388
+ self.instance_wrappers = []
389
+
390
+ self.operations = []
391
+ return self
392
+
393
+ def emit(self, operation):
394
+ emitter = self.instance_emitter[operation.rank_k_kind]()
395
+
396
+ self.operations.append(operation)
397
+
398
+ self.instance_definitions.append(emitter.emit(operation))
399
+
400
+ self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.rank_k_kind], {
401
+ 'configuration_name': self.configuration_name,
402
+ 'operation_name': operation.procedural_name(),
403
+ 'rank_k_kind': self.rank_k_kind_wrappers[operation.rank_k_kind],
404
+ 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
405
+ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
406
+ 'compile_guard_end': "#endif" \
407
+ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
408
+ }))
409
+
410
+ def __exit__(self, exception_type, exception_value, traceback):
411
+
412
+ # Write instance definitions in top-level namespace
413
+ for instance_definition in self.instance_definitions:
414
+ self.configuration_file.write(instance_definition)
415
+
416
+ # Add wrapper objects within initialize() function
417
+ self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
418
+ 'configuration_name': self.configuration_name
419
+ }))
420
+
421
+ for instance_wrapper in self.instance_wrappers:
422
+ self.configuration_file.write(instance_wrapper)
423
+
424
+ self.configuration_file.write(self.epilogue_template)
425
+ self.configuration_file.close()
426
+
427
+ ###################################################################################################
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_shapes.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Valid tcgen05 shapes and cluster sizes for SM100, associated with levels.
35
+ These shape and level pairs are defined as dicts, where keys are shapes and values are their
36
+ associated levels. If the user input level for that category (tcgen05 shape, cluster
37
+ size) is smaller than a shape's associated level, it will be excluded, and otherwise, included.
38
+ Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently.
39
+ Level 0 is always emitted.
40
+ """
41
+
42
+ try:
43
+ from .library import DynamicClusterShape
44
+ except:
45
+ from library import DynamicClusterShape
46
+
47
+ SM100_CLUSTER_SHAPES_1SM = {
48
+ tuple(DynamicClusterShape) : 0,
49
+ # size 1 cluster
50
+ (1, 1, 1): 1,
51
+ # size 2 cluster
52
+ (1, 2, 1): 2,
53
+ (2, 1, 1): 5,
54
+ # size 4 clusters
55
+ (2, 2, 1): 6,
56
+ (1, 4, 1): 3,
57
+ (4, 1, 1): 6,
58
+ # size 8 clusters
59
+ (2, 4, 1): 7,
60
+ (4, 2, 1): 7,
61
+ (1, 8, 1): 8,
62
+ (8, 1, 1): 8,
63
+ # size 16 cluster
64
+ (4, 4, 1): 4,
65
+ }
66
+
67
+ SM100_CLUSTER_SHAPES_2SM = {
68
+ tuple(DynamicClusterShape) : 0,
69
+ # size 2 cluster
70
+ (2, 1, 1): 1,
71
+ # size 4 clusters
72
+ (2, 2, 1): 2,
73
+ (4, 1, 1): 2,
74
+ # size 8 clusters
75
+ (2, 4, 1): 3,
76
+ (4, 2, 1): 3,
77
+ (8, 1, 1): 6,
78
+ # size 16 cluster
79
+ (4, 4, 1): 4,
80
+ }
81
+
82
+ # MMA shapes
83
+
84
+ # 16b Dense
85
+
86
+ SM100_MMA_SHAPES_16b_DENSE_1SM = {
87
+ (64, 8, 16): 5,
88
+ (64, 16, 16): 2,
89
+ (64, 24, 16): 5,
90
+ (64, 32, 16): 2,
91
+ (64, 40, 16): 5,
92
+ (64, 48, 16): 5,
93
+ (64, 56, 16): 5,
94
+ (64, 64, 16): 2,
95
+ (64, 72, 16): 5,
96
+ (64, 80, 16): 5,
97
+ (64, 88, 16): 5,
98
+ (64, 96, 16): 5,
99
+ (64, 104, 16): 5,
100
+ (64, 112, 16): 5,
101
+ (64, 120, 16): 5,
102
+ (64, 128, 16): 0,
103
+ (64, 136, 16): 5,
104
+ (64, 144, 16): 5,
105
+ (64, 152, 16): 5,
106
+ (64, 160, 16): 5,
107
+ (64, 168, 16): 5,
108
+ (64, 176, 16): 5,
109
+ (64, 184, 16): 5,
110
+ (64, 192, 16): 3,
111
+ (64, 200, 16): 5,
112
+ (64, 208, 16): 5,
113
+ (64, 216, 16): 5,
114
+ (64, 224, 16): 5,
115
+ (64, 232, 16): 5,
116
+ (64, 240, 16): 5,
117
+ (64, 248, 16): 5,
118
+ (64, 256, 16): 3,
119
+
120
+ (128, 16, 16): 2,
121
+ (128, 32, 16): 2,
122
+ (128, 48, 16): 5,
123
+ (128, 64, 16): 2,
124
+ (128, 80, 16): 5,
125
+ (128, 96, 16): 5,
126
+ (128, 112, 16): 5,
127
+ (128, 128, 16): 0,
128
+ (128, 144, 16): 5,
129
+ (128, 160, 16): 5,
130
+ (128, 176, 16): 5,
131
+ (128, 192, 16): 3,
132
+ (128, 208, 16): 5,
133
+ (128, 224, 16): 5,
134
+ (128, 240, 16): 5,
135
+ (128, 256, 16): 0,
136
+
137
+ }
138
+
139
+
140
+ SM100_MMA_SHAPES_16b_DENSE_2SM = {
141
+ (128, 32, 16): 2,
142
+ (128, 64, 16): 2,
143
+ (128, 96, 16): 5,
144
+ (128, 128, 16): 0,
145
+ (128, 160, 16): 5,
146
+ (128, 192, 16): 5,
147
+ (128, 224, 16): 5,
148
+ (128, 256, 16): 0,
149
+
150
+ (256, 32, 16): 2,
151
+ (256, 64, 16): 2,
152
+ (256, 96, 16): 5,
153
+ (256, 128, 16): 0,
154
+ (256, 160, 16): 5,
155
+ (256, 192, 16): 3,
156
+ (256, 224, 16): 5,
157
+ (256, 256, 16): 0,
158
+ }
159
+
160
+ # TF32 Dense
161
+
162
+ SM100_MMA_SHAPES_TF32_DENSE_1SM = {
163
+ (64, 8, 8): 5,
164
+ (64, 16, 8): 2,
165
+ (64, 24, 8): 5,
166
+ (64, 32, 8): 2,
167
+ (64, 40, 8): 5,
168
+ (64, 48, 8): 5,
169
+ (64, 56, 8): 5,
170
+ (64, 64, 8): 1,
171
+ (64, 72, 8): 5,
172
+ (64, 80, 8): 5,
173
+ (64, 88, 8): 5,
174
+ (64, 96, 8): 5,
175
+ (64, 104, 8): 5,
176
+ (64, 112, 8): 5,
177
+ (64, 120, 8): 5,
178
+ (64, 128, 8): 0,
179
+ (64, 136, 8): 5,
180
+ (64, 144, 8): 5,
181
+ (64, 152, 8): 5,
182
+ (64, 160, 8): 5,
183
+ (64, 168, 8): 5,
184
+ (64, 176, 8): 5,
185
+ (64, 184, 8): 5,
186
+ (64, 192, 8): 3,
187
+ (64, 200, 8): 5,
188
+ (64, 208, 8): 5,
189
+ (64, 216, 8): 5,
190
+ (64, 224, 8): 5,
191
+ (64, 232, 8): 5,
192
+ (64, 240, 8): 5,
193
+ (64, 248, 8): 5,
194
+ (64, 256, 8): 3,
195
+
196
+ (128, 16, 8): 2,
197
+ (128, 32, 8): 2,
198
+ (128, 48, 8): 5,
199
+ (128, 64, 8): 2,
200
+ (128, 80, 8): 5,
201
+ (128, 96, 8): 5,
202
+ (128, 112, 8): 5,
203
+ (128, 128, 8): 0,
204
+ (128, 144, 8): 5,
205
+ (128, 160, 8): 5,
206
+ (128, 176, 8): 5,
207
+ (128, 192, 8): 3,
208
+ (128, 208, 8): 5,
209
+ (128, 224, 8): 5,
210
+ (128, 240, 8): 5,
211
+ (128, 256, 8): 0,
212
+
213
+ }
214
+
215
+ SM100_MMA_SHAPES_TF32_DENSE_2SM = {
216
+ (128, 32, 8): 2,
217
+ (128, 64, 8): 1,
218
+ (128, 96, 8): 5,
219
+ (128, 128, 8): 0,
220
+ (128, 160, 8): 5,
221
+ (128, 192, 8): 5,
222
+ (128, 224, 8): 5,
223
+ (128, 256, 8): 0,
224
+
225
+ (256, 32, 8): 2,
226
+ (256, 64, 8): 1,
227
+ (256, 96, 8): 5,
228
+ (256, 128, 8): 0,
229
+ (256, 160, 8): 5,
230
+ (256, 192, 8): 5,
231
+ (256, 224, 8): 5,
232
+ (256, 256, 8): 0,
233
+ }
234
+
235
+ # F8F6F4
236
+ SM100_MMA_SHAPES_F8F6F4_DENSE_1SM = {
237
+ (64, 8, 32): 4,
238
+ (64, 16, 32): 4,
239
+ (64, 24, 32): 5,
240
+ (64, 32, 32): 3,
241
+ (64, 40, 32): 5,
242
+ (64, 48, 32): 5,
243
+ (64, 56, 32): 5,
244
+ (64, 64, 32): 2,
245
+ (64, 72, 32): 5,
246
+ (64, 80, 32): 5,
247
+ (64, 88, 32): 5,
248
+ (64, 96, 32): 5,
249
+ (64, 104, 32): 5,
250
+ (64, 112, 32): 5,
251
+ (64, 120, 32): 5,
252
+ (64, 128, 32): 0,
253
+ (64, 136, 32): 5,
254
+ (64, 144, 32): 5,
255
+ (64, 152, 32): 5,
256
+ (64, 160, 32): 5,
257
+ (64, 168, 32): 5,
258
+ (64, 176, 32): 5,
259
+ (64, 184, 32): 5,
260
+ (64, 192, 32): 5,
261
+ (64, 200, 32): 5,
262
+ (64, 208, 32): 5,
263
+ (64, 216, 32): 5,
264
+ (64, 224, 32): 5,
265
+ (64, 232, 32): 5,
266
+ (64, 240, 32): 5,
267
+ (64, 248, 32): 5,
268
+ (64, 256, 32): 0,
269
+
270
+ (128, 16, 32): 4,
271
+ (128, 32, 32): 3,
272
+ (128, 48, 32): 5,
273
+ (128, 64, 32): 2,
274
+ (128, 80, 32): 5,
275
+ (128, 96, 32): 5,
276
+ (128, 112, 32): 5,
277
+ (128, 128, 32): 0,
278
+ (128, 144, 32): 5,
279
+ (128, 160, 32): 5,
280
+ (128, 176, 32): 5,
281
+ (128, 192, 32): 5,
282
+ (128, 208, 32): 5,
283
+ (128, 224, 32): 5,
284
+ (128, 240, 32): 5,
285
+ (128, 256, 32): 0,
286
+
287
+ }
288
+
289
+ SM100_MMA_SHAPES_F8F6F4_DENSE_2SM = {
290
+ (128, 32, 32): 3,
291
+ (128, 64, 32): 2,
292
+ (128, 96, 32): 5,
293
+ (128, 128, 32): 1,
294
+ (128, 160, 32): 5,
295
+ (128, 192, 32): 5,
296
+ (128, 224, 32): 5,
297
+ (128, 256, 32): 1,
298
+
299
+ (256, 32, 32): 2,
300
+ (256, 64, 32): 2,
301
+ (256, 96, 32): 5,
302
+ (256, 128, 32): 0,
303
+ (256, 160, 32): 5,
304
+ (256, 192, 32): 5,
305
+ (256, 224, 32): 5,
306
+ (256, 256, 32): 0,
307
+ }
308
+
309
+ # MXF8F6F4
310
+ SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM = {
311
+ (128, 64, 32): 1,
312
+ (128, 128, 32): 0,
313
+ (128, 192, 32): 1,
314
+ (128, 256, 32): 0,
315
+ }
316
+
317
+
318
+ SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM = {
319
+ (256, 64, 32): 1,
320
+ (256, 128, 32): 0,
321
+ (256, 192, 32): 1,
322
+ (256, 256, 32): 0,
323
+
324
+
325
+ }
326
+
327
+ # MXF4NVF4
328
+ SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM = {
329
+ (128, 64, 64): 1,
330
+ (128, 128, 64): 0,
331
+ (128, 192, 64): 1,
332
+ (128, 256, 64): 0,
333
+ }
334
+
335
+ SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM = {
336
+ # Multiples of 16 for N
337
+ (256, 64, 64): 1,
338
+ (256, 128, 64): 0,
339
+ (256, 192, 64): 1,
340
+ (256, 256, 64): 0,
341
+
342
+ }
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm100_utils.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for enumerating CUTLASS library SM100 kernels
35
+ """
36
+
37
+ import argparse
38
+ import enum
39
+ from itertools import product
40
+ import math
41
+ import logging
42
+ import os.path
43
+ import shutil
44
+ import sys
45
+ import copy
46
+ from typing import Any, Optional, Sequence, Tuple, List, Union, Callable
47
+
48
+ try:
49
+ import builtins
50
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
51
+ raise ImportError("Disabling attempt to import cutlass_library")
52
+ from cutlass_library.library import *
53
+ except ImportError:
54
+ from library import *
55
+
56
+ #### Step 0: define levels
57
+
58
+ # One integer level controls multiple "generators" and how many
59
+ # combinations they generate. That is the "global" level.
60
+ # "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and
61
+ # anything that is eventually involved in the Cartesian product
62
+ # which yields our kernel configurations.
63
+ # For simplicity, each generator defines their own levels,
64
+ # starting from 0. As a rule we assume 10 or fewer levels, making
65
+ # their level a digit.
66
+ # The "global" level simply stacks these digits and represents them
67
+ # as a single integer.
68
+ #
69
+ # For example, level 500 indicates cluster sizes are at level 5, MMA
70
+ # multipliers are at level 0, and WGMMA shapes are at level 0 as well.
71
+ #
72
+ # Here we define the global level to generator level mappings.
73
+
74
+
75
+ def get_tcgen05_level_from_global_level(global_level: int):
76
+ return global_level % 10
77
+
78
+ def get_mma_level_from_global_level(global_level: int):
79
+ return (global_level // 10) % 10
80
+
81
+
82
+ def get_cluster_level_from_global_level(global_level: int):
83
+ return (global_level // 100) % 10
84
+
85
+
86
+ def get_pruning_level_from_global_level(global_level: int):
87
+ return (global_level // 1000) % 10
88
+
89
+
90
+ #### Step 1: generate MMA instruction shapes based on levels
91
+
92
+ try:
93
+ from .sm100_shapes import *
94
+ except:
95
+ from sm100_shapes import *
96
+
97
+ ###########
98
+
99
+ def generate_tf32_math_instructions_sm100(level: int):
100
+ """
101
+ Generate all TensorOp math instructions for TF32 MMA that are supported by SM100 at or above the given level.
102
+
103
+ Args:
104
+ level: The global level to generate math instructions for.
105
+
106
+ Returns:
107
+ A tuple of two lists of MathInstruction objects.
108
+ The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
109
+ """
110
+ tcgen05_level = get_tcgen05_level_from_global_level(level)
111
+ math_instructions_1sm = []
112
+ math_instructions_2sm = []
113
+
114
+ shapes_1sm = [
115
+ shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_1SM.items() if tcgen05_level >= min_level
116
+ ]
117
+ shapes_2sm = [
118
+ shape for shape, min_level in SM100_MMA_SHAPES_TF32_DENSE_2SM.items() if tcgen05_level >= min_level
119
+ ]
120
+
121
+ for shape in shapes_1sm:
122
+ math_instructions_1sm.append(
123
+ MathInstruction(
124
+ shape,
125
+ DataType.tf32, DataType.tf32, DataType.f32,
126
+ OpcodeClass.TensorOp,
127
+ MathOperation.multiply_add)
128
+ )
129
+
130
+ for shape in shapes_2sm:
131
+ math_instructions_2sm.append(
132
+ MathInstruction(
133
+ shape,
134
+ DataType.tf32, DataType.tf32, DataType.f32,
135
+ OpcodeClass.TensorOp,
136
+ MathOperation.multiply_add)
137
+ )
138
+
139
+ return math_instructions_1sm, math_instructions_2sm
140
+
141
+ def generate_16b_math_instructions_sm100(level: int):
142
+ """
143
+ Generate all TensorOp math instructions for 16b MMA that are supported by SM100 at or above the given level.
144
+
145
+ Args:
146
+ level: The global level to generate math instructions for.
147
+
148
+ Returns:
149
+ A tuple of two lists of MathInstruction objects.
150
+ The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
151
+ """
152
+ tcgen05_level = get_tcgen05_level_from_global_level(level)
153
+ math_instructions_1sm = []
154
+ math_instructions_2sm = []
155
+
156
+ shapes_1sm = [
157
+ shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_1SM.items() if tcgen05_level >= min_level
158
+ ]
159
+ shapes_2sm = [
160
+ shape for shape, min_level in SM100_MMA_SHAPES_16b_DENSE_2SM.items() if tcgen05_level >= min_level
161
+ ]
162
+
163
+ for shape in shapes_1sm:
164
+ math_instructions_1sm.append(
165
+ MathInstruction(
166
+ shape,
167
+ DataType.f16, DataType.f16, DataType.f16,
168
+ OpcodeClass.TensorOp,
169
+ MathOperation.multiply_add)
170
+ )
171
+ math_instructions_1sm.append(
172
+ MathInstruction(
173
+ shape,
174
+ DataType.f16, DataType.f16, DataType.f32,
175
+ OpcodeClass.TensorOp,
176
+ MathOperation.multiply_add)
177
+ )
178
+ math_instructions_1sm.append(
179
+ MathInstruction(
180
+ shape,
181
+ DataType.bf16, DataType.bf16, DataType.f32,
182
+ OpcodeClass.TensorOp,
183
+ MathOperation.multiply_add)
184
+ )
185
+
186
+
187
+ for shape in shapes_2sm:
188
+ math_instructions_2sm.append(
189
+ MathInstruction(
190
+ shape,
191
+ DataType.f16, DataType.f16, DataType.f16,
192
+ OpcodeClass.TensorOp,
193
+ MathOperation.multiply_add)
194
+ )
195
+ math_instructions_2sm.append(
196
+ MathInstruction(
197
+ shape,
198
+ DataType.f16, DataType.f16, DataType.f32,
199
+ OpcodeClass.TensorOp,
200
+ MathOperation.multiply_add)
201
+ )
202
+ math_instructions_2sm.append(
203
+ MathInstruction(
204
+ shape,
205
+ DataType.bf16, DataType.bf16, DataType.f32,
206
+ OpcodeClass.TensorOp,
207
+ MathOperation.multiply_add)
208
+ )
209
+
210
+ return math_instructions_1sm, math_instructions_2sm
211
+
212
+
213
+ def generate_fp8_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True):
214
+ """
215
+ Generate all TensorOp math instructions for FP8 MMA that are supported by SM100 at or above the given level.
216
+
217
+ Args:
218
+ level: The global level to generate math instructions for.
219
+ enable_runtime_dtype: Whether to generate runtime dtype math instructions.
220
+ enable_compile_time_dtype: Whether to generate compile time dtype math instructions.
221
+
222
+ Returns:
223
+ A tuple of two lists of MathInstruction objects.
224
+ The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
225
+ """
226
+
227
+ tcgen05_level = get_tcgen05_level_from_global_level(level)
228
+ pruning_level = get_pruning_level_from_global_level(level)
229
+ math_instructions_1sm = []
230
+ math_instructions_2sm = []
231
+
232
+ shapes_1sm = [
233
+ shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level
234
+ ]
235
+ shapes_2sm = [
236
+ shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level
237
+ ]
238
+
239
+ for shape in shapes_1sm:
240
+ if enable_runtime_dtype:
241
+ math_instructions_1sm.append(
242
+ MathInstruction(
243
+ shape,
244
+ DataType.f8, DataType.f8, DataType.f32,
245
+ OpcodeClass.TensorOp,
246
+ MathOperation.multiply_add)
247
+ )
248
+ if enable_compile_time_dtype:
249
+ math_instructions_1sm.append(
250
+ MathInstruction(
251
+ shape,
252
+ DataType.e4m3, DataType.e4m3, DataType.f32,
253
+ OpcodeClass.TensorOp,
254
+ MathOperation.multiply_add)
255
+ )
256
+ math_instructions_1sm.append(
257
+ MathInstruction(
258
+ shape,
259
+ DataType.e5m2, DataType.e4m3, DataType.f32,
260
+ OpcodeClass.TensorOp,
261
+ MathOperation.multiply_add)
262
+ )
263
+ math_instructions_1sm.append(
264
+ MathInstruction(
265
+ shape,
266
+ DataType.e4m3, DataType.e5m2, DataType.f32,
267
+ OpcodeClass.TensorOp,
268
+ MathOperation.multiply_add)
269
+ )
270
+ if pruning_level >= 2:
271
+ math_instructions_1sm.append(
272
+ MathInstruction(
273
+ shape,
274
+ DataType.e5m2, DataType.e5m2, DataType.f32,
275
+ OpcodeClass.TensorOp,
276
+ MathOperation.multiply_add)
277
+ )
278
+
279
+ for shape in shapes_2sm:
280
+ if enable_runtime_dtype:
281
+ math_instructions_2sm.append(
282
+ MathInstruction(
283
+ shape,
284
+ DataType.f8, DataType.f8, DataType.f32,
285
+ OpcodeClass.TensorOp,
286
+ MathOperation.multiply_add)
287
+ )
288
+ if enable_compile_time_dtype:
289
+ math_instructions_2sm.append(
290
+ MathInstruction(
291
+ shape,
292
+ DataType.e4m3, DataType.e4m3, DataType.f32,
293
+ OpcodeClass.TensorOp,
294
+ MathOperation.multiply_add)
295
+ )
296
+ math_instructions_2sm.append(
297
+ MathInstruction(
298
+ shape,
299
+ DataType.e5m2, DataType.e4m3, DataType.f32,
300
+ OpcodeClass.TensorOp,
301
+ MathOperation.multiply_add)
302
+ )
303
+ math_instructions_2sm.append(
304
+ MathInstruction(
305
+ shape,
306
+ DataType.e4m3, DataType.e5m2, DataType.f32,
307
+ OpcodeClass.TensorOp,
308
+ MathOperation.multiply_add)
309
+ )
310
+ if pruning_level >= 2:
311
+ math_instructions_2sm.append(
312
+ MathInstruction(
313
+ shape,
314
+ DataType.e5m2, DataType.e5m2, DataType.f32,
315
+ OpcodeClass.TensorOp,
316
+ MathOperation.multiply_add)
317
+ )
318
+
319
+ return math_instructions_1sm, math_instructions_2sm
320
+
321
+ def generate_f8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True):
322
+ """
323
+ Generate all TensorOp math instructions for FP8 FP6 and FP4 MMA that are supported by SM100 at or above the given level.
324
+
325
+ Args:
326
+ level: The global level to generate math instructions for.
327
+ enable_runtime_dtype: Whether to generate runtime dtype math instructions.
328
+ enable_compile_time_dtype: Whether to generate compile time dtype math instructions.
329
+
330
+ Returns:
331
+ A tuple of two lists of MathInstruction objects.
332
+ The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
333
+ """
334
+
335
+ tcgen05_level = get_tcgen05_level_from_global_level(level)
336
+ math_instructions_1sm = []
337
+ math_instructions_2sm = []
338
+
339
+ shapes_1sm = [
340
+ shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level
341
+ ]
342
+ shapes_2sm = [
343
+ shape for shape, min_level in SM100_MMA_SHAPES_F8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level
344
+ ]
345
+
346
+ for shape in shapes_1sm:
347
+ if enable_runtime_dtype:
348
+
349
+ runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ]
350
+
351
+ for a_type, b_type in product(runtime_types, repeat=2):
352
+ math_instructions_1sm.append(
353
+ MathInstruction(
354
+ shape,
355
+ a_type, b_type, DataType.f32,
356
+ OpcodeClass.TensorOp,
357
+ MathOperation.multiply_add)
358
+ )
359
+
360
+ if enable_compile_time_dtype:
361
+ compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ]
362
+
363
+ for a_type, b_type in product(compile_time_types, repeat=2):
364
+ math_instructions_1sm.append(
365
+ MathInstruction(
366
+ shape,
367
+ a_type, b_type, DataType.f32,
368
+ OpcodeClass.TensorOp,
369
+ MathOperation.multiply_add)
370
+ )
371
+
372
+
373
+ for shape in shapes_2sm:
374
+ if enable_runtime_dtype:
375
+
376
+ runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ]
377
+
378
+ for a_type, b_type in product(runtime_types, repeat=2):
379
+ math_instructions_2sm.append(
380
+ MathInstruction(
381
+ shape,
382
+ a_type, b_type, DataType.f32,
383
+ OpcodeClass.TensorOp,
384
+ MathOperation.multiply_add)
385
+ )
386
+
387
+ if enable_compile_time_dtype:
388
+ compile_time_types = [ DataType.e4m3, DataType.e5m2, DataType.e3m2, DataType.e2m1 ]
389
+
390
+ for a_type, b_type in product(compile_time_types, repeat=2):
391
+ math_instructions_2sm.append(
392
+ MathInstruction(
393
+ shape,
394
+ a_type, b_type, DataType.f32,
395
+ OpcodeClass.TensorOp,
396
+ MathOperation.multiply_add)
397
+ )
398
+
399
+ return math_instructions_1sm, math_instructions_2sm
400
+
401
+ def generate_mxf8f6f4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True):
402
+ """
403
+ Generate all BlockScaledTensorOp math instructions for MXFP8, MXFP6, and MXFP4 MMA that are supported by SM100 at or above the given level.
404
+
405
+ Args:
406
+ level: The global level to generate math instructions for.
407
+ enable_runtime_dtype: Whether to generate runtime dtype math instructions.
408
+ enable_compile_time_dtype: Whether to generate compile time dtype math instructions.
409
+
410
+ Returns:
411
+ A tuple of two lists of MathInstruction objects.
412
+ The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
413
+ """
414
+
415
+ tcgen05_level = get_tcgen05_level_from_global_level(level)
416
+ pruning_level = get_pruning_level_from_global_level(level)
417
+
418
+ math_instructions_1sm = []
419
+ math_instructions_2sm = []
420
+
421
+ shapes_1sm = [
422
+ shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_1SM.items() if tcgen05_level >= min_level
423
+ ]
424
+ shapes_2sm = [
425
+ shape for shape, min_level in SM100_MMA_SHAPES_MXF8F6F4_DENSE_2SM.items() if tcgen05_level >= min_level
426
+ ]
427
+
428
+ for shape in shapes_1sm:
429
+ if enable_runtime_dtype:
430
+
431
+ runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ]
432
+
433
+ for a_type, b_type in product(runtime_types, repeat=2):
434
+
435
+ if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)):
436
+ continue
437
+
438
+ math_instructions_1sm.append(
439
+ MathInstruction(
440
+ shape,
441
+ a_type, b_type, DataType.f32,
442
+ OpcodeClass.BlockScaledTensorOp,
443
+ MathOperation.multiply_add,
444
+ DataType.ue8m0)
445
+ )
446
+
447
+ if enable_compile_time_dtype:
448
+ compile_time_types = [ DataType.e4m3,
449
+ DataType.e5m2,
450
+ DataType.e3m2,
451
+ DataType.e2m3,
452
+ DataType.e2m1 ]
453
+
454
+ for a_type, b_type in product(compile_time_types, repeat=2):
455
+ math_instructions_1sm.append(
456
+ MathInstruction(
457
+ shape,
458
+ a_type, b_type, DataType.f32,
459
+ OpcodeClass.BlockScaledTensorOp,
460
+ MathOperation.multiply_add,
461
+ DataType.ue8m0)
462
+ )
463
+
464
+
465
+ for shape in shapes_2sm:
466
+ if enable_runtime_dtype:
467
+
468
+ runtime_types = [ DataType.f8, DataType.f6, DataType.f4 ]
469
+
470
+ for a_type, b_type in product(runtime_types, repeat=2):
471
+
472
+ if pruning_level < 2 and ((a_type == DataType.f8 or b_type == DataType.f8)):
473
+ continue
474
+
475
+ math_instructions_2sm.append(
476
+ MathInstruction(
477
+ shape,
478
+ a_type, b_type, DataType.f32,
479
+ OpcodeClass.BlockScaledTensorOp,
480
+ MathOperation.multiply_add,
481
+ DataType.ue8m0)
482
+ )
483
+
484
+ if enable_compile_time_dtype:
485
+ compile_time_types = [ DataType.e4m3,
486
+ DataType.e5m2,
487
+ DataType.e3m2,
488
+ DataType.e2m3,
489
+ DataType.e2m1 ]
490
+
491
+ for a_type, b_type in product(compile_time_types, repeat=2):
492
+ math_instructions_2sm.append(
493
+ MathInstruction(
494
+ shape,
495
+ a_type, b_type, DataType.f32,
496
+ OpcodeClass.BlockScaledTensorOp,
497
+ MathOperation.multiply_add,
498
+ DataType.ue8m0)
499
+ )
500
+
501
+ return math_instructions_1sm, math_instructions_2sm
502
+
503
+ def generate_mxf4nvf4_math_instructions_sm100(level: int, enable_runtime_dtype = True, enable_compile_time_dtype = True):
504
+ """
505
+ Generate all BlockScaledTensorOp math instructions for MXFP4 and MXFP4 MMA that are supported by SM100 at or above the given level.
506
+
507
+ Args:
508
+ level: The global level to generate math instructions for.
509
+ enable_runtime_dtype: Whether to generate runtime dtype math instructions.
510
+ enable_compile_time_dtype: Whether to generate compile time dtype math instructions.
511
+
512
+ Returns:
513
+ A tuple of two lists of MathInstruction objects.
514
+ The first list contains the math instructions for 1SM, and the second list contains the math instructions for 2SM.
515
+ """
516
+ tcgen05_level = get_tcgen05_level_from_global_level(level)
517
+ math_instructions_1sm = []
518
+ math_instructions_2sm = []
519
+
520
+ shapes_1sm = [
521
+ shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_1SM.items() if tcgen05_level >= min_level
522
+ ]
523
+ shapes_2sm = [
524
+ shape for shape, min_level in SM100_MMA_SHAPES_MXF4NVF4_DENSE_2SM.items() if tcgen05_level >= min_level
525
+ ]
526
+
527
+ for shape in shapes_1sm:
528
+ if enable_runtime_dtype:
529
+
530
+ runtime_types = [ DataType.f4 ]
531
+
532
+ for a_type, b_type in product(runtime_types, repeat=2):
533
+ math_instructions_1sm.append(
534
+ MathInstruction(
535
+ shape,
536
+ a_type, b_type, DataType.f32,
537
+ OpcodeClass.BlockScaledTensorOp,
538
+ MathOperation.multiply_add,
539
+ DataType.ue8m0)
540
+ )
541
+ math_instructions_1sm.append(
542
+ MathInstruction(
543
+ shape,
544
+ a_type, b_type, DataType.f32,
545
+ OpcodeClass.BlockScaledTensorOp,
546
+ MathOperation.multiply_add,
547
+ DataType.ue4m3)
548
+ )
549
+
550
+
551
+ if enable_compile_time_dtype:
552
+ compile_time_types = [ DataType.e2m1,
553
+ ]
554
+
555
+ for a_type, b_type in product(compile_time_types, repeat=2):
556
+ math_instructions_1sm.append(
557
+ MathInstruction(
558
+ shape,
559
+ a_type, b_type, DataType.f32,
560
+ OpcodeClass.BlockScaledTensorOp,
561
+ MathOperation.multiply_add,
562
+ DataType.ue8m0)
563
+ )
564
+ math_instructions_1sm.append(
565
+ MathInstruction(
566
+ shape,
567
+ a_type, b_type, DataType.f32,
568
+ OpcodeClass.BlockScaledTensorOp,
569
+ MathOperation.multiply_add,
570
+ DataType.ue4m3)
571
+ )
572
+
573
+
574
+ for shape in shapes_2sm:
575
+ if enable_runtime_dtype:
576
+
577
+ runtime_types = [ DataType.f4 ]
578
+
579
+ for a_type, b_type in product(runtime_types, repeat=2):
580
+ math_instructions_2sm.append(
581
+ MathInstruction(
582
+ shape,
583
+ a_type, b_type, DataType.f32,
584
+ OpcodeClass.BlockScaledTensorOp,
585
+ MathOperation.multiply_add,
586
+ DataType.ue8m0)
587
+ )
588
+ math_instructions_2sm.append(
589
+ MathInstruction(
590
+ shape,
591
+ a_type, b_type, DataType.f32,
592
+ OpcodeClass.BlockScaledTensorOp,
593
+ MathOperation.multiply_add,
594
+ DataType.ue4m3)
595
+ )
596
+
597
+
598
+ if enable_compile_time_dtype:
599
+ compile_time_types = [ DataType.e2m1,
600
+ ]
601
+
602
+ for a_type, b_type in product(compile_time_types, repeat=2):
603
+ math_instructions_2sm.append(
604
+ MathInstruction(
605
+ shape,
606
+ a_type, b_type, DataType.f32,
607
+ OpcodeClass.BlockScaledTensorOp,
608
+ MathOperation.multiply_add,
609
+ DataType.ue8m0)
610
+ )
611
+ math_instructions_2sm.append(
612
+ MathInstruction(
613
+ shape,
614
+ a_type, b_type, DataType.f32,
615
+ OpcodeClass.BlockScaledTensorOp,
616
+ MathOperation.multiply_add,
617
+ DataType.ue4m3)
618
+ )
619
+
620
+
621
+ return math_instructions_1sm, math_instructions_2sm
622
+
623
+
624
+ def generate_cluster_shapes_sm100(level: int, change_priority_func : Union[Callable, None] = None):
625
+ """
626
+ Generate all cluster shapes for SM100 at or above the given level.
627
+
628
+ Args:
629
+ level: The global level to generate cluster shapes for.
630
+
631
+ Returns:
632
+ A tuple of two lists of cluster shapes.
633
+ The first list contains the cluster shapes for 1SM, and the second list contains the cluster shapes for 2SM.
634
+ """
635
+ cluster_level = get_cluster_level_from_global_level(level)
636
+
637
+ assert cluster_level >= 4
638
+
639
+ if change_priority_func is not None:
640
+ SM100_CLUSTER_SHAPES_1SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_1SM)
641
+ SM100_CLUSTER_SHAPES_2SM_CPY = copy.deepcopy(SM100_CLUSTER_SHAPES_2SM)
642
+ change_priority_func(SM100_CLUSTER_SHAPES_1SM_CPY, SM100_CLUSTER_SHAPES_2SM_CPY)
643
+ shapes_1sm = [
644
+ list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM_CPY.items() if cluster_level >= min_level
645
+ ]
646
+ shapes_2sm = [
647
+ list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM_CPY.items() if cluster_level >= min_level
648
+ ]
649
+
650
+ return shapes_1sm, shapes_2sm
651
+
652
+ else:
653
+
654
+ shapes_1sm = [
655
+ list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_1SM.items() if cluster_level >= min_level
656
+ ]
657
+ shapes_2sm = [
658
+ list(shape) for shape, min_level in SM100_CLUSTER_SHAPES_2SM.items() if cluster_level >= min_level
659
+ ]
660
+
661
+ return shapes_1sm, shapes_2sm
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_shapes.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Valid WGMMA shapes, MMA multipliers, and cluster sizes for SM90, associated with levels.
35
+ These shape and level pairs are defined as dicts, where keys are shapes and values are their
36
+ associated levels. If the user input level for that category (MMA multiplier, WGMMA shape, cluster
37
+ size) is smaller than a shape's associated level, it will be excluded, and otherwise, included.
38
+ Higher levels are therefore less likely emitted, but lower levels are more emitted more frequently.
39
+ Level 0 is always emitted. The default behavior in `generator.py` is that level 1 is only emitted
40
+ when the `--kernel` argument is non-empty.
41
+ """
42
+
43
+ # NOTE: more combinations are possible here.
44
+ # Levels [0, 3] exist in order to control exactly what configs are generated in different dtypes.
45
+ # The rest are only used in the exhaustive mode (when the corresponding level digit is 9).
46
+ # MMA multipliers are multiplied by MMA instruction shapes (WGMMA shapes) to get CTA shapes.
47
+ SM90_MMA_MULTIPLIERS = {
48
+ (2, 1, 4): 0,
49
+ (1, 1, 4): 1,
50
+ (4, 1, 4): 2,
51
+ (2, 2, 4): 3,
52
+ (2, 1, 8): 4,
53
+ (4, 1, 8): 4,
54
+ (1, 1, 8): 4,
55
+ (2, 2, 8): 4,
56
+ (2, 1, 16): 5,
57
+ (4, 1, 16): 5,
58
+ (1, 1, 16): 5,
59
+ (2, 2, 16): 5,
60
+ }
61
+
62
+ # Level 0: only (1, 2, 1) -- fp8 dense gemms in pruned case
63
+ # Level 1: clusters with 2 CTAs -- all but fp8 (s8, u8, f16, b16, f32, tf32) dense gemms in pruned case
64
+ # Level 2: clusters with 1 or 2 CTAs
65
+ # Level 3: clusters with 1, 2, or 4 CTAs
66
+ # Level 4: clusters with 1, 2, 4, or 8 CTAs
67
+ # Level 5: clusters with 1, 2, 4, 8, or 16 CTAs
68
+ SM90_CLUSTER_SIZES = {
69
+ (1, 2, 1): 0,
70
+ (2, 1, 1): 1,
71
+ (1, 1, 1): 2,
72
+ (2, 2, 1): 3,
73
+ (1, 4, 1): 3,
74
+ (4, 1, 1): 3,
75
+ (2, 4, 1): 4,
76
+ (4, 2, 1): 4,
77
+ (1, 8, 1): 4,
78
+ (8, 1, 1): 4,
79
+ (4, 4, 1): 5,
80
+ }
81
+
82
+
83
+ # WGMMA shapes
84
+ # Level 0: "default" shape only,
85
+ # Level 1: additional shapes for the unpruned case (tf32 only)
86
+ # Level 2: shapes that are all powers of 2
87
+ # Level 3: all other shapes
88
+ SM90_WGMMA_SHAPES_FP16_BF16_DENSE = {
89
+ (64, 8, 16): 2,
90
+ (64, 16, 16): 2,
91
+ (64, 24, 16): 3,
92
+ (64, 32, 16): 2,
93
+ (64, 40, 16): 3,
94
+ (64, 48, 16): 3,
95
+ (64, 56, 16): 3,
96
+ (64, 64, 16): 2,
97
+ (64, 72, 16): 3,
98
+ (64, 80, 16): 3,
99
+ (64, 88, 16): 3,
100
+ (64, 96, 16): 3,
101
+ (64, 104, 16): 3,
102
+ (64, 112, 16): 3,
103
+ (64, 120, 16): 3,
104
+ (64, 128, 16): 0,
105
+ (64, 136, 16): 3,
106
+ (64, 144, 16): 3,
107
+ (64, 152, 16): 3,
108
+ (64, 160, 16): 3,
109
+ (64, 168, 16): 3,
110
+ (64, 176, 16): 3,
111
+ (64, 184, 16): 3,
112
+ (64, 192, 16): 3,
113
+ (64, 200, 16): 3,
114
+ (64, 208, 16): 3,
115
+ (64, 216, 16): 3,
116
+ (64, 224, 16): 3,
117
+ (64, 232, 16): 3,
118
+ (64, 240, 16): 3,
119
+ (64, 248, 16): 3,
120
+ (64, 256, 16): 1,
121
+ }
122
+
123
+ SM90_WGMMA_SHAPES_TF32_DENSE = {
124
+ (64, 8, 8): 2,
125
+ (64, 16, 8): 2,
126
+ (64, 24, 8): 3,
127
+ (64, 32, 8): 2,
128
+ (64, 40, 8): 3,
129
+ (64, 48, 8): 3,
130
+ (64, 56, 8): 3,
131
+ (64, 64, 8): 2,
132
+ (64, 72, 8): 3,
133
+ (64, 80, 8): 3,
134
+ (64, 88, 8): 3,
135
+ (64, 96, 8): 3,
136
+ (64, 104, 8): 3,
137
+ (64, 112, 8): 3,
138
+ (64, 120, 8): 3,
139
+ (64, 128, 8): 0,
140
+ (64, 136, 8): 3,
141
+ (64, 144, 8): 3,
142
+ (64, 152, 8): 3,
143
+ (64, 160, 8): 3,
144
+ (64, 168, 8): 3,
145
+ (64, 176, 8): 3,
146
+ (64, 184, 8): 3,
147
+ (64, 192, 8): 3,
148
+ (64, 200, 8): 3,
149
+ (64, 208, 8): 3,
150
+ (64, 216, 8): 3,
151
+ (64, 224, 8): 3,
152
+ (64, 232, 8): 3,
153
+ (64, 240, 8): 3,
154
+ (64, 248, 8): 3,
155
+ (64, 256, 8): 1,
156
+ }
157
+
158
+ SM90_WGMMA_SHAPES_FP8_DENSE = {
159
+ (64, 8, 32): 2,
160
+ (64, 16, 32): 2,
161
+ (64, 24, 32): 3,
162
+ (64, 32, 32): 2,
163
+ (64, 40, 32): 3,
164
+ (64, 48, 32): 3,
165
+ (64, 56, 32): 3,
166
+ (64, 64, 32): 2,
167
+ (64, 72, 32): 3,
168
+ (64, 80, 32): 3,
169
+ (64, 88, 32): 3,
170
+ (64, 96, 32): 3,
171
+ (64, 104, 32): 3,
172
+ (64, 112, 32): 3,
173
+ (64, 120, 32): 3,
174
+ (64, 128, 32): 0,
175
+ (64, 136, 32): 3,
176
+ (64, 144, 32): 3,
177
+ (64, 152, 32): 3,
178
+ (64, 160, 32): 3,
179
+ (64, 168, 32): 3,
180
+ (64, 176, 32): 3,
181
+ (64, 184, 32): 3,
182
+ (64, 192, 32): 3,
183
+ (64, 200, 32): 3,
184
+ (64, 208, 32): 3,
185
+ (64, 216, 32): 3,
186
+ (64, 224, 32): 3,
187
+ (64, 232, 32): 3,
188
+ (64, 240, 32): 3,
189
+ (64, 248, 32): 3,
190
+ (64, 256, 32): 1,
191
+ }
192
+
193
+ SM90_WGMMA_SHAPES_INT8_DENSE = {
194
+ (64, 8, 32): 2,
195
+ (64, 16, 32): 2,
196
+ (64, 24, 32): 3,
197
+ (64, 32, 32): 2,
198
+ (64, 48, 32): 3,
199
+ (64, 64, 32): 2,
200
+ (64, 80, 32): 3,
201
+ (64, 96, 32): 3,
202
+ (64, 112, 32): 3,
203
+ (64, 128, 32): 0,
204
+ (64, 144, 32): 3,
205
+ (64, 160, 32): 3,
206
+ (64, 176, 32): 3,
207
+ (64, 192, 32): 3,
208
+ (64, 208, 32): 3,
209
+ (64, 224, 32): 3,
210
+ (64, 240, 32): 3,
211
+ (64, 256, 32): 1,
212
+ }
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/sm90_utils.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for enumerating CUTLASS library SM90 kernels
35
+ """
36
+
37
+ import argparse
38
+ import enum
39
+ from itertools import product
40
+ import math
41
+ import logging
42
+ import os.path
43
+ import shutil
44
+ import sys
45
+ import copy
46
+ from typing import Any, Optional, Sequence, Tuple, List
47
+
48
+ try:
49
+ import builtins
50
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
51
+ raise ImportError("Disabling attempt to import cutlass_library")
52
+ from cutlass_library.library import *
53
+ except ImportError:
54
+ from library import *
55
+
56
+ # NOTE: this is a duplicate of CudaToolkitVersionSatisfies in generator.py
57
+ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0):
58
+
59
+ # by default, use the latest CUDA Toolkit version
60
+ cuda_version = [11, 0, 132]
61
+
62
+ # Update cuda_version based on parsed string
63
+ if semantic_ver_string != '':
64
+ for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]):
65
+ if i < len(cuda_version):
66
+ cuda_version[i] = x
67
+ else:
68
+ cuda_version.append(x)
69
+ return cuda_version >= [major, minor, patch]
70
+
71
+ #### Step 0: define levels
72
+
73
+ # One integer level controls multiple "generators" and how many
74
+ # combinations they generate. That is the "global" level.
75
+ # "Generators" are WGMMA shapes, MMA multipliers, cluster sizes, and
76
+ # anything that is eventually involved in the Cartesian product
77
+ # which yields our kernel configurations.
78
+ # For simplicity, each generator defines their own levels,
79
+ # starting from 0. As a rule we assume 10 or fewer levels, making
80
+ # their level a digit.
81
+ # The "global" level simply stacks these digits and represents them
82
+ # as a single integer.
83
+ #
84
+ # For example, level 500 indicates cluster sizes are at level 5, MMA
85
+ # multipliers are at level 0, and WGMMA shapes are at level 0 as well.
86
+ #
87
+ # Here we define the global level to generator level mappings.
88
+
89
+
90
+ def get_wgmma_level_from_global_level(global_level: int):
91
+ return global_level % 10
92
+
93
+
94
+ def get_mma_level_from_global_level(global_level: int):
95
+ return (global_level // 10) % 10
96
+
97
+
98
+ def get_cluster_level_from_global_level(global_level: int):
99
+ return (global_level // 100) % 10
100
+
101
+
102
+ def get_pruning_level_from_global_level(global_level: int):
103
+ return (global_level // 1000) % 10
104
+
105
+
106
+ #### Step 1: generate MMA instruction shapes based on levels
107
+
108
+ try:
109
+ from .sm90_shapes import (
110
+ SM90_MMA_MULTIPLIERS,
111
+ SM90_CLUSTER_SIZES,
112
+ SM90_WGMMA_SHAPES_TF32_DENSE,
113
+ SM90_WGMMA_SHAPES_FP16_BF16_DENSE,
114
+ SM90_WGMMA_SHAPES_FP8_DENSE,
115
+ SM90_WGMMA_SHAPES_INT8_DENSE,
116
+ )
117
+ except:
118
+ from sm90_shapes import (
119
+ SM90_MMA_MULTIPLIERS,
120
+ SM90_CLUSTER_SIZES,
121
+ SM90_WGMMA_SHAPES_TF32_DENSE,
122
+ SM90_WGMMA_SHAPES_FP16_BF16_DENSE,
123
+ SM90_WGMMA_SHAPES_FP8_DENSE,
124
+ SM90_WGMMA_SHAPES_INT8_DENSE,
125
+ )
126
+
127
+
128
+ def generate_tf32_math_instruction_shapes_sm90(level: int):
129
+ assert isinstance(level, int) and level >= 0
130
+ filtered_list_of_wgmma_shapes = [
131
+ wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_TF32_DENSE.items() if level >= min_level
132
+ ]
133
+ return filtered_list_of_wgmma_shapes
134
+
135
+ def generate_fp16_bf16_math_instruction_shapes_sm90(level: int):
136
+ assert isinstance(level, int) and level >= 0
137
+ filtered_list_of_wgmma_shapes = [
138
+ wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP16_BF16_DENSE.items() if level >= min_level
139
+ ]
140
+ return filtered_list_of_wgmma_shapes
141
+
142
+ def generate_fp8_math_instruction_shapes_sm90(level: int):
143
+ assert isinstance(level, int) and level >= 0
144
+ filtered_list_of_wgmma_shapes = [
145
+ wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_FP8_DENSE.items() if level >= min_level
146
+ ]
147
+ return filtered_list_of_wgmma_shapes
148
+
149
+ def generate_int8_math_instruction_shapes_sm90(level: int):
150
+ assert isinstance(level, int) and level >= 0
151
+ filtered_list_of_wgmma_shapes = [
152
+ wgmma_shape for wgmma_shape, min_level in SM90_WGMMA_SHAPES_INT8_DENSE.items() if level >= min_level
153
+ ]
154
+ return filtered_list_of_wgmma_shapes
155
+
156
+ def generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level: int, a_type: DataType, b_type: DataType):
157
+ # DataTypeSize are in the unit of bits
158
+ a_bytes = DataTypeSize[a_type] // 8
159
+ b_bytes = DataTypeSize[b_type] // 8
160
+ if a_bytes == 4 or b_bytes == 4:
161
+ return generate_tf32_math_instruction_shapes_sm90(wgmma_level)
162
+ elif a_bytes == 2 or b_bytes == 2:
163
+ return generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level)
164
+ else:
165
+ return generate_fp8_math_instruction_shapes_sm90(wgmma_level)
166
+
167
+ ###########
168
+
169
+ def generate_tf32_math_instructions_sm90(level: int):
170
+ wgmma_level = get_wgmma_level_from_global_level(level)
171
+ math_instructions = []
172
+ for math_instruction_shape in generate_tf32_math_instruction_shapes_sm90(wgmma_level):
173
+ math_instructions.append(
174
+ MathInstruction(
175
+ math_instruction_shape,
176
+ DataType.tf32, DataType.tf32, DataType.f32,
177
+ OpcodeClass.TensorOp,
178
+ MathOperation.multiply_add)
179
+ )
180
+ return math_instructions
181
+
182
+ def generate_fp16_bf16_math_instructions_sm90(level: int):
183
+ wgmma_level = get_wgmma_level_from_global_level(level)
184
+ math_instructions = []
185
+ for math_instruction_shape in generate_fp16_bf16_math_instruction_shapes_sm90(wgmma_level):
186
+ math_instructions += [
187
+ MathInstruction(
188
+ math_instruction_shape,
189
+ DataType.f16, DataType.f16, DataType.f16,
190
+ OpcodeClass.TensorOp,
191
+ MathOperation.multiply_add),
192
+ MathInstruction(
193
+ math_instruction_shape,
194
+ DataType.f16, DataType.f16, DataType.f32,
195
+ OpcodeClass.TensorOp,
196
+ MathOperation.multiply_add),
197
+ MathInstruction(
198
+ math_instruction_shape,
199
+ DataType.bf16, DataType.bf16, DataType.f32,
200
+ OpcodeClass.TensorOp,
201
+ MathOperation.multiply_add),
202
+ ]
203
+ return math_instructions
204
+
205
+ def generate_fp8_math_instructions_sm90(level: int):
206
+ wgmma_level = get_wgmma_level_from_global_level(level)
207
+ math_instructions = []
208
+ for math_instruction_shape in generate_fp8_math_instruction_shapes_sm90(wgmma_level):
209
+ math_instructions += [
210
+ MathInstruction(
211
+ math_instruction_shape,
212
+ DataType.e4m3, DataType.e4m3, DataType.f32,
213
+ OpcodeClass.TensorOp,
214
+ MathOperation.multiply_add),
215
+ MathInstruction(
216
+ math_instruction_shape,
217
+ DataType.e4m3, DataType.e5m2, DataType.f32,
218
+ OpcodeClass.TensorOp,
219
+ MathOperation.multiply_add),
220
+ MathInstruction(
221
+ math_instruction_shape,
222
+ DataType.e5m2, DataType.e4m3, DataType.f32,
223
+ OpcodeClass.TensorOp,
224
+ MathOperation.multiply_add),
225
+ MathInstruction(
226
+ math_instruction_shape,
227
+ DataType.e5m2, DataType.e5m2, DataType.f32,
228
+ OpcodeClass.TensorOp,
229
+ MathOperation.multiply_add),
230
+ ]
231
+ return math_instructions
232
+
233
+ def generate_mixed_dtype_math_instructions_sm90(level: int, types_of_a_b_acc: List[Tuple[DataType, DataType, DataType]]):
234
+ wgmma_level = get_wgmma_level_from_global_level(level)
235
+ math_instructions = []
236
+ for a_type, b_type, acc_type in types_of_a_b_acc:
237
+ math_instruction_shapes = generate_mixed_dtype_math_instructions_shapes_sm90(wgmma_level, a_type, b_type)
238
+ for math_instruction_shape in math_instruction_shapes:
239
+ math_instructions += [
240
+ MathInstruction(
241
+ math_instruction_shape,
242
+ a_type, b_type, acc_type,
243
+ OpcodeClass.TensorOp,
244
+ MathOperation.multiply_add
245
+ ),
246
+ ]
247
+ return math_instructions
248
+
249
+ def generate_int8_math_instructions_sm90(level: int):
250
+ wgmma_level = get_wgmma_level_from_global_level(level)
251
+ math_instructions = []
252
+ for math_instruction_shape in generate_int8_math_instruction_shapes_sm90(wgmma_level):
253
+ math_instructions += [
254
+ MathInstruction(
255
+ math_instruction_shape,
256
+ DataType.s8, DataType.s8, DataType.s32,
257
+ OpcodeClass.TensorOp,
258
+ MathOperation.multiply_add),
259
+ MathInstruction(
260
+ math_instruction_shape,
261
+ DataType.u8, DataType.u8, DataType.s32,
262
+ OpcodeClass.TensorOp,
263
+ MathOperation.multiply_add),
264
+ ]
265
+ return math_instructions
266
+
267
+ def make_sparse_math_instructions(math_instructions):
268
+ sparse_instructions = []
269
+ for inst in math_instructions:
270
+ if inst.opcode_class == OpcodeClass.TensorOp:
271
+ sparse_instructions.append(MathInstruction(
272
+ (inst.instruction_shape[0], inst.instruction_shape[1], inst.instruction_shape[2] * 2),
273
+ inst.element_a, inst.element_b, inst.element_accumulator,
274
+ OpcodeClass.SparseTensorOp,
275
+ inst.math_operation),)
276
+ return sparse_instructions
277
+
278
+
279
+ #### Step 2: generate tile descriptions from math instruction shapes
280
+
281
+ def is_tile_desc_valid(tile_description):
282
+ if tile_description.minimum_compute_capability != 90 or tile_description.maximum_compute_capability != 90:
283
+ return False
284
+
285
+ element_a, element_b, element_accum = (
286
+ tile_description.math_instruction.element_a,
287
+ tile_description.math_instruction.element_b,
288
+ tile_description.math_instruction.element_accumulator
289
+ )
290
+
291
+ cluster_size, cta_shape = (
292
+ tile_description.cluster_shape,
293
+ tile_description.threadblock_shape,
294
+ )
295
+ grid_size = (
296
+ cta_shape[0] * cluster_size[0] +
297
+ cta_shape[1] * cluster_size[1] +
298
+ cta_shape[2] * cluster_size[2]
299
+ )
300
+ num_ctas_in_cluster = cluster_size[0] * cluster_size[1] * cluster_size[2]
301
+ cluster_shape = (
302
+ cluster_size[0] * cta_shape[0],
303
+ cluster_size[1] * cta_shape[1],
304
+ cluster_size[2] * cta_shape[2]
305
+ )
306
+
307
+ FP32_TYPES = [DataType.f32, DataType.tf32]
308
+ FP16_TYPES = [DataType.f16, DataType.bf16]
309
+ is_fp32 = element_a in FP32_TYPES and element_b in FP32_TYPES
310
+ is_fp16 = element_a in FP16_TYPES and element_b in FP16_TYPES
311
+
312
+ # Maximum number of CTAs per cluster is 8 for Hopper, but up to 16 is
313
+ # allowed for non portable clusters.
314
+ if num_ctas_in_cluster > 16 or num_ctas_in_cluster < 1:
315
+ return False
316
+
317
+ if grid_size < 1:
318
+ return False
319
+
320
+ # SM90 WGMMA shapes are always 64 across M, therefore
321
+ # CTA shape across M must always be a multiple of 64.
322
+ if cta_shape[0] < 64 or cta_shape[0] % 64 != 0:
323
+ return False
324
+
325
+ # The minimum WGMMA shape across N is 8, and increments
326
+ # vary across different dtypes, but they're never smaller
327
+ # than 8. The minimum CTA shape allowed across N though is 16.
328
+ if cta_shape[1] < 16 or cta_shape[1] % 8 != 0:
329
+ return False
330
+
331
+ # SM90 WGMMA shapes across K are always 8 for 32 bit dense
332
+ # operations, 16 for 16 bit, and 32 for 8 bit. In any case,
333
+ # the CTA shape across K should be a multiple of 8 and at least
334
+ # twice the WGMMA shape across K.
335
+ if cta_shape[2] < 16 or cta_shape[2] % 8 != 0:
336
+ return False
337
+
338
+ # Minimum of 2 stages (very rough heuristic that may filter out valid kernel configs)
339
+ if (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 256:
340
+ return False
341
+
342
+ if is_fp32 and (cluster_shape[0] >= 128 or cluster_shape[1] >= 128) and cluster_shape[2] >= 128:
343
+ return False
344
+
345
+ if is_fp32 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 64:
346
+ return False
347
+
348
+ if is_fp16 and cluster_shape[0] >= 256 and cluster_shape[1] >= 256 and cluster_shape[2] >= 128:
349
+ return False
350
+
351
+ # CTA shape upper bound: <256, 256, 256>
352
+ if cta_shape[0] > 256 or cta_shape[1] > 256 or cta_shape[2] > 256:
353
+ return False
354
+
355
+ return True
356
+
357
+ def get_mma_multipliers(level: int):
358
+ assert isinstance(level, int) and level >= 0
359
+ mma_level = get_mma_level_from_global_level(level)
360
+ return [
361
+ mma_mul for mma_mul, mma_min_level in SM90_MMA_MULTIPLIERS.items() if mma_level >= mma_min_level
362
+ ]
363
+
364
+ def get_cluster_sizes(level: int, is_aligned: bool):
365
+ if not is_aligned:
366
+ return [(1, 1, 1)]
367
+ assert isinstance(level, int) and level >= 0
368
+ cluster_level = get_cluster_level_from_global_level(level)
369
+ return [
370
+ cluster_size for cluster_size, cluster_min_level in SM90_CLUSTER_SIZES.items() if cluster_level >= cluster_min_level
371
+ ]
372
+
373
+ def generate_tile_descriptions_sm90(math_instructions, is_aligned: bool, level: int):
374
+ tile_descriptions = set()
375
+ mma_multipliers, cluster_sizes = get_mma_multipliers(level), get_cluster_sizes(level, is_aligned)
376
+ for math_inst, mma_mul, cluster_size in product(math_instructions, mma_multipliers, cluster_sizes):
377
+
378
+ # generator can stamp out duplicate kernels, because it doesn't explicitly set instruction
379
+ # shape for SM90 kernels, and the 3.X collective API doesn't directly expose them when using
380
+ # the auto kernel schedule.
381
+
382
+ math_inst_stub = copy.deepcopy(math_inst)
383
+ math_inst_stub.instruction_shape = [0, 0, 0]
384
+
385
+ tile_desc = TileDescription(
386
+ threadblock_shape=[
387
+ math_inst.instruction_shape[0] * mma_mul[0],
388
+ math_inst.instruction_shape[1] * mma_mul[1],
389
+ math_inst.instruction_shape[2] * mma_mul[2]
390
+ ],
391
+ stages=0,
392
+ warp_count=[4, 1, 1],
393
+ math_instruction=math_inst_stub,
394
+ min_compute=90,
395
+ max_compute=90,
396
+ cluster_shape=cluster_size)
397
+ # For sparse kernels K-tile is twice as large (due to 2x MMA-K size)
398
+ # Reduce it to same size as dense to afford more smem stages
399
+ if math_inst.opcode_class == OpcodeClass.SparseTensorOp:
400
+ tile_desc.threadblock_shape[2] = tile_desc.threadblock_shape[2] // 2
401
+ if is_tile_desc_valid(tile_desc):
402
+ tile_descriptions.add(tile_desc)
403
+
404
+ return tile_descriptions
405
+
406
+ #### Step 3: map tile description to valid schedules
407
+
408
+ def is_tile_desc_compatible_with_cooperative(tile_description):
409
+ # Cooperative kernels require a minimum CTA-M of 128
410
+ return tile_description.threadblock_shape[0] % 128 == 0
411
+
412
+
413
+ def can_tile_desc_use_shmem_in_epilogue(tile_description, data_types):
414
+ dtype_a, dtype_b, dtype_c, dtype_d, dtype_acc, dtype_epi = (
415
+ data_types["a_type"],
416
+ data_types["b_type"],
417
+ data_types["c_type"],
418
+ data_types["d_type"],
419
+ data_types["acc_type"],
420
+ data_types["epi_type"]
421
+ )
422
+ mn = tile_description.threadblock_shape[0] * tile_description.threadblock_shape[1]
423
+ bitsize_c, bitsize_d = DataTypeSize[dtype_c], DataTypeSize[dtype_d]
424
+
425
+ shmem_bits_c, shmem_bits_d = bitsize_c * mn, bitsize_d * mn
426
+ shmem_bits_total = shmem_bits_c + shmem_bits_d
427
+ # Magic number: 2^20
428
+ # Existing logic suggested that tile shape 256x128 (or 128x256)
429
+ # would run out of shmem if D is FP32, and source is needed.
430
+ # That would be 256 * 128 * 32 == 2^21 (~262 KB), which is over the limit.
431
+ # Hopper's max shmem size is 228 KB, and 2^20 ~= 131 KB.
432
+ # Since epilogue can't possibly use ALL of the shmem available
433
+ # we can just settle on 2^20 bits (~ 131 KB) being the upper bound
434
+ # we would allow for epilogue.
435
+ # This can be different for non-persistent kernels where epilogue and
436
+ # mainloop shmem is shared.
437
+ if shmem_bits_total > 2 ** 20:
438
+ return False
439
+
440
+ return True
441
+
442
+
443
+ def get_valid_schedules(tile_description, cuda_version, is_aligned, data_types, layout,
444
+ instantiation_level, enable_fp8_fast_acc=True, gemm_kind=GemmKind.Universal3x):
445
+ # Level 0: prune according to existing generator.py behavior
446
+ # Level >= 1: no pruning
447
+ level = get_pruning_level_from_global_level(instantiation_level)
448
+ schedules = []
449
+ stream_k_schedules = []
450
+
451
+ if not is_tile_desc_valid(tile_description):
452
+ return schedules, stream_k_schedules
453
+
454
+ FP16_TYPES = [DataType.f16, DataType.bf16]
455
+ is_fp16 = data_types["a_type"] in FP16_TYPES and data_types["b_type"] in FP16_TYPES
456
+
457
+ FP8_TYPES = [DataType.e4m3, DataType.e5m2]
458
+ is_fp8 = data_types["a_type"] in FP8_TYPES and data_types["b_type"] in FP8_TYPES
459
+ can_do_fp8_fast_accum = is_fp8 and enable_fp8_fast_acc
460
+
461
+ FP32_TYPES = [DataType.f32, DataType.tf32]
462
+ is_fp32 = data_types["a_type"] in FP32_TYPES and data_types["b_type"] in FP32_TYPES
463
+ requires_transposed_epilogue = is_fp32 and layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.RowMajor
464
+
465
+ can_do_cooperative = is_tile_desc_compatible_with_cooperative(tile_description)
466
+ can_do_tma_epilogue = is_aligned and not requires_transposed_epilogue and can_tile_desc_use_shmem_in_epilogue(tile_description, data_types)
467
+
468
+ default_epilogue = EpilogueScheduleType.NoSmemWarpSpecialized if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed
469
+ auto_epilogue = EpilogueScheduleType.ScheduleAuto if not requires_transposed_epilogue else EpilogueScheduleType.EpilogueTransposed
470
+
471
+ cta_m, cta_n, cta_k = (
472
+ tile_description.threadblock_shape[0],
473
+ tile_description.threadblock_shape[1],
474
+ tile_description.threadblock_shape[2]
475
+ )
476
+ c_type = data_types["c_type"]
477
+ d_type = data_types["d_type"]
478
+ is_void_c = c_type == DataType.void
479
+
480
+ # Filter out invalid kernels
481
+ is_nt = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.RowMajor
482
+ is_tn = layout[0][0] == LayoutType.RowMajor and layout[1][0] == LayoutType.ColumnMajor
483
+ is_nn = layout[0][0] == LayoutType.ColumnMajor and layout[1][0] == LayoutType.ColumnMajor
484
+
485
+ # static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0,
486
+ # "Copy size must evenly divide SMEM tile.");
487
+ if is_fp32 and is_nt and (cta_n % cta_k != 0):
488
+ return [], []
489
+
490
+ # static_assert(!TransposeB || (cutlass::bits_to_bytes((size<1>(SmemLayoutB{}) * sizeof_bits<InternalElementB>::value))) == 128,
491
+ # "SmemLayoutB K must be 128bytes to be transposed.")
492
+ if is_fp32 and is_nt and cta_k != 32:
493
+ return [], []
494
+
495
+ # Static assert failure when instantiating SmemLayoutB
496
+ if is_fp32 and (is_tn or is_nn) and (cta_n % cta_k != 0):
497
+ return [], []
498
+
499
+ grouped = is_grouped(gemm_kind)
500
+ if grouped:
501
+ # the following cases are unsupported by grouped GEMM
502
+ if not is_aligned:
503
+ return [], []
504
+ if requires_transposed_epilogue:
505
+ return [], []
506
+
507
+ # Early pruning
508
+ if level < 1:
509
+ # Don't stamp out FP16/BF16 kernels smaller than or equal to 64x128x64
510
+ if is_fp16 and cta_m <= 64 and cta_n <= 128 and cta_k <= 64:
511
+ return [], []
512
+
513
+ # FP8 configs with CTA tile larger than or equal to 256x128x128 limit data types and schedules
514
+ is_large_fp8_tile = is_fp8 and cta_m >= 256 and cta_n >= 128 and cta_k >= 128
515
+ if is_large_fp8_tile:
516
+ # Only void-C, and only FP8 outputs allowed
517
+ if not is_void_c or d_type not in FP8_TYPES:
518
+ return [], []
519
+ if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative and can_do_tma_epilogue:
520
+ schedules = []
521
+ if is_blockwise(gemm_kind):
522
+ schedules.append(
523
+ [
524
+ to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
525
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
526
+ ])
527
+ else:
528
+ schedules.append(
529
+ [
530
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
531
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
532
+ ])
533
+ schedules.append(
534
+ [
535
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
536
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
537
+ ])
538
+ return schedules, []
539
+ return [], []
540
+
541
+ if is_fp8 and not is_large_fp8_tile:
542
+ valid_dtypes_for_c = [DataType.f32, DataType.bf16, DataType.f16, DataType.void]
543
+ # Prune all configs with fp8 source, and all configs with non-fp8 output
544
+ # that have different dtypes for source and output.
545
+ if c_type not in valid_dtypes_for_c or (d_type not in FP8_TYPES and c_type != d_type):
546
+ return [], []
547
+
548
+ # FP32/TF32 kernels don't stamp out void-C
549
+ if is_fp32 and is_void_c:
550
+ return [], []
551
+
552
+ # Void-c only makes a difference for TMA epilogues
553
+ if is_void_c and not can_do_tma_epilogue:
554
+ return [], []
555
+
556
+ # For mixed input data types
557
+ a_type_size = DataTypeSize[data_types["a_type"]]
558
+ b_type_size = DataTypeSize[data_types["b_type"]]
559
+ if a_type_size != b_type_size and CudaToolkitVersionSatisfies(cuda_version, 12, 1):
560
+ schedules = []
561
+ stream_k_schedules = []
562
+ epilogue_schedule = EpilogueScheduleType.TmaWarpSpecialized
563
+ if a_type_size > b_type_size:
564
+ epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
565
+
566
+ if not is_blockwise(gemm_kind):
567
+ schedules.append([
568
+ KernelScheduleType.TmaWarpSpecialized,
569
+ epilogue_schedule
570
+ ])
571
+ schedules.append([
572
+ KernelScheduleType.TmaWarpSpecializedPingpong,
573
+ epilogue_schedule
574
+ ])
575
+ if cta_m >= 128:
576
+ if a_type_size > b_type_size:
577
+ epilogue_schedule = EpilogueScheduleType.EpilogueTransposed
578
+ else:
579
+ epilogue_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative
580
+ if is_blockwise(gemm_kind):
581
+ schedules.append([
582
+ KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
583
+ epilogue_schedule
584
+ ])
585
+ else:
586
+ schedules.append([
587
+ KernelScheduleType.TmaWarpSpecializedCooperative,
588
+ epilogue_schedule
589
+ ])
590
+ stream_k_schedules.append([
591
+ KernelScheduleType.TmaWarpSpecializedCooperative,
592
+ epilogue_schedule
593
+ ])
594
+ return schedules, stream_k_schedules
595
+
596
+ if not is_aligned and not is_blockwise(gemm_kind):
597
+ schedules = [[KernelScheduleType.CpAsyncWarpSpecialized,
598
+ default_epilogue]]
599
+ stream_k_schedules = []
600
+
601
+ if CudaToolkitVersionSatisfies(cuda_version, 12, 1) and can_do_cooperative:
602
+ schedules.append([
603
+ KernelScheduleType.CpAsyncWarpSpecializedCooperative,
604
+ default_epilogue
605
+ ])
606
+ stream_k_schedules.append([
607
+ KernelScheduleType.CpAsyncWarpSpecializedCooperative,
608
+ default_epilogue
609
+ ])
610
+
611
+ return schedules, stream_k_schedules
612
+
613
+ schedules = []
614
+ # Pruning: emit Void-C and Grouped kernels with persistent kernels only
615
+ if (level >= 1 or not is_void_c) and not grouped and not is_blockwise(gemm_kind):
616
+ # Pruning: don't stamp out fp8 kernels with auto schedule
617
+ if not is_fp8:
618
+ schedules.append([KernelScheduleType.ScheduleAuto, auto_epilogue])
619
+ schedules.append([KernelScheduleType.TmaWarpSpecialized, default_epilogue])
620
+ stream_k_schedules = []
621
+
622
+ if CudaToolkitVersionSatisfies(cuda_version, 12, 0):
623
+ if can_do_tma_epilogue:
624
+ assert not requires_transposed_epilogue
625
+ # Inconsistency: fp8 pingpong only gets stamped out with fast accum
626
+ if (not is_fp8 or level >= 1) and not is_blockwise(gemm_kind):
627
+ schedules.append([
628
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped),
629
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
630
+ ])
631
+ if can_do_fp8_fast_accum:
632
+ schedules.append([
633
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped),
634
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecialized, grouped)
635
+ ])
636
+
637
+ if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
638
+ # Pruning: don't stamp out fp8 ping-pong kernel with non-tma epilogue
639
+ if not is_fp8 or level >= 1:
640
+ if not is_blockwise(gemm_kind):
641
+ schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)])
642
+ else:
643
+ schedules.append([to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedPingpong, grouped), to_grouped_schedule(default_epilogue, grouped)])
644
+
645
+ if can_do_fp8_fast_accum:
646
+ if not grouped:
647
+ schedules.append([KernelScheduleType.TmaWarpSpecializedFP8FastAccum, default_epilogue])
648
+ schedules.append([to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, grouped), to_grouped_schedule(default_epilogue, grouped)])
649
+
650
+ if can_do_cooperative:
651
+ if is_blockwise(gemm_kind):
652
+ schedules.append([
653
+ to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
654
+ to_grouped_schedule(default_epilogue, grouped)
655
+ ])
656
+ stream_k_schedules.append([
657
+ KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
658
+ default_epilogue
659
+ ])
660
+ else:
661
+ schedules.append([
662
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
663
+ to_grouped_schedule(default_epilogue, grouped)
664
+ ])
665
+ stream_k_schedules.append([
666
+ KernelScheduleType.TmaWarpSpecializedCooperative,
667
+ default_epilogue
668
+ ])
669
+ if can_do_fp8_fast_accum:
670
+ schedules.append([
671
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
672
+ to_grouped_schedule(default_epilogue, grouped)
673
+ ])
674
+ stream_k_schedules.append([
675
+ KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
676
+ default_epilogue
677
+ ])
678
+
679
+ # persistent kernels with TMA epilogues
680
+ if can_do_tma_epilogue:
681
+ assert not requires_transposed_epilogue
682
+ if can_do_cooperative:
683
+ if is_blockwise(gemm_kind):
684
+ schedules.append([
685
+ to_grouped_schedule(KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative, grouped),
686
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
687
+ ])
688
+ stream_k_schedules.append([
689
+ KernelScheduleType.BlockwiseTmaWarpSpecializedCooperative,
690
+ EpilogueScheduleType.TmaWarpSpecializedCooperative
691
+ ])
692
+ else:
693
+ schedules.append([
694
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperative, grouped),
695
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
696
+ ])
697
+ stream_k_schedules.append([
698
+ KernelScheduleType.TmaWarpSpecializedCooperative,
699
+ EpilogueScheduleType.TmaWarpSpecializedCooperative
700
+ ])
701
+ if can_do_fp8_fast_accum:
702
+ schedules.append([
703
+ to_grouped_schedule(KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, grouped),
704
+ to_grouped_schedule(EpilogueScheduleType.TmaWarpSpecializedCooperative, grouped)
705
+ ])
706
+ stream_k_schedules.append([
707
+ KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum,
708
+ EpilogueScheduleType.TmaWarpSpecializedCooperative
709
+ ])
710
+ # Grouped GEMM do not support Stream-K scheduler
711
+ if grouped:
712
+ return schedules, []
713
+ return schedules, stream_k_schedules
714
+
715
+
716
+ #### Misc: helpers
717
+
718
+ def generate_data_types_from_math_instruction(math_instruction, element_source = None, element_dest = None, element_epilogue = None):
719
+ element_a, element_b = math_instruction.element_a, math_instruction.element_b
720
+ element_accumulator = math_instruction.element_accumulator
721
+ element_c = element_source or element_accumulator
722
+ element_d = element_dest or element_accumulator
723
+ element_epilogue = element_epilogue or element_accumulator
724
+ data_types = {
725
+ "a_type" : element_a,
726
+ "b_type" : element_b,
727
+ "c_type" : element_c,
728
+ "d_type" : element_d,
729
+ "acc_type" : element_accumulator,
730
+ "epi_type" : element_epilogue
731
+ }
732
+ return data_types
733
+
734
+ def fix_alignments(data_types, layout, alignment_bits = 128):
735
+ operand_keys = ["a_type", "b_type", "c_type"]
736
+ operands_to_fix = ["c_type"]
737
+ new_layout = []
738
+ assert len(layout) == len(operand_keys)
739
+ for i, k in enumerate(operand_keys):
740
+ assert k in data_types and data_types[k] in DataTypeSize
741
+ dtype = data_types[k]
742
+ dtype_size_bits = DataTypeSize[dtype]
743
+
744
+ layout_type = layout[i][0]
745
+ layout_alignment = layout[i][1]
746
+
747
+ # Don't modify alignment if dtype's been changed to void
748
+ if k in operands_to_fix and dtype_size_bits >= 1:
749
+ layout_alignment = alignment_bits // dtype_size_bits
750
+
751
+ new_layout.append([layout_type, layout_alignment])
752
+
753
+ return new_layout
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/symm_operation.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for emitting Symm kernels
35
+ """
36
+
37
+ import enum
38
+ import functools
39
+ import operator
40
+ import os.path
41
+ import shutil
42
+
43
+ try:
44
+ import builtins
45
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
+ raise ImportError("Disabling attempt to import cutlass_library")
47
+ from cutlass_library.library import *
48
+ except ImportError:
49
+ from library import *
50
+
51
+
52
+ ###################################################################################################
53
+ #
54
+ # Data structure modeling a Symm update operation
55
+ #
56
+ ###################################################################################################
57
+
58
+ #
59
+ class SymmOperation:
60
+ #
61
+ def __init__(self, symm_kind, arch, tile_description, A, B, C, element_epilogue, \
62
+ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, \
63
+ blas_mode = BlasMode.symmetric):
64
+
65
+ self.blas_mode = blas_mode
66
+ self.operation_kind = OperationKind.Symm
67
+ self.arch = arch
68
+ self.tile_description = tile_description
69
+ self.symm_kind = symm_kind
70
+ # tensor A and B have same data type and layout
71
+ self.A = A
72
+ self.B = B
73
+ self.C = C
74
+ self.element_epilogue = element_epilogue
75
+ self.epilogue_functor = epilogue_functor
76
+ self.swizzling_functor = swizzling_functor
77
+
78
+ #
79
+ def is_complex(self):
80
+ complex_operators = [
81
+ MathOperation.multiply_add_complex,
82
+ MathOperation.multiply_add_complex_gaussian,
83
+ MathOperation.multiply_add_complex_fast_f32
84
+ ]
85
+ return self.tile_description.math_instruction.math_operation in complex_operators
86
+ return False
87
+
88
+ #
89
+ def is_mixed_input(self):
90
+ return self.A.element != self.B.element
91
+
92
+ #
93
+ def is_planar_complex(self):
94
+ return False
95
+
96
+ #
97
+ def accumulator_type(self):
98
+ accum = self.tile_description.math_instruction.element_accumulator
99
+
100
+ if self.is_complex():
101
+ return get_complex_from_real(accum)
102
+
103
+ return accum
104
+
105
+ #
106
+ def short_math_name(self):
107
+ if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
108
+ return "g%s" % ShortDataTypeNames[self.accumulator_type()]
109
+ return ShortDataTypeNames[self.accumulator_type()]
110
+
111
+
112
+ #
113
+ def core_name(self):
114
+ ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
115
+
116
+ inst_shape = ''
117
+ inst_operation = ''
118
+ intermediate_type = ''
119
+
120
+ math_operations_map = {
121
+ MathOperation.xor_popc: 'xor',
122
+ MathOperation.and_popc: 'and'
123
+ }
124
+
125
+ if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
126
+ self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
127
+
128
+ math_op = self.tile_description.math_instruction.math_operation
129
+ math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
130
+
131
+ inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
132
+ inst_shape += math_op_string
133
+
134
+ if self.tile_description.math_instruction.element_a != self.A.element and \
135
+ self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
136
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
137
+
138
+ operation_name = 'symm' if self.blas_mode == BlasMode.symmetric else 'hemm'
139
+
140
+ return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, operation_name)
141
+
142
+ #
143
+ def extended_name(self):
144
+ ''' Append data types if they differ from compute type. '''
145
+ if self.is_complex():
146
+ extended_name = "${core_name}"
147
+ else:
148
+ if self.C.element != self.tile_description.math_instruction.element_accumulator and \
149
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
150
+ extended_name = "${element_c}_${core_name}_${element_a}"
151
+ elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
152
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
153
+ extended_name = "${core_name}_${element_a}"
154
+ else:
155
+ extended_name = "${core_name}"
156
+
157
+ extended_name = SubstituteTemplate(extended_name, {
158
+ 'element_a': DataTypeNames[self.A.element],
159
+ 'element_c': DataTypeNames[self.C.element],
160
+ 'core_name': self.core_name()
161
+ })
162
+
163
+ return extended_name
164
+
165
+ #
166
+ def layout_name(self):
167
+ if self.is_complex() or self.is_planar_complex():
168
+ return "%s" % (
169
+ ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)]
170
+ )
171
+ return "%s" % (ShortLayoutTypeNames[self.A.layout])
172
+
173
+ #
174
+ def side_mode_name(self):
175
+ return "%s" % (ShortSideModeNames[self.A.side_mode])
176
+
177
+ #
178
+ def fill_mode_name(self):
179
+ return "%s" % (ShortFillModeNames[self.A.fill_mode])
180
+
181
+ #
182
+ def procedural_name(self):
183
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
184
+ threadblock = self.tile_description.procedural_name()
185
+
186
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
187
+
188
+ alignment = self.C.alignment
189
+
190
+ return SubstituteTemplate(
191
+ "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_align${alignment}",
192
+ {
193
+ 'opcode_class': opcode_class_name,
194
+ 'extended_name': self.extended_name(),
195
+ 'threadblock': threadblock,
196
+ 'layout': self.layout_name(),
197
+ 'side_mode': self.side_mode_name(),
198
+ 'fill_mode': self.fill_mode_name(),
199
+ 'alignment': "%d" % alignment,
200
+ }
201
+ )
202
+
203
+ #
204
+ def configuration_name(self):
205
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
206
+ return self.procedural_name()
207
+
208
+ ###################################################################################################
209
+ #
210
+ # Emits single instances of a CUTLASS device-wide operator
211
+ #
212
+ ###################################################################################################
213
+
214
+ #
215
+ class EmitSymmUniversalInstance:
216
+ ''' Responsible for emitting a CUTLASS template definition'''
217
+
218
+ def __init__(self):
219
+ self.symm_template = """
220
+ // Symm operator ${operation_name}
221
+ using Operation_${operation_name} =
222
+ typename cutlass::gemm::device::Symm<
223
+ ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode},
224
+ ${element_b}, ${layout_b},
225
+ ${element_c}, ${layout_c},
226
+ ${element_accumulator},
227
+ ${opcode_class},
228
+ ${arch},
229
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
230
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
231
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
232
+ ${epilogue_functor}<
233
+ ${element_c},
234
+ ${epilogue_vector_length},
235
+ ${element_accumulator},
236
+ ${element_epilogue}
237
+ >,
238
+ ${swizzling_functor},
239
+ ${stages},
240
+ ${align_a},
241
+ ${align_b},
242
+ ${split_k_serial},
243
+ ${math_operation}
244
+ >;
245
+ """
246
+ self.symm_complex_template = """
247
+ // Symm operator ${operation_name}
248
+ using Operation_${operation_name} =
249
+ typename cutlass::gemm::device::Symm<
250
+ ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode},
251
+ ${element_b}, ${layout_b},
252
+ ${element_c}, ${layout_c},
253
+ ${element_accumulator},
254
+ ${opcode_class},
255
+ ${arch},
256
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
257
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
258
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
259
+ ${epilogue_functor}<
260
+ ${element_c},
261
+ ${epilogue_vector_length},
262
+ ${element_accumulator},
263
+ ${element_epilogue}
264
+ >,
265
+ ${swizzling_functor},
266
+ ${stages},
267
+ ${align_a},
268
+ ${align_b},
269
+ ${split_k_serial},
270
+ ${math_operation},
271
+ ${blas_mode}
272
+ >;
273
+ """
274
+
275
+ def emit(self, operation):
276
+
277
+ threadblock_shape = operation.tile_description.threadblock_shape
278
+
279
+ warp_count = operation.tile_description.warp_count
280
+ warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
281
+
282
+ epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
283
+
284
+ values = {
285
+ 'operation_name': operation.procedural_name(),
286
+ 'element_a': DataTypeTag[operation.A.element],
287
+ 'layout_a': LayoutTag[operation.A.layout],
288
+ 'side_mode': SideModeTag[operation.A.side_mode],
289
+ 'fill_mode': FillModeTag[operation.A.fill_mode],
290
+ 'element_b': DataTypeTag[operation.B.element],
291
+ 'layout_b': LayoutTag[operation.B.layout],
292
+ 'element_c': DataTypeTag[operation.C.element],
293
+ 'layout_c': LayoutTag[operation.C.layout],
294
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
295
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
296
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
297
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
298
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
299
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
300
+ 'warp_shape_m': str(warp_shape[0]),
301
+ 'warp_shape_n': str(warp_shape[1]),
302
+ 'warp_shape_k': str(warp_shape[2]),
303
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
304
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
305
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
306
+ 'epilogue_vector_length': str(epilogue_vector_length),
307
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
308
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
309
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
310
+ 'stages': str(operation.tile_description.stages),
311
+ 'align_a': str(operation.A.alignment),
312
+ 'align_b': str(operation.B.alignment),
313
+ 'split_k_serial': 'false',
314
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
315
+ 'blas_mode': BlasModeTag[operation.blas_mode]
316
+ }
317
+
318
+ symm_template = self.symm_complex_template if operation.is_complex() else self.symm_template
319
+
320
+ return SubstituteTemplate(symm_template, values)
321
+
322
+ ###################################################################################################
323
+
324
+
325
+ ###################################################################################################
326
+ #
327
+ # Emitters functions for all targets
328
+ #
329
+ ###################################################################################################
330
+
331
+ class EmitSymmConfigurationLibrary:
332
+ def __init__(self, operation_path, configuration_name):
333
+ self.configuration_name = configuration_name
334
+ self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
335
+
336
+ self.instance_emitter = {
337
+ SymmKind.Universal: EmitSymmUniversalInstance,
338
+ }
339
+
340
+ self.symm_kind_wrappers = {
341
+ SymmKind.Universal: 'SymmOperation',
342
+ }
343
+
344
+ self.instance_template = {
345
+ SymmKind.Universal: """
346
+ ${compile_guard_start}
347
+ manifest.append(new ${symm_kind}<
348
+ Operation_${operation_name}
349
+ >("${operation_name}"));
350
+ ${compile_guard_end}
351
+ """
352
+ }
353
+
354
+ self.header_template = """
355
+ /*
356
+ Generated by symm_operation.py - Do not edit.
357
+ */
358
+
359
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
360
+ #include "cutlass/cutlass.h"
361
+ #include "cutlass/library/library.h"
362
+ #include "cutlass/library/manifest.h"
363
+
364
+ #include "library_internal.h"
365
+ #include "symm_operation.h"
366
+
367
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
368
+
369
+ """
370
+
371
+ self.initialize_function_template = """
372
+
373
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
374
+
375
+ namespace cutlass {
376
+ namespace library {
377
+
378
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
379
+
380
+ void initialize_${configuration_name}(Manifest &manifest) {
381
+
382
+ """
383
+ self.epilogue_template = """
384
+
385
+ }
386
+
387
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
388
+
389
+ } // namespace library
390
+ } // namespace cutlass
391
+
392
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
393
+
394
+ """
395
+
396
+ def __enter__(self):
397
+ self.configuration_file = open(self.configuration_path, "w")
398
+ self.configuration_file.write(self.header_template)
399
+
400
+ self.instance_definitions = []
401
+ self.instance_wrappers = []
402
+
403
+ self.operations = []
404
+ return self
405
+
406
+ def emit(self, operation):
407
+ emitter = self.instance_emitter[operation.symm_kind]()
408
+
409
+ self.operations.append(operation)
410
+
411
+ self.instance_definitions.append(emitter.emit(operation))
412
+
413
+ self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.symm_kind], {
414
+ 'configuration_name': self.configuration_name,
415
+ 'operation_name': operation.procedural_name(),
416
+ 'symm_kind': self.symm_kind_wrappers[operation.symm_kind],
417
+ 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
418
+ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
419
+ 'compile_guard_end': "#endif" \
420
+ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
421
+ }))
422
+
423
+ def __exit__(self, exception_type, exception_value, traceback):
424
+
425
+ # Write instance definitions in top-level namespace
426
+ for instance_definition in self.instance_definitions:
427
+ self.configuration_file.write(instance_definition)
428
+
429
+ # Add wrapper objects within initialize() function
430
+ self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
431
+ 'configuration_name': self.configuration_name
432
+ }))
433
+
434
+ for instance_wrapper in self.instance_wrappers:
435
+ self.configuration_file.write(instance_wrapper)
436
+
437
+ self.configuration_file.write(self.epilogue_template)
438
+ self.configuration_file.close()
439
+
440
+ ###################################################################################################
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/cutlass_library/trmm_operation.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for emitting Trmm kernels
35
+ """
36
+
37
+ import enum
38
+ import functools
39
+ import operator
40
+ import os.path
41
+ import shutil
42
+
43
+ try:
44
+ import builtins
45
+ if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
+ raise ImportError("Disabling attempt to import cutlass_library")
47
+ from cutlass_library.library import *
48
+ except ImportError:
49
+ from library import *
50
+
51
+
52
+ ###################################################################################################
53
+ #
54
+ # Data structure modeling a TRMM operation
55
+ #
56
+ ###################################################################################################
57
+
58
+ #
59
+ class TrmmOperation:
60
+ #
61
+ def __init__(self, trmm_kind, arch, tile_description, A, B, C, element_epilogue, \
62
+ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8):
63
+
64
+ self.operation_kind = OperationKind.Trmm
65
+ self.arch = arch
66
+ self.tile_description = tile_description
67
+ self.trmm_kind = trmm_kind
68
+ self.A = A
69
+ self.B = B
70
+ self.C = C
71
+ self.element_epilogue = element_epilogue
72
+ self.epilogue_functor = epilogue_functor
73
+ self.swizzling_functor = swizzling_functor
74
+
75
+ #
76
+ def is_complex(self):
77
+ complex_operators = [
78
+ MathOperation.multiply_add_complex,
79
+ MathOperation.multiply_add_complex_gaussian,
80
+ MathOperation.multiply_add_complex_fast_f32
81
+ ]
82
+ return self.tile_description.math_instruction.math_operation in complex_operators
83
+ return False
84
+
85
+ #
86
+ def is_planar_complex(self):
87
+ # return self.trmm_kind in (TrmmKind.PlanarComplex, TrmmKind.PlanarComplexArray)
88
+ return False
89
+
90
+ #
91
+ def is_mixed_input(self):
92
+ return self.A.element != self.B.element
93
+
94
+ #
95
+ def accumulator_type(self):
96
+ accum = self.tile_description.math_instruction.element_accumulator
97
+
98
+ if self.is_complex():
99
+ return get_complex_from_real(accum)
100
+
101
+ return accum
102
+
103
+ #
104
+ def short_math_name(self):
105
+ if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
106
+ return "g%s" % ShortDataTypeNames[self.accumulator_type()]
107
+ return ShortDataTypeNames[self.accumulator_type()]
108
+
109
+
110
+ #
111
+ def core_name(self):
112
+ ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
113
+
114
+ inst_shape = ''
115
+ inst_operation = ''
116
+ intermediate_type = ''
117
+
118
+ math_operations_map = {
119
+ MathOperation.xor_popc: 'xor',
120
+ MathOperation.and_popc: 'and'
121
+ }
122
+
123
+ if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \
124
+ self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp:
125
+
126
+ math_op = self.tile_description.math_instruction.math_operation
127
+ math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ''
128
+
129
+ inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
130
+ inst_shape += math_op_string
131
+
132
+ if self.tile_description.math_instruction.element_a != self.A.element and \
133
+ self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
134
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
135
+
136
+ return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, TrmmKindNames[self.trmm_kind])
137
+
138
+ #
139
+ def extended_name(self):
140
+ ''' Append data types if they differ from compute type. '''
141
+ if self.is_complex():
142
+ extended_name = "${core_name}"
143
+ else:
144
+ if self.C.element != self.tile_description.math_instruction.element_accumulator and \
145
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
146
+ extended_name = "${element_c}_${core_name}_${element_a}"
147
+ elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
148
+ self.A.element != self.tile_description.math_instruction.element_accumulator:
149
+ extended_name = "${core_name}_${element_a}"
150
+ else:
151
+ extended_name = "${core_name}"
152
+
153
+ extended_name = SubstituteTemplate(extended_name, {
154
+ 'element_a': DataTypeNames[self.A.element],
155
+ 'element_c': DataTypeNames[self.C.element],
156
+ 'core_name': self.core_name()
157
+ })
158
+
159
+ return extended_name
160
+
161
+ #
162
+ def layout_name(self):
163
+ if self.is_complex() or self.is_planar_complex():
164
+ return "%s%s" % (
165
+ ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
166
+ ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
167
+ )
168
+ return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
169
+
170
+ #
171
+ def side_mode_name(self):
172
+ return "%s" % (ShortSideModeNames[self.A.side_mode])
173
+
174
+ #
175
+ def fill_mode_name(self):
176
+ return "%s" % (ShortFillModeNames[self.A.fill_mode])
177
+
178
+ #
179
+ def diag_type_name(self):
180
+ return "%s" % (ShortDiagTypeNames[self.A.diag_type])
181
+
182
+ #
183
+ def procedural_name(self):
184
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
185
+ threadblock = self.tile_description.procedural_name()
186
+
187
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
188
+
189
+ alignment = max([self.C.alignment])
190
+
191
+ return SubstituteTemplate(
192
+ "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${side_mode}_${fill_mode}_${diag_type}_align${alignment}",
193
+ {
194
+ 'opcode_class': opcode_class_name,
195
+ 'extended_name': self.extended_name(),
196
+ 'threadblock': threadblock,
197
+ 'layout': self.layout_name(),
198
+ 'side_mode': self.side_mode_name(),
199
+ 'fill_mode': self.fill_mode_name(),
200
+ 'diag_type': self.diag_type_name(),
201
+ 'alignment': "%d" % self.C.alignment,
202
+ }
203
+ )
204
+
205
+ #
206
+ def configuration_name(self):
207
+ ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
208
+ return self.procedural_name()
209
+
210
+ ###################################################################################################
211
+ #
212
+ # Emits single instances of a CUTLASS device-wide operator
213
+ #
214
+ ###################################################################################################
215
+
216
+ #
217
+ class EmitTrmmUniversalInstance:
218
+ ''' Responsible for emitting a CUTLASS template definition'''
219
+
220
+ def __init__(self):
221
+ self.trmm_template = """
222
+ // Trmm operator ${operation_name}
223
+ using Operation_${operation_name} =
224
+ typename cutlass::gemm::device::Trmm<
225
+ ${element_a}, ${layout_a},
226
+ ${side_mode}, ${fill_mode}, ${diag_type},
227
+ ${element_b}, ${layout_b},
228
+ ${element_c}, ${layout_c},
229
+ ${element_accumulator},
230
+ ${opcode_class},
231
+ ${arch},
232
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
233
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
234
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
235
+ ${epilogue_functor}<
236
+ ${element_c},
237
+ ${epilogue_vector_length},
238
+ ${element_accumulator},
239
+ ${element_epilogue},
240
+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
241
+ >,
242
+ ${swizzling_functor},
243
+ ${stages},
244
+ ${align_a},
245
+ ${align_b},
246
+ ${split_k_serial},
247
+ ${math_operation}
248
+ >;
249
+ """
250
+ self.trmm_complex_template = """
251
+ // Trmm operator ${operation_name}
252
+ using Operation_${operation_name} =
253
+ typename cutlass::gemm::device::Trmm<
254
+ ${element_a}, ${layout_a},
255
+ ${side_mode}, ${fill_mode}, ${diag_type},
256
+ ${element_b}, ${layout_b},
257
+ ${element_c}, ${layout_c},
258
+ ${element_accumulator},
259
+ ${opcode_class},
260
+ ${arch},
261
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
262
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
263
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
264
+ ${epilogue_functor}<
265
+ ${element_c},
266
+ ${epilogue_vector_length},
267
+ ${element_accumulator},
268
+ ${element_epilogue},
269
+ cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
270
+ >,
271
+ ${swizzling_functor},
272
+ ${stages},
273
+ ${align_a},
274
+ ${align_b},
275
+ ${split_k_serial},
276
+ ${math_operation},
277
+ ${transform_a}
278
+ >;
279
+ """
280
+
281
+ def emit(self, operation):
282
+
283
+ threadblock_shape = operation.tile_description.threadblock_shape
284
+ warp_count = operation.tile_description.warp_count
285
+
286
+ warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
287
+
288
+ epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
289
+
290
+ values = {
291
+ 'operation_name': operation.procedural_name(),
292
+ 'element_a': DataTypeTag[operation.A.element],
293
+ 'layout_a': LayoutTag[operation.A.layout],
294
+ 'side_mode' : SideModeTag[operation.A.side_mode],
295
+ 'fill_mode': FillModeTag[operation.A.fill_mode],
296
+ 'diag_type' : DiagTypeTag[operation.A.diag_type],
297
+ 'element_b': DataTypeTag[operation.B.element],
298
+ 'layout_b': LayoutTag[operation.B.layout],
299
+ 'element_c': DataTypeTag[operation.C.element],
300
+ 'layout_c': LayoutTag[operation.C.layout],
301
+ 'element_accumulator': DataTypeTag[operation.accumulator_type()],
302
+ 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
303
+ 'arch': "cutlass::arch::Sm%d" % operation.arch,
304
+ 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
305
+ 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
306
+ 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
307
+ 'warp_shape_m': str(warp_shape[0]),
308
+ 'warp_shape_n': str(warp_shape[1]),
309
+ 'warp_shape_k': str(warp_shape[2]),
310
+ 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
311
+ 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
312
+ 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
313
+ 'epilogue_vector_length': str(epilogue_vector_length),
314
+ 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
315
+ 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
316
+ 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
317
+ 'stages': str(operation.tile_description.stages),
318
+ 'align_a': str(1), # TRMM A's alignment is always 1 for no padding to work until we make zfill work with variable bytes
319
+ 'align_b': str(operation.B.alignment),
320
+ 'split_k_serial': 'false',
321
+ 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation],
322
+ 'transform_a': ComplexTransformTag[operation.A.complex_transform]
323
+ }
324
+
325
+ trmm_template = self.trmm_complex_template if operation.is_complex() else self.trmm_template
326
+
327
+ return SubstituteTemplate(trmm_template, values)
328
+
329
+ ###################################################################################################
330
+
331
+
332
+ ###################################################################################################
333
+ #
334
+ # Emitters functions for all targets
335
+ #
336
+ ###################################################################################################
337
+
338
+ class EmitTrmmConfigurationLibrary:
339
+ def __init__(self, operation_path, configuration_name):
340
+ self.configuration_name = configuration_name
341
+ self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name).replace('\\', '/')
342
+
343
+ self.instance_emitter = {
344
+ TrmmKind.Universal: EmitTrmmUniversalInstance,
345
+ }
346
+
347
+ self.trmm_kind_wrappers = {
348
+ TrmmKind.Universal: 'TrmmOperation',
349
+ }
350
+
351
+ self.instance_template = {
352
+ TrmmKind.Universal: """
353
+ ${compile_guard_start}
354
+ manifest.append(new ${trmm_kind}<
355
+ Operation_${operation_name}
356
+ >("${operation_name}"));
357
+ ${compile_guard_end}
358
+ """
359
+ }
360
+
361
+ self.header_template = """
362
+ /*
363
+ Generated by trmm_operation.py - Do not edit.
364
+ */
365
+
366
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
367
+ #include "cutlass/cutlass.h"
368
+ #include "cutlass/library/library.h"
369
+ #include "cutlass/library/manifest.h"
370
+
371
+ #include "library_internal.h"
372
+ #include "trmm_operation.h"
373
+
374
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
375
+
376
+ """
377
+
378
+ self.initialize_function_template = """
379
+
380
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
381
+
382
+ namespace cutlass {
383
+ namespace library {
384
+
385
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
386
+
387
+ void initialize_${configuration_name}(Manifest &manifest) {
388
+
389
+ """
390
+ self.epilogue_template = """
391
+
392
+ }
393
+
394
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
395
+
396
+ } // namespace library
397
+ } // namespace cutlass
398
+
399
+ ///////////////////////////////////////////////////////////////////////////////////////////////////
400
+
401
+ """
402
+
403
+ def __enter__(self):
404
+ self.configuration_file = open(self.configuration_path, "w")
405
+ self.configuration_file.write(self.header_template)
406
+
407
+ self.instance_definitions = []
408
+ self.instance_wrappers = []
409
+
410
+ self.operations = []
411
+ return self
412
+
413
+ def emit(self, operation):
414
+ emitter = self.instance_emitter[operation.trmm_kind]()
415
+
416
+ self.operations.append(operation)
417
+
418
+ self.instance_definitions.append(emitter.emit(operation))
419
+
420
+ self.instance_wrappers.append(SubstituteTemplate(self.instance_template[operation.trmm_kind], {
421
+ 'configuration_name': self.configuration_name,
422
+ 'operation_name': operation.procedural_name(),
423
+ 'trmm_kind': self.trmm_kind_wrappers[operation.trmm_kind],
424
+ 'compile_guard_start': SubstituteTemplate(self.wmma_guard_start, {'sm_number': str(operation.arch)}) \
425
+ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else "",
426
+ 'compile_guard_end': "#endif" \
427
+ if operation.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp else ""
428
+ }))
429
+
430
+ def __exit__(self, exception_type, exception_value, traceback):
431
+
432
+ # Write instance definitions in top-level namespace
433
+ for instance_definition in self.instance_definitions:
434
+ self.configuration_file.write(instance_definition)
435
+
436
+ # Add wrapper objects within initialize() function
437
+ self.configuration_file.write(SubstituteTemplate(self.initialize_function_template, {
438
+ 'configuration_name': self.configuration_name
439
+ }))
440
+
441
+ for instance_wrapper in self.instance_wrappers:
442
+ self.configuration_file.write(instance_wrapper)
443
+
444
+ self.configuration_file.write(self.epilogue_template)
445
+ self.configuration_file.close()
446
+
447
+ ###################################################################################################
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/docs_src/source/conf.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ # Configuration file for the Sphinx documentation builder.
34
+ #
35
+ # For the full list of built-in configuration values, see the documentation:
36
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html
37
+
38
+ # -- Path setup --------------------------------------------------------------
39
+
40
+ # If extensions (or modules to document with autodoc) are in another directory,
41
+ # add these directories to sys.path here. If the directory is relative to the
42
+ # documentation root, use os.path.abspath to make it absolute, like shown here.
43
+ #
44
+ import os
45
+ import sys
46
+
47
+ sys.path.insert(0, os.path.abspath('..'))
48
+ sys.path.insert(0, os.path.abspath('../..'))
49
+ sys.path.insert(0, os.path.abspath('../../media/docs'))
50
+
51
+ # -- Project information -----------------------------------------------------
52
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
53
+
54
+ project = 'CUTLASS Python interface'
55
+ copyright = '2023, NVIDIA'
56
+ author = 'NVIDIA'
57
+ release = '3.1.0'
58
+
59
+ # -- General configuration ---------------------------------------------------
60
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
61
+
62
+
63
+ # Add any Sphinx extension module names here, as strings. They can be
64
+ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
65
+ # ones.
66
+ extensions = [
67
+ 'myst_parser',
68
+ 'nbsphinx',
69
+ 'nbsphinx_link',
70
+ 'sphinx_copybutton',
71
+ 'sphinx.ext.autodoc',
72
+ 'sphinx.ext.autosectionlabel',
73
+ 'sphinx.ext.autosummary',
74
+ 'sphinx.ext.coverage',
75
+ 'sphinx.ext.extlinks',
76
+ 'sphinx.ext.ifconfig',
77
+ 'sphinx.ext.intersphinx',
78
+ 'sphinx.ext.mathjax',
79
+ 'sphinx.ext.napoleon',
80
+ 'sphinx.ext.viewcode',
81
+ 'sphinx_inline_tabs',
82
+ ]
83
+
84
+ source_suffix = {
85
+ '.rst': 'restructuredtext',
86
+ '.md': 'markdown',
87
+ }
88
+
89
+ autodoc_typehints = 'description'
90
+
91
+ pygments_style = "sphinx"
92
+ pygments_dark_style = "monokai"
93
+
94
+ templates_path = ['_templates']
95
+ exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
96
+
97
+ # Ignore errors when converting notebooks
98
+ nbsphinx_allow_errors = True
99
+
100
+ language = 'en'
101
+ # -- Options for HTML output -------------------------------------------------
102
+ # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
103
+
104
+ html_static_path = ['_static']
105
+
106
+ html_title = "CUTLASS Python"
107
+ html_baseurl = 'docs'
108
+ html_theme = 'furo'
109
+ html_theme_options = {
110
+ "light_logo": "cutlass-logo-small.png",
111
+ "dark_logo": "cutlass-logo-small.png",
112
+ "light_css_variables": {
113
+ "color-brand-primary": "#76B900",
114
+ "color-brand-content": "#76B900",
115
+ },
116
+ "dark_css_variables": {
117
+ "color-brand-primary": "#76B900",
118
+ "color-brand-content": "#76B900",
119
+ },
120
+ "footer_icons": [
121
+ {
122
+ "name": "GitHub",
123
+ "url": "https://github.com/NVIDIA/cutlass",
124
+ "html": """
125
+ <svg stroke="currentColor" fill="currentColor" stroke-width="0" viewBox="0 0 16 16">
126
+ <path fill-rule="evenodd" d="M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z"></path>
127
+ </svg>
128
+ """,
129
+ "class": "",
130
+ },
131
+ ],
132
+ }
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from .int_tuple import *
34
+ from .layout import *
35
+ from .swizzle import *
36
+ from .typing import *
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/int_tuple.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Functions for manipulating IntTuples
35
+ """
36
+
37
+ from functools import reduce
38
+ from itertools import chain
39
+ from typing import Union
40
+ from .typing import Integer
41
+
42
+
43
+ def is_int(x):
44
+ return isinstance(x, Integer)
45
+
46
+
47
+ def is_tuple(x):
48
+ return isinstance(x, tuple)
49
+
50
+
51
+ def flatten(t):
52
+ if is_tuple(t):
53
+ if len(t) == 0:
54
+ return ()
55
+ else:
56
+ return tuple(i for a in t for i in flatten(a))
57
+ else:
58
+ return (t,)
59
+
60
+
61
+ def signum(a):
62
+ return bool(a > 0) - bool(a < 0)
63
+
64
+
65
+ def product(a):
66
+ if is_tuple(a):
67
+ return reduce(lambda val,elem : val*product(elem), a, 1)
68
+ else:
69
+ return a
70
+
71
+
72
+ def inner_product(a, b):
73
+ if is_tuple(a): # tuple tuple
74
+ assert len(a) == len(b)
75
+ return sum(inner_product(x,y) for x,y in zip(a,b))
76
+ else: # "int" "int"
77
+ assert not is_tuple(b)
78
+ return a * b
79
+
80
+
81
+ def tuple_max(a):
82
+ if is_tuple(a):
83
+ return max(tuple_max(x) for x in a)
84
+ else:
85
+ return a
86
+
87
+
88
+ def elem_scale(a, b):
89
+ if is_tuple(a):
90
+ if is_tuple(b): # tuple tuple
91
+ assert len(a) == len(b)
92
+ return tuple(elem_scale(x,y) for x,y in zip(a,b))
93
+ else: # tuple "int"
94
+ assert False # Error
95
+ else:
96
+ if is_tuple(b): # "int" tuple
97
+ return elem_scale(a, product(b))
98
+ else: # "int" "int"
99
+ return a * b
100
+
101
+
102
+ # Inclusive prefix ceil div with output congruent to input a
103
+ def shape_div(a, b):
104
+ if is_tuple(a):
105
+ if is_tuple(b): # tuple tuple
106
+ assert len(a) == len(b)
107
+ return tuple(shape_div(x,y) for x,y in zip(a,b))
108
+ else: # tuple "int"
109
+ #r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))]
110
+ r = []
111
+ for v in a:
112
+ r.append(shape_div(v,b))
113
+ b = shape_div(b,product(v))
114
+ return tuple(r)
115
+ else:
116
+ if is_tuple(b): # "int" tuple
117
+ return shape_div(a, product(b))
118
+ else: # "int" "int"
119
+ assert a % b == 0 or b % a == 0
120
+ return (a + b - 1) // b
121
+
122
+ # Exclusive prefix product with output congruent to input a
123
+ def prefix_product(a, init=1):
124
+ if is_tuple(a):
125
+ if is_tuple(init): # tuple tuple
126
+ assert len(a) == len(init)
127
+ return tuple(prefix_product(x,i) for x,i in zip(a,init))
128
+ else: # tuple "int"
129
+ #r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))]
130
+ r = []
131
+ for v in a:
132
+ r.append(prefix_product(v,init))
133
+ init = init * product(v)
134
+ return tuple(r)
135
+ else:
136
+ if is_tuple(init): # "int" tuple
137
+ assert False # Error
138
+ else: # "int" "int"
139
+ return init
140
+
141
+
142
+ def idx2crd(idx, shape, stride=None):
143
+ if stride is None:
144
+ stride = prefix_product(shape)
145
+
146
+ if is_tuple(idx):
147
+ if is_tuple(shape): # tuple tuple tuple
148
+ assert len(idx) == len(shape) and len(idx) == len(stride)
149
+ return tuple(idx2crd(i, s, d) for i, s, d in zip(idx,shape,stride))
150
+ else: # tuple "int" "int"
151
+ assert False # Error
152
+ else:
153
+ if is_tuple(shape): # "int" tuple tuple
154
+ assert len(shape) == len(stride)
155
+ return tuple(idx2crd(idx, s, d) for s,d in zip(shape,stride))
156
+ else: # "int" "int" "int"
157
+ return (idx // stride) % shape
158
+
159
+
160
+ def crd2idx(crd, shape, stride=None):
161
+ if stride is None:
162
+ stride = prefix_product(shape)
163
+
164
+ if is_tuple(crd):
165
+ if is_tuple(shape): # tuple tuple tuple
166
+ assert len(crd) == len(shape) and len(crd) == len(stride)
167
+ return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride))
168
+ else: # tuple "int" "int"
169
+ assert False, f"crd={crd}, shape={shape}" # Error
170
+ else:
171
+ if crd is None:
172
+ crd = 0
173
+
174
+ if is_tuple(shape): # "int" tuple tuple
175
+ assert len(shape) == len(stride)
176
+ result = 0
177
+ for i in range(len(shape)-1):
178
+ result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
179
+ crd = crd // product(shape[i])
180
+ return result + crd2idx(crd, shape[-1], stride[-1])
181
+ else: # "int" "int" "int"
182
+ return crd * stride
183
+
184
+
185
+ # Transform crd into the dst_shape's iteration space
186
+ def crd2crd(crd, dst_shape, src_shape=None):
187
+ if is_tuple(crd):
188
+ if is_tuple(dst_shape): # tuple tuple
189
+ assert len(crd) == len(dst_shape)
190
+ return tuple(crd2crd(x, y) for x, y in zip(crd,dst_shape))
191
+ else: # tuple "int"
192
+ # Ambiguous unless we have src_shape
193
+ assert src_shape is not None
194
+ return crd2idx(crd, src_shape)
195
+ else:
196
+ if is_tuple(dst_shape): # "int" tuple
197
+ return idx2crd(crd, dst_shape)
198
+ else: # "int" "int"
199
+ assert crd < dst_shape
200
+ return crd
201
+
202
+
203
+ # Filter trg according to crd: keep only elements of trg that are paired with None
204
+ def slice_(crd: Union[None, tuple, int],
205
+ trg: Union[tuple, int]):
206
+ if is_tuple(crd):
207
+ if is_tuple(trg): # tuple tuple
208
+ assert len(crd) == len(trg)
209
+ # match C++ behavior of `filter_tuple` using `tuple_cat(...)`
210
+ return tuple(chain(*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)])))
211
+ else:
212
+ assert False # tuple "int" : Error
213
+ elif crd is None:
214
+ # match C++ behavior `return cute::tuple<B>{b};`
215
+ return (trg,)
216
+ else:
217
+ return ()
218
+
219
+
220
+ # Determine if None appears at any of an int_tuples' terminals
221
+ def has_none(a: Union[None, tuple, int]):
222
+ if is_tuple(a):
223
+ return any(has_none(v) for v in a)
224
+ else:
225
+ return a is None
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/layout.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Definition of CuTe Layouts and functions to manipulate them
35
+ """
36
+
37
+ from itertools import chain
38
+ from typing import Union
39
+
40
+ from .int_tuple import *
41
+
42
+
43
+ class LayoutBase:
44
+ pass
45
+
46
+
47
+ def is_layout(x):
48
+ return isinstance(x, LayoutBase)
49
+
50
+
51
+ class Layout(LayoutBase):
52
+ def __init__(self, _shape, _stride=None):
53
+ self.shape = _shape
54
+ if _stride is None:
55
+ self.stride = prefix_product(self.shape)
56
+ else:
57
+ self.stride = _stride
58
+
59
+ # operator ==
60
+ def __eq__(self, other):
61
+ return self.shape == other.shape and self.stride == other.stride
62
+
63
+ # operator len(L) (len [rank] like tuples)
64
+ def __len__(self):
65
+ if is_tuple(self.shape):
66
+ return len(self.shape)
67
+ else:
68
+ return 1
69
+
70
+ # operator () (map coord to idx)
71
+ def __call__(self, *args):
72
+ """
73
+ Map a logical coordinate to a linear index (Coord has no Underscore slice operators)
74
+ OR
75
+ Slice the layout and return the sublayout (Coord has an Underscore slice op)
76
+
77
+ Follow the same behavior of `Layout::operator(Coord const&)` in cute C++
78
+ """
79
+ if has_none(args):
80
+ if len(args) == 1:
81
+ return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride))
82
+ else:
83
+ return Layout(slice_(args, self.shape), slice_(args, self.stride))
84
+ else:
85
+ if len(args) == 1:
86
+ return crd2idx(args[0], self.shape, self.stride)
87
+ else:
88
+ return crd2idx(args, self.shape, self.stride)
89
+
90
+ # operator [] (get-i like tuples)
91
+ def __getitem__(self, i):
92
+ if is_tuple(self.shape):
93
+ return Layout(self.shape[i], self.stride[i])
94
+ else:
95
+ assert i == 0
96
+ return Layout(self.shape, self.stride)
97
+
98
+ # size(layout) Size of the domain
99
+ def size(self):
100
+ return product(self.shape)
101
+
102
+ # cosize(layout) Size of the codomain
103
+ def cosize(self):
104
+ return self(self.size() - 1) + 1
105
+
106
+ # print and str
107
+ def __str__(self):
108
+ return f"{self.shape}:{self.stride}"
109
+
110
+ # error msgs and representation
111
+ def __repr__(self):
112
+ return f"Layout({self.shape},{self.stride})"
113
+
114
+
115
+ # Make Layout from a list of layouts (each layout it's own mode in the result)
116
+ def make_layout(*layouts):
117
+ if len(layouts) == 1 and not is_layout(layouts[0]):
118
+ layouts = layouts[0]
119
+
120
+ shape, stride = zip(*((a.shape,a.stride) for a in layouts))
121
+ return Layout(shape, stride)
122
+
123
+
124
+ # Size of the domain
125
+ def size(layout):
126
+ if is_layout(layout):
127
+ return layout.size()
128
+ return product(layout)
129
+
130
+
131
+ # Size of the codomain
132
+ def cosize(layout):
133
+ return layout.cosize()
134
+
135
+
136
+ # Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function
137
+ def coalesce(layout, profile=None):
138
+ if is_tuple(profile):
139
+ assert len(layout) >= len(profile)
140
+ return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))),
141
+ (layout[i] for i in range(len(profile),len(layout)))))
142
+
143
+ result_shape = [1]
144
+ result_stride = [0]
145
+ for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)):
146
+ # skip their shape-1s
147
+ if shape == 1:
148
+ continue
149
+ # replace our shape-1 with anything
150
+ elif result_shape[-1] == 1:
151
+ result_shape[-1] = shape
152
+ result_stride[-1] = stride
153
+ # merge modes if the shape*stride match
154
+ elif result_shape[-1] * result_stride[-1] == stride:
155
+ result_shape[-1] = result_shape[-1] * shape
156
+ # append a new mode
157
+ else:
158
+ result_shape.append(shape)
159
+ result_stride.append(stride)
160
+
161
+ if len(result_shape) == 1:
162
+ return Layout(result_shape[0], result_stride[0])
163
+ else:
164
+ return Layout(tuple(result_shape), tuple(result_stride))
165
+
166
+
167
+ # Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them
168
+ def filter(layout, profile=None):
169
+ if is_tuple(profile):
170
+ assert len(layout) >= len(profile)
171
+ return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))),
172
+ (layout[i] for i in range(len(profile),len(layout)))))
173
+
174
+ result_shape = []
175
+ result_stride = []
176
+ for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)):
177
+ # skip their shape-1s and stride-0s
178
+ if not (shape == 1 or stride == 0):
179
+ result_shape.append(shape)
180
+ result_stride.append(stride)
181
+
182
+ if len(result_shape) == 0:
183
+ return Layout(1,0)
184
+ else:
185
+ return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
186
+
187
+
188
+ # Layout composition
189
+ # Use tuples-of-layouts to perform this operation by-mode and None as no-op
190
+ def composition(layoutA, layoutB):
191
+ if layoutB is None:
192
+ return layoutA
193
+ elif is_int(layoutB):
194
+ return composition(layoutA, Layout(layoutB))
195
+ elif is_tuple(layoutB):
196
+ assert len(layoutA) >= len(layoutB)
197
+ return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
198
+ (layoutA[i] for i in range(len(layoutB),len(layoutA)))))
199
+ elif is_tuple(layoutB.shape):
200
+ return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB)
201
+
202
+ if layoutB.stride == 0:
203
+ return Layout(layoutB.shape, 0)
204
+ else:
205
+ result_shape = []
206
+ result_stride = []
207
+ rest_shape = layoutB.shape
208
+ rest_stride = layoutB.stride
209
+ flat_A = coalesce(layoutA)
210
+ for (curr_shape, curr_stride) in zip(flatten(flat_A.shape)[:-1], flatten(flat_A.stride)[:-1]):
211
+ assert curr_shape % rest_stride == 0 or rest_stride % curr_shape == 0
212
+ new_shape = min(max(1, curr_shape // rest_stride), rest_shape)
213
+
214
+ if new_shape != 1:
215
+ result_shape.append(new_shape)
216
+ result_stride.append(rest_stride * curr_stride)
217
+
218
+ rest_shape = rest_shape // new_shape
219
+ rest_stride = -(-rest_stride // curr_shape) # Python exclusive impl: "//" is always floor div so == ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride)
220
+
221
+ if rest_shape != 1 or len(result_shape) == 0:
222
+ result_shape.append(rest_shape)
223
+ result_stride.append(rest_stride * flatten(flat_A.stride)[-1])
224
+
225
+ if len(result_shape) == 1:
226
+ return Layout(result_shape[0], result_stride[0])
227
+ else:
228
+ return Layout(tuple(result_shape), tuple(result_stride))
229
+
230
+
231
+ # Layout complement
232
+ def complement(layout, max_idx=1):
233
+ if is_int(layout):
234
+ return complement(Layout(layout))
235
+
236
+ result_shape = []
237
+ result_stride = []
238
+ current_idx = 1
239
+
240
+ sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape)))
241
+ for (stride, shape) in sorted_DS:
242
+ if stride == 0 or shape == 1:
243
+ continue
244
+
245
+ in_bound = current_idx <= shape * stride
246
+ # To support symbolic value which can't be evaluated now
247
+ assert (type(in_bound) is not bool) or in_bound
248
+
249
+ result_shape.append(stride // current_idx)
250
+ result_stride.append(current_idx)
251
+ current_idx = shape * stride
252
+
253
+ result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div
254
+ result_stride.append(current_idx)
255
+
256
+ return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
257
+
258
+
259
+ # Layout right inverse
260
+ def right_inverse(layout):
261
+ if layout is None:
262
+ return None
263
+ elif is_int(layout):
264
+ return Layout(layout)
265
+
266
+ result_shape = []
267
+ result_stride = []
268
+ current_idx = 1
269
+
270
+ flat_shape = flatten(layout.shape)
271
+ flat_stride = flatten(layout.stride)
272
+ sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape)))
273
+ for (stride,shape,rstride) in sorted_DSA:
274
+ if shape == 1:
275
+ continue
276
+ if current_idx != stride:
277
+ break
278
+
279
+ result_shape.append(shape)
280
+ result_stride.append(rstride)
281
+ current_idx = shape * stride
282
+
283
+ return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
284
+
285
+
286
+ # Layout left inverse
287
+ def left_inverse(layout):
288
+ if layout is None:
289
+ return None
290
+ elif is_int(layout):
291
+ return Layout(layout)
292
+ return right_inverse(make_layout(layout, complement(layout)))
293
+
294
+
295
+ # Split a layout by the composition of B and the "rest"
296
+ # Use tuples-of-layouts to perform this operation by-mode and None as no-op
297
+ def logical_divide(layoutA, layoutB):
298
+ if layoutB is None:
299
+ return layoutA
300
+ elif is_int(layoutB):
301
+ return logical_divide(layoutA, Layout(layoutB))
302
+ elif is_tuple(layoutB):
303
+ assert len(layoutA) >= len(layoutB)
304
+ return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
305
+ (layoutA[i] for i in range(len(layoutB),len(layoutA)))))
306
+
307
+ return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA))))
308
+
309
+
310
+ # Reproduce a layoutA over a layoutB
311
+ # Use tuples-of-layouts to perform this operation by-mode and None as no-op
312
+ def logical_product(layoutA, layoutB):
313
+ if layoutB is None:
314
+ return layoutA
315
+ elif is_int(layoutB):
316
+ return logical_divide(layoutA, Layout(layoutB))
317
+ elif is_tuple(layoutB):
318
+ assert len(layoutA) >= len(layoutB)
319
+ return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
320
+ (layoutA[i] for i in range(len(layoutB),len(layoutA)))))
321
+
322
+ return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB));
323
+
324
+
325
+ # Gather the modes from a hierarchical logical_divide or logical_product
326
+ def hier_unzip(splitter, layoutA, layoutB):
327
+ if layoutB is None:
328
+ return make_layout(Layout(1,0), layoutA)
329
+ elif is_tuple(layoutB):
330
+ assert len(layoutA) >= len(layoutB)
331
+ # A layout with shape ((A,a),(B,b),(C,c))
332
+ split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB)))
333
+ # Gather to shape ((A,B,C,...),(a,b,c,...,y,z))
334
+ return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))),
335
+ make_layout(chain((split[i][1] for i in range( 0,len(layoutB))),
336
+ (layoutA[i] for i in range(len(layoutB),len(layoutA))))))
337
+
338
+ # splitter must return a rank-2 layout
339
+ return splitter(layoutA, layoutB)
340
+
341
+
342
+ # Apply logical divide hierarchically and gather the split modes into two modes
343
+ def zipped_divide(layoutA, layoutB):
344
+ return hier_unzip(logical_divide, layoutA, layoutB)
345
+
346
+
347
+ # Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode
348
+ def tiled_divide(layoutA, layoutB):
349
+ result = zipped_divide(layoutA, layoutB)
350
+ return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
351
+
352
+
353
+ # Apply logical product hierarchically and gather the split modes into two modes
354
+ def zipped_product(layoutA, layoutB):
355
+ return hier_unzip(logical_product, layoutA, layoutB)
356
+
357
+
358
+ # Perform logical product hierarchically and gather tiles (B-layouts) into a new mode
359
+ def tiled_product(layoutA, layoutB):
360
+ result = zipped_product(layoutA, layoutB)
361
+ return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
362
+
363
+
364
+ def slice_and_offset(crd: tuple,
365
+ layout: Layout):
366
+ return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)),
367
+ crd2idx(crd, layout.shape, layout.stride))
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/swizzle.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Methods for layout swizzling
35
+ """
36
+
37
+ from .layout import *
38
+
39
+
40
+ def shiftr(a, s):
41
+ return a >> s if s > 0 else shiftl(a, -s)
42
+
43
+
44
+ def shiftl(a, s):
45
+ return a << s if s > 0 else shiftr(a, -s)
46
+
47
+
48
+ ## A generic Swizzle functor
49
+ # 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx
50
+ # ^--^ Base is the number of least-sig bits to keep constant
51
+ # ^-^ ^-^ Bits is the number of bits in the mask
52
+ # ^---------^ Shift is the distance to shift the YYY mask
53
+ # (pos shifts YYY to the right, neg shifts YYY to the left)
54
+ #
55
+ # e.g. Given
56
+ # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx
57
+ # the result is
58
+ # 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY
59
+ #
60
+ class Swizzle:
61
+ def __init__(self, bits, base, shift):
62
+ assert bits >= 0
63
+ assert base >= 0
64
+ assert abs(shift) >= bits
65
+ self.bits = bits
66
+ self.base = base
67
+ self.shift = shift
68
+ bit_msk = (1 << bits) - 1
69
+ self.yyy_msk = bit_msk << (base + max(0,shift))
70
+ self.zzz_msk = bit_msk << (base - min(0,shift))
71
+
72
+ # operator () (transform integer)
73
+ def __call__(self, offset):
74
+ return offset ^ shiftr(offset & self.yyy_msk, self.shift)
75
+
76
+ # Size of the domain
77
+ def size(self):
78
+ return 1 << (self.bits + self.base + abs(self.shift))
79
+
80
+ # Size of the codomain
81
+ def cosize(self):
82
+ return self.size()
83
+
84
+ # print and str
85
+ def __str__(self):
86
+ return f"SW_{self.bits}_{self.base}_{self.shift}"
87
+
88
+ # error msgs and representation
89
+ def __repr__(self):
90
+ return f"Swizzle({self.bits},{self.base},{self.shift})"
91
+
92
+
93
+ class ComposedLayout(LayoutBase):
94
+ def __init__(self, layoutB, offset, layoutA):
95
+ self.layoutB = layoutB
96
+ self.offset = offset
97
+ self.layoutA = layoutA
98
+
99
+ # operator ==
100
+ def __eq__(self, other):
101
+ return self.layoutB == other.layoutB and self.offset == other.offset and self.layoutA == other.layoutA
102
+
103
+ # operator len(L) (len [rank] like tuples)
104
+ def __len__(self):
105
+ return len(self.layoutA)
106
+
107
+ # operator () (map coord to idx)
108
+ def __call__(self, *args):
109
+ return self.layoutB(self.offset + self.layoutA(*args))
110
+
111
+ # operator [] (get-i like tuples)
112
+ def __getitem__(self, i):
113
+ return ComposedLayout(self.layoutB, self.offset, self.layoutA[i])
114
+
115
+ # size(layout) Size of the domain
116
+ def size(self):
117
+ return size(self.layoutA)
118
+
119
+ # cosize(layout) Size of the codomain
120
+ def cosize(self):
121
+ return cosize(self.layoutB)
122
+
123
+ # print and str
124
+ def __str__(self):
125
+ return f"{self.layoutB} o {self.offset} o {self.layoutA}"
126
+
127
+ # error msgs and representation
128
+ def __repr__(self):
129
+ return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})"
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/pycute/typing.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from abc import ABC
34
+
35
+
36
+ class Integer(ABC):
37
+ @classmethod
38
+ def __subclasshook__(cls, c):
39
+ if c in [bool, float]:
40
+ return False
41
+
42
+ return issubclass(c, int)
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_cutlass.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+
34
+ import copy
35
+ import os
36
+ import setuptools
37
+ from setuptools import setup
38
+ from setuptools.command.build_ext import build_ext
39
+
40
+ import setup_pycute
41
+ import setup_library
42
+
43
+
44
+ # Install cutlass_library package
45
+ setup_library.perform_setup()
46
+
47
+
48
+ # Install the PyCuTe package
49
+ setup_pycute.perform_setup()
50
+
51
+
52
+ setup(
53
+ name='cutlass_cppgen',
54
+ version='4.2.0',
55
+ description='CUTLASS Pythonic Interface',
56
+ package_dir={'': '.'},
57
+ packages=[
58
+ 'cutlass_cppgen',
59
+ 'cutlass_cppgen.emit',
60
+ 'cutlass_cppgen.op',
61
+ 'cutlass_cppgen.utils',
62
+ 'cutlass_cppgen.backend',
63
+ 'cutlass_cppgen.backend.utils'
64
+ ],
65
+ setup_requires=['pybind11'],
66
+ install_requires=[
67
+ 'bfloat16',
68
+ 'cuda-python>=11.8.0',
69
+ 'pybind11',
70
+ 'scikit-build',
71
+ 'treelib',
72
+ 'pydot'
73
+ ]
74
+ )
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_library.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from setuptools import setup
34
+
35
+
36
+ def perform_setup():
37
+ setup(
38
+ name='cutlass_library',
39
+ version='4.2.1',
40
+ description='CUTLASS library generation scripts',
41
+ packages=['cutlass_library']
42
+ )
43
+
44
+
45
+ if __name__ == '__main__':
46
+ perform_setup()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/python/setup_pycute.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from setuptools import setup
34
+
35
+
36
+ def perform_setup():
37
+ setup(
38
+ name='pycute',
39
+ version='4.2.1',
40
+ description='Python implementation of CuTe',
41
+ packages=['pycute'],
42
+ )
43
+
44
+
45
+ if __name__ == '__main__':
46
+ perform_setup()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_problem_sizes.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for defining Conv2D problem sizes for testing.
35
+
36
+ This file was ported from the C++ version in test/unit/conv/device/conv2d_problems.h
37
+ """
38
+
39
+ from cutlass_library import ConvMode
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.shape import Conv2DProblemSize
43
+
44
+
45
+ class TestbedConv2dProblemSizes:
46
+ def __init__(self, minimum_channel_size: int):
47
+ conv2d_default_sizes = self.initialize_conv2d_default_sizes(minimum_channel_size)
48
+ conv2d_rigorous_sizes = self.initialize_conv2d_rigorous_sizes(minimum_channel_size)
49
+ conv2d_resnet50_sizes = self.initialize_conv2d_resnet50_sizes(1)
50
+ conv2d_resnet50_sizes_perf = self.initialize_conv2d_resnet50_sizes(34)
51
+ grouped_sizes = self.initialize_conv2d_grouped_sizes()
52
+
53
+ # Filter all problems
54
+ self.all = []
55
+ for size_list in [conv2d_default_sizes, conv2d_rigorous_sizes, conv2d_resnet50_sizes, conv2d_resnet50_sizes_perf, grouped_sizes]:
56
+ for size in size_list:
57
+ if (size.C // size.groups) % minimum_channel_size == 0:
58
+ self.all.append(size)
59
+
60
+
61
+ def initialize_conv2d_default_sizes(self, minimum_channel_size):
62
+ # Small input size x stride (1,1)
63
+ # C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
64
+
65
+ conv2d_default_sizes = []
66
+ conv2d_default_sizes.append(Conv2DProblemSize(
67
+ 1, 1, 1, minimum_channel_size,
68
+ 8, 1, 1, minimum_channel_size,
69
+ 1, 1,
70
+ 1, 1,
71
+ 1, 1,
72
+ ))
73
+
74
+ conv2d_default_sizes.append(Conv2DProblemSize(
75
+ 1, 1, 8, minimum_channel_size,
76
+ 8, 1, 3, minimum_channel_size,
77
+ 1, 1,
78
+ 1, 1,
79
+ 1, 1,
80
+ ))
81
+
82
+ conv2d_default_sizes.append(Conv2DProblemSize(
83
+ 1, 7, 8, minimum_channel_size,
84
+ 8, 3, 3, minimum_channel_size,
85
+ 1, 1,
86
+ 1, 1,
87
+ 1, 1,
88
+ ))
89
+
90
+ conv2d_default_sizes.append(Conv2DProblemSize(
91
+ 1, 7, 9, minimum_channel_size,
92
+ 8, 4, 4, minimum_channel_size,
93
+ 1, 1,
94
+ 1, 1,
95
+ 1, 1,
96
+ ))
97
+
98
+ conv2d_default_sizes.append(Conv2DProblemSize(
99
+ 2, 7, 9, minimum_channel_size,
100
+ 8, 5, 5, minimum_channel_size,
101
+ 1, 1,
102
+ 1, 1,
103
+ 1, 1,
104
+ ))
105
+
106
+ conv2d_default_sizes.append(Conv2DProblemSize(
107
+ 3, 7, 9, minimum_channel_size,
108
+ 8, 6, 5, minimum_channel_size,
109
+ 1, 1,
110
+ 1, 1,
111
+ 1, 1,
112
+ ))
113
+
114
+ conv2d_default_sizes.append(Conv2DProblemSize(
115
+ 3, 7, 9, minimum_channel_size,
116
+ 8, 6, 6, minimum_channel_size,
117
+ 1, 1,
118
+ 1, 1,
119
+ 1, 1,
120
+ ))
121
+
122
+ conv2d_default_sizes.append(Conv2DProblemSize(
123
+ 3, 7, 9, minimum_channel_size,
124
+ 8, 7, 7, minimum_channel_size,
125
+ 1, 1,
126
+ 1, 1,
127
+ 1, 1,
128
+ ))
129
+
130
+ ##############################################
131
+ # Small input size x stride (2,2)
132
+ # C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
133
+ ##############################################
134
+ conv2d_default_sizes.append(Conv2DProblemSize(
135
+ 1, 11, 7, minimum_channel_size,
136
+ 8, 1, 1, minimum_channel_size,
137
+ 0, 0,
138
+ 2, 2,
139
+ 1, 1,
140
+ ))
141
+
142
+ conv2d_default_sizes.append(Conv2DProblemSize(
143
+ 1, 11, 7, minimum_channel_size,
144
+ 8, 3, 3, minimum_channel_size,
145
+ 1, 1,
146
+ 2, 2,
147
+ 1, 1,
148
+ ))
149
+
150
+ conv2d_default_sizes.append(Conv2DProblemSize(
151
+ 1, 13, 11, minimum_channel_size,
152
+ 8, 1, 1, minimum_channel_size,
153
+ 1, 1,
154
+ 2, 2,
155
+ 1, 1,
156
+ ))
157
+
158
+ conv2d_default_sizes.append(Conv2DProblemSize(
159
+ 1, 17, 19, minimum_channel_size,
160
+ 16, 2, 2, minimum_channel_size,
161
+ 1, 1,
162
+ 2, 2,
163
+ 1, 1,
164
+ ))
165
+
166
+ conv2d_default_sizes.append(Conv2DProblemSize(
167
+ 1, 23, 5, minimum_channel_size,
168
+ 16, 3, 3, minimum_channel_size,
169
+ 1, 1,
170
+ 2, 2,
171
+ 1, 1,
172
+ ))
173
+
174
+ conv2d_default_sizes.append(Conv2DProblemSize(
175
+ 1, 13, 17, 8,
176
+ 24, 3, 3, 8,
177
+ 0, 0,
178
+ 2, 2,
179
+ 1, 1,
180
+ ))
181
+
182
+ conv2d_default_sizes.append(Conv2DProblemSize(
183
+ 1, 23, 21, 8,
184
+ 24, 3, 3, 8,
185
+ 1, 1,
186
+ 3, 3,
187
+ 1, 1,
188
+ ))
189
+
190
+ conv2d_default_sizes.append(Conv2DProblemSize(
191
+ 1, 20, 24, 8,
192
+ 40, 3, 3, 8,
193
+ 3, 3,
194
+ 3, 3,
195
+ 1, 1,
196
+ ))
197
+
198
+ ##########################################
199
+ # Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1)
200
+ ##########################################
201
+ conv2d_default_sizes.append(Conv2DProblemSize(
202
+ 1, 15, 19, 160,
203
+ 224, 1, 1, 160,
204
+ 0, 0,
205
+ 1, 1,
206
+ 1, 1,
207
+ ))
208
+
209
+ conv2d_default_sizes.append(Conv2DProblemSize(
210
+ 1, 19, 37, 160,
211
+ 224, 3, 3, 160,
212
+ 1, 1,
213
+ 2, 2,
214
+ 1, 1,
215
+ ))
216
+
217
+ conv2d_default_sizes.append(Conv2DProblemSize(
218
+ 1, 16, 16, 160,
219
+ 224, 2, 3, 160,
220
+ 1, 1,
221
+ 1, 1,
222
+ 1, 1,
223
+ ))
224
+
225
+ conv2d_default_sizes.append(Conv2DProblemSize(
226
+ 1, 23, 21, 128,
227
+ 224, 3, 3, 128,
228
+ 1, 1,
229
+ 1, 1,
230
+ 1, 1,
231
+ ))
232
+
233
+ conv2d_default_sizes.append(Conv2DProblemSize(
234
+ 1, 29, 37, 160,
235
+ 224, 5, 5, 160,
236
+ 2, 2,
237
+ 1, 1,
238
+ 1, 1,
239
+ ))
240
+
241
+ ##########################################
242
+ # C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
243
+ ##########################################
244
+ conv2d_default_sizes.append(Conv2DProblemSize(
245
+ 1, 15, 19, 32 + minimum_channel_size,
246
+ 96, 3, 3, 32 + minimum_channel_size,
247
+ 1, 1,
248
+ 1, 1,
249
+ 1, 1,
250
+ ))
251
+
252
+ conv2d_default_sizes.append(Conv2DProblemSize(
253
+ 1, 16, 24, 64 + minimum_channel_size,
254
+ 96, 3, 3, 64 + minimum_channel_size,
255
+ 1, 1,
256
+ 1, 1,
257
+ 1, 1,
258
+ ))
259
+
260
+ ##########################################
261
+ # Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2)
262
+ ##########################################
263
+ conv2d_default_sizes.append(Conv2DProblemSize(
264
+ 1, 13, 16, 288,
265
+ 160, 5, 5, 288,
266
+ 2, 2,
267
+ 2, 2,
268
+ 1, 1,
269
+ ))
270
+
271
+ conv2d_default_sizes.append(Conv2DProblemSize(
272
+ 1, 55, 51, 256,
273
+ 512, 1, 1, 256,
274
+ 0, 0,
275
+ 2, 2,
276
+ 1, 1,
277
+ ))
278
+
279
+ conv2d_default_sizes.append(Conv2DProblemSize(
280
+ 1, 71, 80, 32,
281
+ 64, 5, 5, 32,
282
+ 2, 2,
283
+ 2, 2,
284
+ 1, 1,
285
+ ))
286
+
287
+ conv2d_default_sizes.append(Conv2DProblemSize(
288
+ 1, 224, 224, 8,
289
+ 64, 7, 7, 8,
290
+ 3, 3,
291
+ 2, 2,
292
+ 1, 1,
293
+ ))
294
+
295
+ ##########################################
296
+ # Medium input size stride (3, 3), filter (3, 3), non-default padding
297
+ ##########################################
298
+ conv2d_default_sizes.append(Conv2DProblemSize(
299
+ 1, 27, 23, 256,
300
+ 512, 3, 3, 256,
301
+ 0, 0,
302
+ 3, 3,
303
+ 1, 1,
304
+ ))
305
+
306
+ ##########################################
307
+ # Medium input size padding > stride, asymmetric filter, padding and striding
308
+ ##########################################
309
+ conv2d_default_sizes.append(Conv2DProblemSize(
310
+ 1, 27, 31, 256,
311
+ 512, 3, 3, 256,
312
+ 5, 7,
313
+ 3, 4,
314
+ 1, 1,
315
+ ))
316
+
317
+ conv2d_default_sizes.append(Conv2DProblemSize(
318
+ 1, 27, 35, 256,
319
+ 512, 7, 5, 256,
320
+ 11, 7,
321
+ 3, 5,
322
+ 1, 1,
323
+ ))
324
+
325
+ ##########################################
326
+ # Medium input size *mixed* stride (1, 2) and (2, 1),
327
+ # filter (3, 3), default padding
328
+ ##########################################
329
+ conv2d_default_sizes.append(Conv2DProblemSize(
330
+ 1, 27, 27, 256,
331
+ 512, 3, 3, 256,
332
+ 1, 1,
333
+ 1, 2,
334
+ 1, 1,
335
+ ))
336
+
337
+ conv2d_default_sizes.append(Conv2DProblemSize(
338
+ 1, 27, 27, 256,
339
+ 512, 3, 3, 256,
340
+ 1, 1,
341
+ 2, 1,
342
+ 1, 1,
343
+ ))
344
+
345
+ ######################################/
346
+ # Additional input size
347
+ ######################################/
348
+ conv2d_default_sizes.append(Conv2DProblemSize(
349
+ 3, 28, 28, 256,
350
+ 256, 2, 2, 256,
351
+ 0, 0,
352
+ 2, 2,
353
+ 1, 1,
354
+ ))
355
+
356
+ conv2d_default_sizes.append(Conv2DProblemSize(
357
+ 1, 32, 32, 16,
358
+ 32, 3, 3, 16,
359
+ 1, 1,
360
+ 6, 2,
361
+ 1, 1,
362
+ ))
363
+
364
+ conv2d_default_sizes.append(Conv2DProblemSize(
365
+ 32, 24, 32, 32,
366
+ 32, 1, 2, 32,
367
+ 0, 0,
368
+ 1, 1,
369
+ 1, 1,
370
+ ))
371
+
372
+ conv2d_default_sizes.append(Conv2DProblemSize(
373
+ 4, 2, 3, 256,
374
+ 328, 3, 5, 256,
375
+ 1, 1,
376
+ 1, 1,
377
+ 1, 1,
378
+ ))
379
+ return conv2d_default_sizes
380
+
381
+ # Add a few large and rigorous convolution problem sizes
382
+ def initialize_conv2d_rigorous_sizes(self, minimum_channel_size):
383
+ sizes = []
384
+ if False:
385
+ sizes.append(Conv2DProblemSize.from_sizes(
386
+ (1, 124, 224, 2 * minimum_channel_size),
387
+ (24, 7, 7, 2 * minimum_channel_size),
388
+ ))
389
+
390
+ sizes.append(Conv2DProblemSize.from_sizes(
391
+ (1, 233, 35, minimum_channel_size),
392
+ (24, 7, 5, minimum_channel_size),
393
+ ))
394
+ return sizes
395
+
396
+ # Add resent50 layers to unit testing sizes
397
+ def initialize_conv2d_resnet50_sizes(self, batch_size):
398
+ conv2d_problem_vector = []
399
+ conv2d_problem_vector.append(Conv2DProblemSize(
400
+ batch_size, 56, 56, 64,
401
+ 256, 1, 1, 64,
402
+ 0, 0,
403
+ 1, 1,
404
+ 1, 1,
405
+ ))
406
+
407
+ conv2d_problem_vector.append(Conv2DProblemSize(
408
+ batch_size, 56, 56, 64,
409
+ 64, 1, 1, 64,
410
+ 0, 0,
411
+ 1, 1,
412
+ 1, 1,
413
+ ))
414
+
415
+ conv2d_problem_vector.append(Conv2DProblemSize(
416
+ batch_size, 56, 56, 64,
417
+ 64, 3, 3, 64,
418
+ 1, 1,
419
+ 1, 1,
420
+ 1, 1,
421
+ ))
422
+
423
+ conv2d_problem_vector.append(Conv2DProblemSize(
424
+ batch_size, 56, 56, 256,
425
+ 64, 1, 1, 256,
426
+ 0, 0,
427
+ 1, 1,
428
+ 1, 1,
429
+ ))
430
+
431
+ conv2d_problem_vector.append(Conv2DProblemSize(
432
+ batch_size, 56, 56, 256,
433
+ 512, 1, 1, 256,
434
+ 0, 0,
435
+ 2, 2,
436
+ 1, 1,
437
+ ))
438
+
439
+ conv2d_problem_vector.append(Conv2DProblemSize(
440
+ batch_size, 56, 56, 256,
441
+ 128, 1, 1, 256,
442
+ 0, 0,
443
+ 2, 2,
444
+ 1, 1,
445
+ ))
446
+
447
+ conv2d_problem_vector.append(Conv2DProblemSize(
448
+ batch_size, 28, 28, 128,
449
+ 128, 3, 3, 128,
450
+ 1, 1,
451
+ 1, 1,
452
+ 1, 1,
453
+ ))
454
+
455
+ conv2d_problem_vector.append(Conv2DProblemSize(
456
+ batch_size, 28, 28, 128,
457
+ 512, 1, 1, 128,
458
+ 0, 0,
459
+ 1, 1,
460
+ 1, 1,
461
+ ))
462
+
463
+ conv2d_problem_vector.append(Conv2DProblemSize(
464
+ batch_size, 28, 28, 512,
465
+ 128, 1, 1, 512,
466
+ 0, 0,
467
+ 1, 1,
468
+ 1, 1,
469
+ ))
470
+
471
+ conv2d_problem_vector.append(Conv2DProblemSize(
472
+ batch_size, 28, 28, 512,
473
+ 1024, 1, 1, 512,
474
+ 0, 0,
475
+ 2, 2,
476
+ 1, 1,
477
+ ))
478
+
479
+ conv2d_problem_vector.append(Conv2DProblemSize(
480
+ batch_size, 28, 28, 512,
481
+ 256, 1, 1, 512,
482
+ 0, 0,
483
+ 2, 2,
484
+ 1, 1,
485
+ ))
486
+
487
+ conv2d_problem_vector.append(Conv2DProblemSize(
488
+ batch_size, 14, 14, 256,
489
+ 256, 3, 3, 256,
490
+ 1, 1,
491
+ 1, 1,
492
+ 1, 1,
493
+ ))
494
+
495
+ conv2d_problem_vector.append(Conv2DProblemSize(
496
+ batch_size, 14, 14, 256,
497
+ 1024, 1, 1, 256,
498
+ 0, 0,
499
+ 1, 1,
500
+ 1, 1,
501
+ ))
502
+
503
+ conv2d_problem_vector.append(Conv2DProblemSize(
504
+ batch_size, 14, 14, 1024,
505
+ 256, 1, 1, 1024,
506
+ 0, 0,
507
+ 1, 1,
508
+ 1, 1,
509
+ ))
510
+
511
+ conv2d_problem_vector.append(Conv2DProblemSize(
512
+ batch_size, 14, 14, 1024,
513
+ 2048, 1, 1, 1024,
514
+ 0, 0,
515
+ 2, 2,
516
+ 1, 1,
517
+ ))
518
+
519
+ conv2d_problem_vector.append(Conv2DProblemSize(
520
+ batch_size, 14, 14, 1024,
521
+ 512, 1, 1, 1024,
522
+ 0, 0,
523
+ 2, 2,
524
+ 1, 1,
525
+ ))
526
+
527
+ conv2d_problem_vector.append(Conv2DProblemSize(
528
+ batch_size, 7, 7, 512,
529
+ 512, 3, 3, 512,
530
+ 1, 1,
531
+ 1, 1,
532
+ 1, 1,
533
+ ))
534
+
535
+ conv2d_problem_vector.append(Conv2DProblemSize(
536
+ batch_size, 7, 7, 512,
537
+ 2048, 1, 1, 512,
538
+ 0, 0,
539
+ 1, 1,
540
+ 1, 1,
541
+ ))
542
+
543
+ conv2d_problem_vector.append(Conv2DProblemSize(
544
+ batch_size, 7, 7, 2048,
545
+ 512, 1, 1, 2048,
546
+ 0, 0,
547
+ 1, 1,
548
+ 1, 1,
549
+ ))
550
+
551
+ return conv2d_problem_vector
552
+
553
+ def initialize_conv2d_grouped_sizes(self):
554
+ threadblock_n = 128
555
+ threadblock_k = 32
556
+
557
+ sizes = []
558
+ ##########################################
559
+ # One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0
560
+ # One CTA calculates a single group
561
+ ##########################################
562
+ for cta_per_group_k in range(1, 4):
563
+ for groups in range(2, 5):
564
+ conv_k = cta_per_group_k * threadblock_n * groups
565
+ sizes.append(Conv2DProblemSize(
566
+ 1, 8, 8, threadblock_k * 2 * groups,
567
+ conv_k, 3, 3, threadblock_k * 2,
568
+ 1, 1,
569
+ 1, 1,
570
+ 1, 1,
571
+ ConvMode.CrossCorrelation,
572
+ 1,
573
+ groups
574
+ ))
575
+
576
+ # Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K
577
+ sizes.append(Conv2DProblemSize(
578
+ 1, 8, 8, threadblock_k,
579
+ threadblock_n * 2, 3, 3, threadblock_k // 2,
580
+ 1, 1,
581
+ 1, 1,
582
+ 1, 1,
583
+ ConvMode.CrossCorrelation,
584
+ 1,
585
+ 2
586
+ ))
587
+
588
+ sizes.append(Conv2DProblemSize(
589
+ 1, 56, 56, 696,
590
+ 768, 3, 3, 232,
591
+ 1, 1,
592
+ 2, 2,
593
+ 1, 1,
594
+ ConvMode.CrossCorrelation,
595
+ 1,
596
+ 3
597
+ ))
598
+ sizes.append(Conv2DProblemSize(
599
+ 1, 14, 14, 1392,
600
+ 1536, 3, 3, 232,
601
+ 1, 1,
602
+ 1, 1,
603
+ 1, 1,
604
+ ConvMode.CrossCorrelation,
605
+ 1,
606
+ 3
607
+ ))
608
+
609
+ ##########################################
610
+ # One CTA calculate multiple groups: CTA::N % k_per_group = 0
611
+ ##########################################
612
+
613
+ # 2 groups per CTA
614
+ sizes.append(Conv2DProblemSize(
615
+ 1, 8, 8, threadblock_k * 4,
616
+ threadblock_n, 3, 3, threadblock_k * 2,
617
+ 1, 1,
618
+ 1, 1,
619
+ 1, 1,
620
+ ConvMode.CrossCorrelation,
621
+ 1,
622
+ 2
623
+ ))
624
+
625
+ # 2 groups per CTA and partial gemm_k
626
+ sizes.append(Conv2DProblemSize(
627
+ 1, 8, 8, threadblock_k,
628
+ threadblock_n, 3, 3, threadblock_k // 2,
629
+ 1, 1,
630
+ 1, 1,
631
+ 1, 1,
632
+ ConvMode.CrossCorrelation,
633
+ 1,
634
+ 2
635
+ ))
636
+
637
+ # 4 groups per CTA
638
+ sizes.append(Conv2DProblemSize(
639
+ 1, 8, 8, threadblock_k * 8,
640
+ threadblock_n // 2, 3, 3, threadblock_k * 2,
641
+ 1, 1,
642
+ 1, 1,
643
+ 1, 1,
644
+ ConvMode.CrossCorrelation,
645
+ 1,
646
+ 4
647
+ ))
648
+
649
+ # 4 groups per CTA and partial gemm_k
650
+ sizes.append(Conv2DProblemSize(
651
+ 1, 8, 8, threadblock_k * 2,
652
+ threadblock_n // 2, 3, 3, threadblock_k // 2,
653
+ 1, 1,
654
+ 1, 1,
655
+ 1, 1,
656
+ ConvMode.CrossCorrelation,
657
+ 1,
658
+ 4
659
+ ))
660
+
661
+ return sizes
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_sm80.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Low-level functionality tests for Conv2d opreations on SM80
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen.backend.utils.device import device_cc
42
+
43
+ from conv2d_test_utils import *
44
+
45
+
46
+ cutlass_cppgen.set_log_level(logging.WARNING)
47
+ cc = 80
48
+
49
+
50
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is invalid for SM80 tests.')
51
+ class Conv2dSm80(unittest.TestCase):
52
+ """
53
+ Wrapper class to which tests will be added dynamically in __main__
54
+ """
55
+ pass
56
+
57
+
58
+ conv_problems = get_conv_problems()
59
+
60
+
61
+ # Tests for optimized & analytic
62
+ for conv_kind in ["fprop", "wgrad", "dgrad"]:
63
+ # F16, simt
64
+ add_test(
65
+ Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
66
+ opclass="simt", threadblock_shape=[128, 128, 8],
67
+ warp_count=[4, 2, 1], stages=2, instruction_shape=[1, 1, 1])
68
+ # F16, tensor op
69
+ add_test(
70
+ Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
71
+ opclass="tensor_op", threadblock_shape=[128, 128, 64],
72
+ warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16])
73
+ # F16, tensor op, analytic iterator
74
+ add_test(
75
+ Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16,
76
+ opclass="tensor_op", threadblock_shape=[128, 128, 64],
77
+ warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="analytic")
78
+ # F16, tensor op, f32 output
79
+ add_test(
80
+ Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32,
81
+ opclass="tensor_op", threadblock_shape=[128, 128, 64],
82
+ warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16])
83
+ # F16, tensor op, different tile description
84
+ add_test(
85
+ Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
86
+ opclass="tensor_op", threadblock_shape=[128, 64, 32],
87
+ warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8])
88
+ # F32, simt
89
+ add_test(
90
+ Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32,
91
+ opclass="simt", threadblock_shape=[128, 128, 8],
92
+ warp_count=[4, 2, 1], stages=4, instruction_shape=[1, 1, 1])
93
+ # Tf32, tensorop
94
+ add_test(
95
+ Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32,
96
+ opclass="tensor_op", threadblock_shape=[128, 128, 16],
97
+ warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8]
98
+ )
99
+ # Split-K
100
+ add_test(
101
+ Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
102
+ opclass="tensor_op", threadblock_shape=[128, 128, 64],
103
+ warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="serial",
104
+ split_k_slices=2)
105
+ add_test(
106
+ Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
107
+ opclass="tensor_op", threadblock_shape=[128, 128, 64],
108
+ warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="parallel",
109
+ split_k_slices=5)
110
+ # Swizzling functor
111
+ add_test(
112
+ Conv2dSm80, cc, conv_kind, conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
113
+ opclass="tensor_op", threadblock_shape=[128, 64, 32],
114
+ warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8], swizzle=4)
115
+
116
+ # Tests for few channels and fixed channels
117
+ # F16, tensor op, few channels
118
+ for c, tb, stage, inst in zip([2, 1],
119
+ [[128, 128, 64], [128, 128, 32]],
120
+ [3, 2],
121
+ [[16, 8, 16], [16, 8, 8]]):
122
+ add_test(
123
+ Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
124
+ opclass="tensor_op", threadblock_shape=tb,
125
+ warp_count=[2, 2, 1], stages=stage, instruction_shape=inst, iterator_algorithm="few_channels"
126
+ )
127
+ # F16, tensor op, fixed channels
128
+ for c in [8, 4, 2]:
129
+ add_test(
130
+ Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
131
+ opclass="tensor_op", threadblock_shape=[128, 128, 64],
132
+ warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="fixed_channels"
133
+ )
134
+
135
+ # Test activations
136
+ for activation in ["relu", "leaky_relu"]:
137
+ for split_k_mode, split_k_slices in zip(["parallel", "serial", "parallel"], [1, 7, 5]):
138
+ add_test(
139
+ Conv2dSm80, cc, "fprop", conv_problems, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16,
140
+ opclass="tensor_op", threadblock_shape=[128, 128, 64],
141
+ warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode=split_k_mode,
142
+ split_k_slices=split_k_slices, activation=activation)
143
+
144
+
145
+ if __name__ == '__main__':
146
+ unittest.main()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/conv2d_test_utils.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utility functions for Conv2d tests.
35
+ """
36
+
37
+ from cutlass_library import SubstituteTemplate
38
+ import torch
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_library import (
42
+ ConvKind,
43
+ ConvMode,
44
+ DataType,
45
+ DataTypeNames,
46
+ EpilogueScheduleSuffixes,
47
+ KernelScheduleSuffixes,
48
+ LayoutType,
49
+ OpcodeClassNames,
50
+ ShortDataTypeNames,
51
+ ShortLayoutTypeNames,
52
+ SplitKMode,
53
+ )
54
+ from cutlass_cppgen.shape import Conv2DProblemSize
55
+ from cutlass_cppgen.utils.datatypes import numpy_type, torch_type
56
+
57
+ from conv2d_problem_sizes import TestbedConv2dProblemSizes
58
+
59
+
60
+ def get_name_conv2d(
61
+ arch,
62
+ conv_kind,
63
+ element,
64
+ element_accumulator,
65
+ element_output,
66
+ opclass,
67
+ threadblock_shape,
68
+ warp_count,
69
+ instruction_shape,
70
+ stages,
71
+ iterator_algorithm,
72
+ swizzle,
73
+ split_k_mode,
74
+ split_k_slices,
75
+ activation
76
+ ):
77
+ """
78
+ Generates a procedural name for a test case for conv2d
79
+
80
+ :param arch: compute capability of kernel being generated
81
+ :type arch: int
82
+ :param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad)
83
+ :type conv_kind: str
84
+ :param iterator_algorithm: the iterator algorithm applied
85
+ :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
86
+ :param element_a: data type of operand A
87
+ :param element_b: data type of operand B
88
+ :param element_c: data type of operand C
89
+ :param element_accumulator: data type used in accumulation
90
+ :param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
91
+ :type opclass: cutlass_cppgen.OpcodeClass
92
+ :param threadblock_shape: indexable container of dimensions of threadblock tiles
93
+ :param stages: number of pipeline stages to use in the kernel
94
+ :type stages: int
95
+ :param stride_support: stride support of dgrad
96
+ :param alignment: int
97
+ :type alignment: int
98
+
99
+ :return: str
100
+ """
101
+ if iterator_algorithm is None:
102
+ iterator_algorithm = "AUTO"
103
+ if swizzle is None:
104
+ swizzle = 1
105
+ name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}"
106
+
107
+ return SubstituteTemplate(
108
+ name_format,
109
+ {
110
+ "arch": str(arch),
111
+ "conv_kind": conv_kind,
112
+ "iter_alg": iterator_algorithm,
113
+ "eA": DataTypeNames[element],
114
+ "eB": DataTypeNames[element],
115
+ "eC": DataTypeNames[element_output],
116
+ "opclass": opclass,
117
+ "acc": DataTypeNames[element_accumulator],
118
+ "tbM": str(threadblock_shape[0]),
119
+ "tbN": str(threadblock_shape[1]),
120
+ "tbK": str(threadblock_shape[2]),
121
+ "wM": str(threadblock_shape[0] // warp_count[0]),
122
+ "wN": str(threadblock_shape[1] // warp_count[1]),
123
+ "wK": str(threadblock_shape[2] // warp_count[2]),
124
+ "IM": str(instruction_shape[0]),
125
+ "IN": str(instruction_shape[1]),
126
+ "IK": str(instruction_shape[2]),
127
+ "stages": str(stages),
128
+ "swizzle": str(swizzle),
129
+ "split_k_mode": split_k_mode,
130
+ "split_k_slices": str(split_k_slices),
131
+ "activation": activation
132
+ }
133
+ )
134
+
135
+
136
+ def conv2d_few_channel_problemsizes(channels):
137
+ problem_sizes = [
138
+ Conv2DProblemSize(
139
+ 1, 8, 8, channels,
140
+ 16, 3, 3, channels,
141
+ 1, 1,
142
+ 2, 2,
143
+ 1, 1,
144
+ ConvMode.CrossCorrelation,
145
+ 1, 1
146
+ ),
147
+ Conv2DProblemSize(
148
+ 1, 16, 16, channels,
149
+ 16, 3, 3, channels,
150
+ 1, 1,
151
+ 2, 2,
152
+ 1, 1,
153
+ ConvMode.CrossCorrelation,
154
+ 1, 1
155
+ ),
156
+ Conv2DProblemSize(
157
+ 1, 16, 16, channels,
158
+ 16, 7, 7, channels,
159
+ 1, 1,
160
+ 1, 1,
161
+ 1, 1,
162
+ ConvMode.CrossCorrelation,
163
+ 1, 1
164
+ ),
165
+ Conv2DProblemSize(
166
+ 1, 224, 224, channels,
167
+ 32, 7, 7, channels,
168
+ 1, 1,
169
+ 1, 1,
170
+ 1, 1,
171
+ ConvMode.CrossCorrelation,
172
+ 1, 1
173
+ ),
174
+ Conv2DProblemSize(
175
+ 1, 224, 224, channels,
176
+ 64, 7, 7, channels,
177
+ 1, 1,
178
+ 2, 2,
179
+ 1, 1,
180
+ ConvMode.CrossCorrelation,
181
+ 1, 1
182
+ ),
183
+ Conv2DProblemSize(
184
+ 1, 224, 224, channels,
185
+ 64, 5, 5, channels,
186
+ 1, 1,
187
+ 1, 1,
188
+ 1, 1,
189
+ ConvMode.CrossCorrelation,
190
+ 1, 1
191
+ ),
192
+ Conv2DProblemSize(
193
+ 1, 224, 224, channels,
194
+ 64, 5, 5, channels,
195
+ 1, 1,
196
+ 2, 2,
197
+ 1, 1,
198
+ ConvMode.CrossCorrelation,
199
+ 1, 1
200
+ ),
201
+ ]
202
+
203
+ return problem_sizes
204
+
205
+
206
+ def validate_problem_size(ps, conv_kind, split_k_slices):
207
+ P = (ps.H + 2 * ps.pad_h - ps.dilation_h * (ps.R - 1) - 1) // ps.stride_h + 1
208
+ Q = (ps.W + 2 * ps.pad_w - ps.dilation_w * (ps.S - 1) - 1) // ps.stride_w + 1
209
+ if P != ps.P or Q != ps.Q:
210
+ return False
211
+
212
+ # Split-K (serial or parallel) is not supported for strided dgrad
213
+ if conv_kind == "dgrad" and split_k_slices > 1 and (ps.stride_h > 1 or ps.stride_w > 1):
214
+ return False
215
+ return True
216
+
217
+
218
+ class Conv2dLauncherFrontend:
219
+ def __init__(self, plan: cutlass_cppgen.Conv2d, seed: int = 80, backend="numpy"):
220
+ self.operation = plan
221
+ self.conv_kind = plan.conv_kind
222
+ self.seed = seed
223
+ self.backend = backend
224
+
225
+ self.dtype_A = plan._element_a
226
+ self.dtype_B = plan._element_b
227
+ self.dtype_C = plan._element_c
228
+ self.dtype_acc = plan._element_accumulator
229
+ self.layout_A = LayoutType.TensorNHWC
230
+ self.layout_B = LayoutType.TensorNHWC
231
+ self.layout_C = LayoutType.TensorNHWC
232
+ self.layout_D = LayoutType.TensorNHWC
233
+
234
+ self.element_compute = DataType.f32
235
+
236
+ if self.dtype_A in [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.bf16]:
237
+ self.rand_max = 1
238
+ else:
239
+ self.rand_max = 4
240
+ self.activation = plan.activation
241
+
242
+ def uniform_init(self, size, dtype):
243
+ tensor = torch.ceil(
244
+ torch.empty(size=size, dtype=torch_type(dtype), device="cuda").uniform_(-self.rand_max - 0.5, self.rand_max - 0.5)
245
+ ).to(memory_format=torch.channels_last)
246
+ return tensor
247
+
248
+ def reference(self, ps, A, B, C, alpha, beta, activation):
249
+ if self.conv_kind == ConvKind.Fprop:
250
+ torch_result = alpha * torch.ops.aten.conv2d(
251
+ A,
252
+ B,
253
+ stride=(ps.stride_h, ps.stride_w),
254
+ padding=(ps.pad_h, ps.pad_w),
255
+ dilation=(ps.dilation_h, ps.dilation_w)
256
+ ) + beta * C
257
+ elif self.conv_kind == ConvKind.Dgrad:
258
+ torch_result = alpha * torch.nn.grad.conv2d_input(
259
+ (ps.N, ps.C, ps.H, ps.W),
260
+ B,
261
+ A,
262
+ padding=(ps.pad_h, ps.pad_w),
263
+ stride=(ps.stride_h, ps.stride_w)
264
+ ) + beta * C
265
+ elif self.conv_kind == ConvKind.Wgrad:
266
+ torch_result = alpha * torch.nn.grad.conv2d_weight(
267
+ B,
268
+ (ps.K, ps.C, ps.R, ps.S),
269
+ A,
270
+ padding=(ps.pad_h, ps.pad_w),
271
+ stride=(ps.stride_h, ps.stride_w)
272
+ ) + beta * C
273
+ else:
274
+ raise Exception(f"Conv kind {self.conv_kind} is currently unsupported.")
275
+
276
+ if activation == cutlass_cppgen.backend.epilogue.relu:
277
+ torch_result = torch.nn.functional.relu(torch_result)
278
+ elif activation == cutlass_cppgen.backend.epilogue.leaky_relu:
279
+ torch_result = torch.nn.functional.leaky_relu(torch_result, 0.5)
280
+ return torch_result
281
+
282
+ def run(self, ps, split_k_mode=SplitKMode.Serial, split_k_slices=1, alpha=1.0, beta=0.0):
283
+ if self.conv_kind == ConvKind.Fprop:
284
+ tensor_A_size = (ps.N, ps.C, ps.H, ps.W)
285
+ tensor_B_size = (ps.K, ps.C, ps.R, ps.S)
286
+ tensor_C_size = (ps.N, ps.K, ps.P, ps.Q)
287
+ elif self.conv_kind == ConvKind.Dgrad:
288
+ tensor_A_size = (ps.N, ps.K, ps.P, ps.Q)
289
+ tensor_B_size = (ps.K, ps.C, ps.R, ps.S)
290
+ tensor_C_size = (ps.N, ps.C, ps.H, ps.W)
291
+ elif self.conv_kind == ConvKind.Wgrad:
292
+ tensor_A_size = (ps.N, ps.K, ps.P, ps.Q)
293
+ tensor_B_size = (ps.N, ps.C, ps.H, ps.W)
294
+ tensor_C_size = (ps.K, ps.C, ps.R, ps.S)
295
+ else:
296
+ raise Exception(f"Conv kind {self.conv_kind} is not supported")
297
+
298
+ torch.manual_seed(self.seed)
299
+
300
+ tensor_A = self.uniform_init(size=tensor_A_size, dtype=self.dtype_A)
301
+ tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B)
302
+ tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C)
303
+ tensor_D = torch.zeros_like(tensor_C).to(memory_format=torch.channels_last)
304
+ args = self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D,
305
+ stride=(ps.stride_h, ps.stride_w),
306
+ padding=(ps.pad_h, ps.pad_w),
307
+ dilation=(ps.dilation_h, ps.dilation_w),
308
+ alpha=alpha, beta=beta,
309
+ split_k=(split_k_mode, split_k_slices))
310
+
311
+ args.sync()
312
+
313
+ tensor_D_ref = self.reference(ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation)
314
+
315
+ torch.cuda.synchronize()
316
+ passed = torch.allclose(tensor_D, tensor_D_ref, atol=2e-06)
317
+
318
+ return passed
319
+
320
+
321
+ def add_test(
322
+ cls,
323
+ cc,
324
+ conv_kind,
325
+ problem_sizes,
326
+ element,
327
+ element_accumulator,
328
+ element_output,
329
+ opclass,
330
+ threadblock_shape,
331
+ warp_count,
332
+ instruction_shape,
333
+ stages,
334
+ iterator_algorithm=None,
335
+ swizzle=None,
336
+ split_k_mode="serial",
337
+ split_k_slices=1,
338
+ activation = "identity"
339
+ ):
340
+ """Create a test-running function with the given specification"""
341
+ test_name = get_name_conv2d(
342
+ cc, conv_kind, element, element_accumulator,
343
+ element_output, opclass, threadblock_shape, warp_count, instruction_shape, stages,
344
+ iterator_algorithm, swizzle, split_k_mode, split_k_slices, activation)
345
+
346
+ def run(self):
347
+ # Create the plan
348
+ plan = cutlass_cppgen.Conv2d(
349
+ kind=conv_kind,
350
+ element=element,
351
+ element_accumulator=element_accumulator,
352
+ element_C=element_output,
353
+ element_D=element_output
354
+ )
355
+
356
+ # Set the opclass
357
+ plan.opclass = opclass
358
+ # Set the tile description
359
+ td = {
360
+ "threadblock_shape": threadblock_shape,
361
+ "warp_count": warp_count,
362
+ "stages": stages,
363
+ "instruction_shape": instruction_shape,
364
+ }
365
+
366
+ plan.tile_description = td
367
+ # Set iterator algorithm
368
+ if iterator_algorithm is not None:
369
+ plan.iterator_algorithm = iterator_algorithm
370
+ # Set swizzling functor
371
+ if swizzle is not None:
372
+ plan.swizzling_stride = swizzle
373
+
374
+ if activation != "identity":
375
+ if activation == "leaky_relu":
376
+ plan.activation = (cutlass_cppgen.epilogue.leaky_relu, 0.5)
377
+ else:
378
+ plan.activation = getattr(cutlass_cppgen.epilogue, activation)
379
+
380
+ conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch")
381
+
382
+ for ps in problem_sizes:
383
+ if not validate_problem_size(ps, conv_kind, split_k_slices):
384
+ continue
385
+
386
+ self.assertTrue(conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 2.0))
387
+
388
+ setattr(cls, test_name, run)
389
+
390
+ return run
391
+
392
+
393
+ def get_conv_problems():
394
+ # 64: minimum channel size
395
+ conv_problems = TestbedConv2dProblemSizes(64).all
396
+
397
+ # Insert alignment 4 & 2 tests
398
+ conv_problems += [
399
+ Conv2DProblemSize(
400
+ 1, 4, 4, 12,
401
+ 8, 3, 3, 12,
402
+ 0, 0,
403
+ 3, 3,
404
+ 1, 1,
405
+ ConvMode.CrossCorrelation,
406
+ 1, 1
407
+ ),
408
+ Conv2DProblemSize(
409
+ 1, 4, 4, 14,
410
+ 8, 3, 3, 14,
411
+ 0, 0,
412
+ 3, 3,
413
+ 1, 1,
414
+ ConvMode.CrossCorrelation,
415
+ 1, 1
416
+ ),
417
+ Conv2DProblemSize(
418
+ 1, 23, 56, 98,
419
+ 128, 3, 3, 98,
420
+ 4, 5,
421
+ 3, 3,
422
+ 1, 1,
423
+ ConvMode.CrossCorrelation,
424
+ 1, 1
425
+ ),
426
+ ]
427
+
428
+ return conv_problems
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/conv2d/run_all_tests.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import pathlib
34
+ import unittest
35
+
36
+
37
+ if __name__ == '__main__':
38
+ loader = unittest.TestLoader()
39
+ script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/'
40
+ tests = loader.discover(script_dir, 'conv2d_*.py')
41
+ testRunner = unittest.runner.TextTestRunner()
42
+ results = testRunner.run(tests)
43
+ if not results.wasSuccessful():
44
+ raise Exception('Test cases failed')
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/emit/pytorch.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Tests emitting a CUTLASS kernel to a PyTorch CUDA extension
35
+ """
36
+
37
+ import random
38
+ import tempfile
39
+ import unittest
40
+
41
+ from cutlass_library import ConvMode
42
+
43
+ import cutlass_cppgen
44
+
45
+ if cutlass_cppgen.utils.datatypes.is_torch_available():
46
+ import torch
47
+
48
+
49
+ def _initialize(dtype, M: int, N: int, K: int):
50
+ """
51
+ Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K
52
+
53
+ :param dtype: data type of tensors
54
+ :param M: M dimension of GEMM problem
55
+ :type M: int
56
+ :param N: N dimension of GEMM problem
57
+ :type N: int
58
+ :param K: N dimension of GEMM problem
59
+ :type K: int
60
+
61
+ :return: initialized tensors A, B, C, and D
62
+ :rtype: list
63
+ """
64
+ sizes = [(M, K), (K, N), (M, N), (M, N)]
65
+ return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]
66
+
67
+
68
+ def _generate_problems(dtype, num):
69
+ """
70
+ Utility function to generate `num` GEMMs of random sizes
71
+
72
+ :param dtype: data type of tensors
73
+ :param num: number of GEMMs to generate
74
+ :type num: int
75
+
76
+ :return: lists of A, B, C, and D tensors
77
+ :rtype: list
78
+ """
79
+ valid_sizes = [128, 256, 512, 1024]
80
+ As, Bs, Cs, Ds = [], [], [], []
81
+ for _ in range(num):
82
+ M, N, K = [random.choice(valid_sizes) for _ in range(3)]
83
+ A, B, C, D = _initialize(dtype, M, N, K)
84
+ As.append(A)
85
+ Bs.append(B)
86
+ Cs.append(C)
87
+ Ds.append(D)
88
+ return As, Bs, Cs, Ds
89
+
90
+ def _generate_conv2d_problem(conv_kind, dtype, ps):
91
+ """
92
+ Utility function to generate conv2d inputs
93
+
94
+ :param conv_kind: kind of convolution
95
+ :type conv_kind: str
96
+ :param dtype: data type of tensors
97
+ :param problem_size: the conv2d problem size
98
+ :type problem_size: cutlass_cppgen.shape.Conv2DProblemSize
99
+
100
+ :return: initialized tensors A, B, C, and D
101
+ :rtype: list
102
+ """
103
+ if conv_kind == "fprop":
104
+ tensor_A_size = (ps.N, ps.C, ps.H, ps.W)
105
+ tensor_B_size = (ps.K, ps.C, ps.R, ps.S)
106
+ tensor_C_size = (ps.N, ps.K, ps.P, ps.Q)
107
+ elif conv_kind == "dgrad":
108
+ tensor_A_size = (ps.N, ps.K, ps.P, ps.Q)
109
+ tensor_B_size = (ps.K, ps.C, ps.R, ps.S)
110
+ tensor_C_size = (ps.N, ps.C, ps.H, ps.W)
111
+ else:
112
+ tensor_A_size = (ps.N, ps.K, ps.P, ps.Q)
113
+ tensor_B_size = (ps.N, ps.C, ps.H, ps.W)
114
+ tensor_C_size = (ps.K, ps.C, ps.R, ps.S)
115
+ sizes = [tensor_A_size, tensor_B_size, tensor_C_size]
116
+ return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes]
117
+
118
+
119
+ @unittest.skipIf(not cutlass_cppgen.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests')
120
+ class PyTorchExtensionTest(unittest.TestCase):
121
+
122
+ def test_gemm(self):
123
+ random.seed(2023)
124
+
125
+ dtype = torch.float16
126
+ plan = cutlass_cppgen.op.Gemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor)
127
+ op = plan.construct()
128
+
129
+ with tempfile.TemporaryDirectory() as tmpdir:
130
+ mod = cutlass_cppgen.emit.pytorch(op, name='gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True)
131
+
132
+ A, B, C, _ = _initialize(dtype, 1024, 256, 512)
133
+
134
+ D_ref = A @ B
135
+ D = mod.run(A, B)
136
+ assert torch.allclose(D, D_ref)
137
+
138
+ D = mod.run(A, B, C)
139
+ assert torch.allclose(D, D_ref)
140
+
141
+ D = mod.run(A, B, C, 1.0)
142
+ assert torch.allclose(D, D_ref)
143
+
144
+ D = mod.run(A, B, C, 1.0, 0.0)
145
+ assert torch.allclose(D, D_ref)
146
+
147
+ alpha = 2.0
148
+ beta = -1.0
149
+ D_ref = (A @ B) * alpha + (beta * C)
150
+ D = mod.run(A, B, C, alpha, beta)
151
+ assert torch.allclose(D, D_ref)
152
+
153
+ def test_grouped_gemm(self):
154
+ random.seed(2023)
155
+
156
+ dtype = torch.float16
157
+ plan = cutlass_cppgen.op.GroupedGemm(element=dtype, layout=cutlass_cppgen.LayoutType.RowMajor)
158
+ op = plan.construct()
159
+
160
+ with tempfile.TemporaryDirectory() as tmpdir:
161
+ mod = cutlass_cppgen.emit.pytorch(op, name='grouped_gemm_mod', cc=plan.cc, sourcedir=tmpdir, jit=True)
162
+
163
+ As, Bs, Cs, _ = _generate_problems(dtype, 50)
164
+
165
+ def check_all(X, Y):
166
+ for x, y in zip(X, Y):
167
+ assert torch.allclose(x, y)
168
+
169
+ Ds_ref = [a @ b for a, b in zip(As, Bs)]
170
+ Ds = mod.run(As, Bs)
171
+ check_all(Ds, Ds_ref)
172
+
173
+ Ds = mod.run(As, Bs, Cs)
174
+ check_all(Ds, Ds_ref)
175
+
176
+ Ds = mod.run(As, Bs, Cs, 1.0)
177
+ check_all(Ds, Ds_ref)
178
+
179
+ Ds = mod.run(As, Bs, Cs, 1.0, 0.0)
180
+ check_all(Ds, Ds_ref)
181
+
182
+ alpha = 2.0
183
+ beta = -1.0
184
+ Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)]
185
+ Ds = mod.run(As, Bs, Cs, alpha, beta)
186
+ check_all(Ds, Ds_ref)
187
+
188
+ def test_conv2d_fprop(self):
189
+ torch.manual_seed(2023)
190
+
191
+ dtype = torch.float16
192
+ plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32)
193
+ plan.activation = "relu"
194
+
195
+ op = plan.construct()
196
+ with tempfile.TemporaryDirectory() as tmpdir:
197
+ mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
198
+
199
+ problem_size = cutlass_cppgen.shape.Conv2DProblemSize(
200
+ 1, 4, 4, 16,
201
+ 8, 3, 3, 16,
202
+ 0, 0,
203
+ 3, 3,
204
+ 1, 1
205
+ )
206
+
207
+ A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size)
208
+ stride = (problem_size.stride_h, problem_size.stride_w)
209
+ padding = (problem_size.pad_h, problem_size.pad_w)
210
+
211
+ alpha = 1.0
212
+ beta = 0.5
213
+
214
+ D_ref = alpha * torch.ops.aten.conv2d(
215
+ A, B, stride=stride, padding=padding
216
+ ) + beta * C
217
+ D_ref = torch.nn.functional.relu(D_ref)
218
+ D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta)
219
+
220
+ assert torch.allclose(D, D_ref)
221
+
222
+ # Test serial split-K
223
+ D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3)
224
+ assert torch.allclose(D, D_serial_split_k)
225
+
226
+ # Test parallel split-K
227
+ D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7)
228
+ assert torch.allclose(D, D_parallel_split_k)
229
+
230
+
231
+ def test_conv2d_dgrad(self):
232
+ torch.manual_seed(2023)
233
+ dtype = torch.float16
234
+ plan = cutlass_cppgen.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32)
235
+
236
+ op = plan.construct()
237
+ with tempfile.TemporaryDirectory() as tmpdir:
238
+ mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
239
+
240
+ problem_size = cutlass_cppgen.shape.Conv2DProblemSize(
241
+ 1, 4, 4, 16,
242
+ 8, 3, 3, 16,
243
+ 0, 0,
244
+ 3, 3,
245
+ 1, 1,
246
+ ConvMode.CrossCorrelation,
247
+ 1, 1
248
+ )
249
+
250
+ A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size)
251
+ stride = (problem_size.stride_h, problem_size.stride_w)
252
+ padding = (problem_size.pad_h, problem_size.pad_w)
253
+
254
+ alpha = 1.0
255
+ beta = 0.5
256
+ input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W)
257
+ D_ref = alpha * torch.nn.grad.conv2d_input(
258
+ input_size, B, A,
259
+ stride=stride, padding=padding
260
+ ) + beta * C
261
+ D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, )
262
+
263
+ assert torch.allclose(D, D_ref)
264
+
265
+ def test_conv2d_wgrad(self):
266
+ torch.manual_seed(2023)
267
+ dtype = torch.float16
268
+ plan = cutlass_cppgen.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32)
269
+
270
+ op = plan.construct()
271
+ with tempfile.TemporaryDirectory() as tmpdir:
272
+ mod = cutlass_cppgen.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True)
273
+
274
+ problem_size = cutlass_cppgen.shape.Conv2DProblemSize(
275
+ 1, 4, 4, 16,
276
+ 8, 3, 3, 16,
277
+ 0, 0,
278
+ 3, 3,
279
+ 1, 1,
280
+ ConvMode.CrossCorrelation,
281
+ 1, 1
282
+ )
283
+
284
+ A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size)
285
+ stride = (problem_size.stride_h, problem_size.stride_w)
286
+ padding = (problem_size.pad_h, problem_size.pad_w)
287
+
288
+ alpha = 1.0
289
+ beta = 0.5
290
+ weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S)
291
+ D_ref = alpha * torch.nn.grad.conv2d_weight(
292
+ B, weight_size, A,
293
+ stride=stride, padding=padding
294
+ ) + beta * C
295
+ D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta)
296
+
297
+ assert torch.allclose(D, D_ref)
298
+
299
+ # Test serial split-K
300
+ D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3)
301
+ assert torch.allclose(D, D_serial_split_k)
302
+
303
+ # Test parallel split-K
304
+ D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7)
305
+ assert torch.allclose(D, D_parallel_split_k)
306
+
307
+
308
+ if __name__ == '__main__':
309
+ unittest.main()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_compute_sm80_90.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+ """
33
+ Unit test for compute node in SM90
34
+ """
35
+
36
+ import logging
37
+ import unittest
38
+
39
+ import cutlass_cppgen
40
+ from cutlass_cppgen.backend import *
41
+ from cutlass_cppgen.epilogue import *
42
+ from cutlass_cppgen import swizzle
43
+
44
+ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
45
+
46
+ cutlass_cppgen.set_log_level(logging.WARNING)
47
+
48
+
49
+ @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
50
+ class TestEVTCompute(EVTTestCaseBase):
51
+
52
+ def test_arith(self):
53
+ """
54
+ Test Arithmatic op
55
+ """
56
+ def evt_arith_compute(accum, C, alpha, beta, gamma):
57
+ D = ((accum + C) * alpha - gamma) / beta
58
+ return D
59
+
60
+ for m, n, k, l in self.get_problem_sizes(8):
61
+ example_inputs = {
62
+ "accum": self.fake_tensor(self.element, (l, m, n)),
63
+ "C": self.fake_tensor(self.element, (l, m, n)),
64
+ "alpha": 1.5,
65
+ "beta": 0.5,
66
+ "gamma": 2.5,
67
+ "D": self.fake_tensor(self.element, (l, m, n))
68
+ }
69
+
70
+ launcher = EVTTestBed(self.element, evt_arith_compute, example_inputs)
71
+ input_keys = ["C", "alpha", "beta", "gamma"]
72
+ result_keys = ["D"]
73
+ launcher.verify((m, n, k), input_keys, result_keys, l)
74
+
75
+ def test_func_call(self):
76
+ """
77
+ Test Function call
78
+ """
79
+ def evt_func_call(accum, C, alpha, beta, gamma):
80
+ D = multiply_add(relu(accum + alpha) + C, beta, gamma)
81
+ return D
82
+
83
+ for m, n, k, l in self.get_problem_sizes(8):
84
+ example_inputs = {
85
+ "accum": self.fake_tensor(self.element, (l, m, n)),
86
+ "C": self.fake_tensor(self.element, (l, m, n)),
87
+ "alpha": 1.5,
88
+ "beta": 0.5,
89
+ "gamma": 2.5,
90
+ "D": self.fake_tensor(self.element, (l, m, n))
91
+ }
92
+
93
+ launcher = EVTTestBed(self.element, evt_func_call, example_inputs)
94
+ input_keys = ["C", "alpha", "beta", "gamma"]
95
+ result_keys = ["D"]
96
+ launcher.verify((m, n, k), input_keys, result_keys, l)
97
+
98
+ def test_func_call2(self):
99
+ """
100
+ Test Function call
101
+ """
102
+
103
+ def evt_func_call2(accum, C, alpha, beta):
104
+ D = maximum(alpha * accum + beta * C, 0.0)
105
+ return D
106
+
107
+ for m, n, k, l in self.get_problem_sizes(8):
108
+ example_inputs = {
109
+ "accum": self.fake_tensor(self.element, (l, m, n)),
110
+ "C": self.fake_tensor(self.element, (l, m, n)),
111
+ "alpha": 1.5,
112
+ "beta": 0.5,
113
+ "D": self.fake_tensor(self.element, (l, m, n))
114
+ }
115
+
116
+ launcher = EVTTestBed(self.element, evt_func_call2, example_inputs)
117
+ input_keys = ["C", "alpha", "beta"]
118
+ result_keys = ["D"]
119
+ launcher.verify((m, n, k), input_keys, result_keys, l)
120
+
121
+ def test_tanh(self):
122
+ """
123
+ Test Tanh op
124
+ """
125
+ def evt_tanh(accum):
126
+ D = tanh(accum)
127
+ return D
128
+
129
+ for m, n, k, l in self.get_problem_sizes(8):
130
+ example_inputs = {
131
+ "accum": self.fake_tensor(self.element, (l, m, n)),
132
+ "D": self.fake_tensor(self.element, (l, m, n))
133
+ }
134
+
135
+ launcher = EVTTestBed(self.element, evt_tanh, example_inputs)
136
+ input_keys = []
137
+ result_keys = ["D"]
138
+ launcher.verify((m, n, k), input_keys, result_keys, l)
139
+
140
+ def test_sigmoid(self):
141
+ """
142
+ Test Sigmoid op
143
+ """
144
+ def evt_sigmoid(accum):
145
+ D = sigmoid(accum)
146
+ return D
147
+
148
+ for m, n, k, l in self.get_problem_sizes(8):
149
+ example_inputs = {
150
+ "accum": self.fake_tensor(self.element, (l, m, n)),
151
+ "D": self.fake_tensor(self.element, (l, m, n))
152
+ }
153
+
154
+ launcher = EVTTestBed(self.element, evt_sigmoid, example_inputs)
155
+ input_keys = []
156
+ result_keys = ["D"]
157
+ launcher.verify((m, n, k), input_keys, result_keys, l)
158
+
159
+ def test_gelu(self):
160
+ """
161
+ Test GELU op
162
+ """
163
+ def evt_gelu(accum):
164
+ D = gelu(accum)
165
+ return D
166
+
167
+ for m, n, k, l in self.get_problem_sizes(8):
168
+ example_inputs = {
169
+ "accum": self.fake_tensor(self.element, (l, m, n)),
170
+ "D": self.fake_tensor(self.element, (l, m, n))
171
+ }
172
+
173
+ launcher = EVTTestBed(self.element, evt_gelu, example_inputs)
174
+ input_keys = []
175
+ result_keys = ["D"]
176
+ launcher.verify((m, n, k), input_keys, result_keys, l)
177
+
178
+ def test_exp(self):
179
+ """
180
+ Test Exp op
181
+ """
182
+ def evt_exp(accum):
183
+ D = exp(accum)
184
+ return D
185
+
186
+ for m, n, k, l in self.get_problem_sizes(8):
187
+ example_inputs = {
188
+ "accum": self.fake_tensor(self.element, (l, m, n)),
189
+ "D": self.fake_tensor(self.element, (l, m, n))
190
+ }
191
+
192
+ launcher = EVTTestBed(self.element, evt_exp, example_inputs)
193
+ input_keys = []
194
+ result_keys = ["D"]
195
+ launcher.verify((m, n, k), input_keys, result_keys, l)
196
+
197
+ if __name__ == '__main__':
198
+ unittest.main()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ """
34
+ Unit test for store nodes in SM90
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen.backend import *
42
+ from cutlass_cppgen.epilogue import *
43
+
44
+ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
45
+
46
+ cutlass_cppgen.set_log_level(logging.WARNING)
47
+
48
+
49
+ @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
50
+ class TestEVTLayout(EVTTestCaseBase):
51
+
52
+ def test_permute_1(self):
53
+ """
54
+ Returning a tensor with shape [m, n]
55
+ """
56
+ def evt_permute(accum, alpha, C):
57
+ F = alpha * accum
58
+ F_permute = permute(F, indices=(0, 2, 1))
59
+ D_permute = F_permute + permute(C, indices=(0, 2, 1))
60
+ D = permute(D_permute, indices=(0, 2, 1))
61
+ return D, F
62
+
63
+ for m, n, k, l in self.get_problem_sizes(8):
64
+ example_inputs = {
65
+ "accum": self.fake_tensor(self.element, (l, m, n)),
66
+ "alpha": 0.5,
67
+ "C": self.fake_tensor(self.element, (l, m, n)),
68
+ "F": self.fake_tensor(self.element, (l, m, n)),
69
+ "D": self.fake_tensor(self.element, (l, m, n)),
70
+ }
71
+
72
+ launcher = EVTTestBed(self.element, evt_permute, example_inputs)
73
+ input_keys = ["C", "alpha"]
74
+ result_keys = ["D", "F"]
75
+ launcher.verify((m, n, k), input_keys, result_keys, l)
76
+
77
+ @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
78
+ def test_permute_2(self):
79
+ """
80
+ Returning a tensor with shape [m, n]
81
+ """
82
+ def evt_permute(accum, alpha, C):
83
+ F = alpha * accum
84
+ F_permute = permute(F, indices=(0, 2, 1))
85
+ D = F_permute + C
86
+ return D, F
87
+
88
+ for m, n, k, l in self.get_problem_sizes(8):
89
+ example_inputs = {
90
+ "accum": self.fake_tensor(self.element, (l, m, n)),
91
+ "alpha": 0.5,
92
+ "C": self.fake_tensor(self.element, (l, n, m)),
93
+ "F": self.fake_tensor(self.element, (l, m, n)),
94
+ "D": self.fake_tensor(self.element, (l, n, m)),
95
+ }
96
+
97
+ launcher = EVTTestBed(self.element, evt_permute, example_inputs)
98
+ input_keys = ["C", "alpha"]
99
+ result_keys = ["D", "F"]
100
+ launcher.verify((m, n, k), input_keys, result_keys, l)
101
+
102
+ @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
103
+ def test_permute_3(self):
104
+ """
105
+ Returning a tensor with shape [m, n]
106
+ """
107
+ def evt_permute(accum, alpha, C):
108
+ F = alpha * accum
109
+ F_permute = permute(F, indices=(1, 0, 2))
110
+ D = F_permute + C
111
+ return D, F
112
+
113
+ for m, n, k, l in self.get_problem_sizes(8):
114
+ example_inputs = {
115
+ "accum": self.fake_tensor(self.element, (l, m, n)),
116
+ "alpha": 0.5,
117
+ "C": self.fake_tensor(self.element, (m, l, n)),
118
+ "F": self.fake_tensor(self.element, (l, m, n)),
119
+ "D": self.fake_tensor(self.element, (m, l, n)),
120
+ }
121
+
122
+ launcher = EVTTestBed(self.element, evt_permute, example_inputs)
123
+ input_keys = ["C", "alpha"]
124
+ result_keys = ["D", "F"]
125
+ launcher.verify((m, n, k), input_keys, result_keys, l)
126
+
127
+ def test_reshape(self):
128
+ """
129
+ Test reshape
130
+ """
131
+ def evt_reshape(accum, alpha, TensorE):
132
+ F = alpha * accum
133
+ E_reshape = reshape(TensorE, new_shape=(512, 1))
134
+ D = F + E_reshape
135
+ return D
136
+
137
+ example_inputs = {
138
+ "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)),
139
+ "alpha": 0.5,
140
+ "TensorE": self.fake_tensor(self.element, (16, 32)),
141
+ "D": self.fake_tensor(self.element, (self.l, self.m, self.n)),
142
+ }
143
+
144
+ launcher = EVTTestBed(self.element, evt_reshape, example_inputs)
145
+ input_keys = ["alpha", "TensorE"]
146
+ result_keys = ["D"]
147
+ launcher.verify(self.problem_size, input_keys, result_keys, self.l)
148
+
149
+ def test_reshape2(self):
150
+ """
151
+ Test reshape
152
+ """
153
+ def evt_reshape(accum, alpha, TensorE):
154
+ F = alpha * accum
155
+ F_reshape = reshape(F, new_shape=(2, 3, 512, 256))
156
+ D = F_reshape + TensorE
157
+ return D
158
+
159
+ example_inputs = {
160
+ "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)),
161
+ "alpha": 0.5,
162
+ "TensorE": self.fake_tensor(self.element, (2, 3, 1, self.n)),
163
+ "D": self.fake_tensor(self.element, (2, 3, self.m, self.n)),
164
+ }
165
+
166
+ launcher = EVTTestBed(self.element, evt_reshape, example_inputs)
167
+ input_keys = ["alpha", "TensorE"]
168
+ result_keys = ["D"]
169
+ launcher.verify(self.problem_size, input_keys, result_keys, self.l)
170
+
171
+
172
+ if __name__ == '__main__':
173
+ unittest.main()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ """
34
+ Unit test for load nodes in SM90
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen.backend import *
42
+ from cutlass_cppgen.epilogue import *
43
+
44
+ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
45
+
46
+ cutlass_cppgen.set_log_level(logging.WARNING)
47
+
48
+
49
+ @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
50
+ class TestEVTLoad(EVTTestCaseBase):
51
+
52
+ def test_tensor_load(self):
53
+ """
54
+ Load extra tensor with shape [m, n]
55
+ """
56
+ def evt_tensor_load(accum, C, aux, aux_batch):
57
+ D = accum + C + aux + aux_batch
58
+ return D
59
+
60
+ for m, n, k, l in self.get_problem_sizes(8):
61
+ example_inputs = {
62
+ "accum": self.fake_tensor(self.element, (l, m, n)),
63
+ "C": self.fake_tensor(self.element, (l, m, n)),
64
+ "aux": self.fake_tensor(self.element, (m, n)),
65
+ "aux_batch": self.fake_tensor(np.float32, (l, m, n)),
66
+ "D": self.fake_tensor(self.element, (l, m, n)),
67
+ }
68
+
69
+ launcher = EVTTestBed(self.element, evt_tensor_load, example_inputs)
70
+ input_keys = ["C", "aux", "aux_batch"]
71
+ result_keys = ["D"]
72
+ launcher.verify((m, n, k), input_keys, result_keys, l)
73
+
74
+ def test_row_broadcast(self):
75
+ """
76
+ Load extra tensor with shape [1, n]
77
+ """
78
+ def evt_row_broadcast(accum, C, bias, bias_batch):
79
+ D = accum + C + bias + bias_batch
80
+ return D
81
+
82
+ for m, n, k, l in self.get_problem_sizes(8):
83
+ example_inputs = {
84
+ "accum": self.fake_tensor(self.element, (l, m, n)),
85
+ "C": self.fake_tensor(self.element, (l, m, n)),
86
+ "bias": self.fake_tensor(self.element, (n,)),
87
+ "bias_batch": self.fake_tensor(np.float32, (l, 1, n)),
88
+ "D": self.fake_tensor(self.element, (l, m, n)),
89
+ }
90
+
91
+ launcher = EVTTestBed(self.element, evt_row_broadcast, example_inputs)
92
+ input_keys = ["C", "bias", "bias_batch"]
93
+ result_keys = ["D"]
94
+ launcher.verify((m, n, k), input_keys, result_keys, l)
95
+
96
+ def test_column_broadcast(self):
97
+ """
98
+ Load extra tensor with shape [m, 1]
99
+ """
100
+ def evt_column_broadcast(accum, C, bias, bias_batch):
101
+ D = accum + C + bias + bias_batch
102
+ return D
103
+
104
+ for m, n, k, l in self.get_problem_sizes(8):
105
+ example_inputs = {
106
+ "accum": self.fake_tensor(self.element, (l, m, n)),
107
+ "C": self.fake_tensor(self.element, (l, m, n)),
108
+ "bias": self.fake_tensor(self.element, (m, 1)),
109
+ "bias_batch": self.fake_tensor(np.float32, (l, m, 1)),
110
+ "D": self.fake_tensor(self.element, (l, m, n)),
111
+ }
112
+
113
+ launcher = EVTTestBed(self.element, evt_column_broadcast, example_inputs)
114
+ input_keys = ["C", "bias", "bias_batch"]
115
+ result_keys = ["D"]
116
+ launcher.verify((m, n, k), input_keys, result_keys, l)
117
+
118
+ def test_scalar_broadcast(self):
119
+ """
120
+ Load extra tensor with shape [1, 1]
121
+ """
122
+ def evt_scalar_broadcast(accum, C, alpha, alpha_batch):
123
+ D = accum + C + alpha + alpha_batch
124
+ return D
125
+
126
+ for m, n, k, l in self.get_problem_sizes(8):
127
+ example_inputs = {
128
+ "accum": self.fake_tensor(self.element, (l, m, n)),
129
+ "C": self.fake_tensor(self.element, (l, m, n)),
130
+ "alpha": 0.5,
131
+ "alpha_batch": self.fake_tensor(np.float32, (l, 1, 1)),
132
+ "D": self.fake_tensor(self.element, (l, m, n)),
133
+ }
134
+
135
+ launcher = EVTTestBed(self.element, evt_scalar_broadcast, example_inputs)
136
+ input_keys = ["C", "alpha", "alpha_batch"]
137
+ result_keys = ["D"]
138
+ launcher.verify((m, n, k), input_keys, result_keys, l)
139
+
140
+
141
+ if __name__ == '__main__':
142
+ unittest.main()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ """
34
+ Unittest for mixed types of nodes in SM90
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen.backend import *
42
+ from cutlass_cppgen.epilogue import *
43
+ from cutlass_cppgen.swizzle import ThreadblockSwizzleStreamK
44
+
45
+ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+
49
+
50
+ @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
51
+ class TestEVTMixed(EVTTestCaseBase):
52
+
53
+ def test_same_variable_used_multiple_times(self):
54
+ """
55
+ The same variable z0 is used multiple times
56
+ """
57
+ def evt_aux_store(accum):
58
+ z0 = relu(accum)
59
+ D = z0 + z0
60
+ return z0, D
61
+
62
+ for m, n, k, l in self.get_problem_sizes(8):
63
+ example_inputs = {
64
+ "accum": self.fake_tensor(self.element, (l, m, n)),
65
+ "D": self.fake_tensor(self.element, (l, m, n)),
66
+ "z0": self.fake_tensor(self.element, (l, m, n)),
67
+ }
68
+
69
+ launcher = EVTTestBed(self.element, evt_aux_store, example_inputs)
70
+ input_keys = ["accum"]
71
+ result_keys = ["z0", "D"]
72
+ launcher.verify((m, n, k), input_keys, result_keys, l)
73
+
74
+ def test_no_lca(self):
75
+ """
76
+ The same variable z0 is used multiple times
77
+ """
78
+ def evt_no_lca(accum, bias):
79
+ E = relu(accum)
80
+ F = E + bias
81
+ tmp_2 = E + 2
82
+ D = tmp_2 + E
83
+ return D
84
+
85
+ for m, n, k, l in self.get_problem_sizes(8):
86
+ example_inputs = {
87
+ "accum": self.fake_tensor(self.element, (l, m, n)),
88
+ "D": self.fake_tensor(self.element, (l, m, n)),
89
+ "bias": self.fake_tensor(self.element, (m,1), stride=(1,0)),
90
+ }
91
+
92
+ launcher = EVTTestBed(self.element, evt_no_lca, example_inputs)
93
+ input_keys = ["accum", "bias"]
94
+ result_keys = ["D"]
95
+ launcher.verify((m, n, k), input_keys, result_keys, l)
96
+
97
+ def test_mixed_dag(self):
98
+ def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
99
+ F = alpha * accum + (beta * C + aux)
100
+ F_row_max = max(F, dim=[0, 1])
101
+ E = relu(F + 1) + cbias + rbias
102
+ E_col_max = max(E, dim=[0, 2])
103
+ D = E + F
104
+ return D, F, F_row_max, E_col_max
105
+
106
+ if device_cc() == 80:
107
+ alignments = [2, 4, 8]
108
+ else:
109
+ # Sm90 EVT currently only supports 128-bit alignment
110
+ alignments = [8,]
111
+ for align in alignments:
112
+ for m, n, k, l in self.get_problem_sizes(align):
113
+ example_inputs = {
114
+ "accum": self.fake_tensor(self.element, (l, m, n)),
115
+ "alpha": 1.0,
116
+ "C": self.fake_tensor(self.element, (l, m, n)),
117
+ "beta": 1.0,
118
+ "aux": self.fake_tensor(self.element, (l, m, n)),
119
+ "cbias": self.fake_tensor(self.element, (m, 1)),
120
+ "rbias": self.fake_tensor(self.element, (n,)),
121
+ "D": self.fake_tensor(self.element, (l, m, n)),
122
+ "F": self.fake_tensor(self.element, (l, m, n)),
123
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
124
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
125
+ }
126
+
127
+ launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs)
128
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
129
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
130
+ launcher.verify((m, n, k), input_keys, result_keys, l)
131
+
132
+ @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
133
+ def test_mixed_dag_float(self):
134
+ def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
135
+ F = alpha * accum + (beta * C + aux)
136
+ F_row_max = max(F, dim=[0, 1])
137
+ E = relu(F + 1) + cbias + rbias
138
+ E_col_max = max(E, dim=[0, 2])
139
+ D = E + F
140
+ return D, F, F_row_max, E_col_max
141
+
142
+ for align in [3, 2, 4]:
143
+ for m, n, k, l in self.get_problem_sizes(align):
144
+ example_inputs = {
145
+ "accum": self.fake_tensor(np.float32, (l, m, n)),
146
+ "alpha": 1.0,
147
+ "C": self.fake_tensor(np.float32, (l, m, n)),
148
+ "beta": 1.0,
149
+ "aux": self.fake_tensor(np.float32, (l, m, n)),
150
+ "cbias": self.fake_tensor(np.float32, (m, 1)),
151
+ "rbias": self.fake_tensor(np.float32, (n,)),
152
+ "D": self.fake_tensor(np.float32, (l, m, n)),
153
+ "F": self.fake_tensor(np.float32, (l, m, n)),
154
+ "F_row_max": self.fake_tensor(np.float32, (n,)),
155
+ "E_col_max": self.fake_tensor(np.float32, (m, 1))
156
+ }
157
+ launcher = EVTTestBed(DataType.f32, evt_mixed_dag, example_inputs)
158
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
159
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
160
+ launcher.verify((m, n, k), input_keys, result_keys, l)
161
+
162
+ @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
163
+ def test_mixed_dag_stage2(self):
164
+ def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
165
+ F = alpha * accum + (beta * C + aux)
166
+ F_row_max = max(F, dim=[0, 1])
167
+ E = relu(F + 1) + cbias + rbias
168
+ E_col_max = max(E, dim=[0, 2])
169
+ D = E + F
170
+ return D, F, F_row_max, E_col_max
171
+
172
+ for m, n, k, l in self.get_problem_sizes(8):
173
+ example_inputs = {
174
+ "accum": self.fake_tensor(self.element, (l, m, n)),
175
+ "alpha": 1.0,
176
+ "C": self.fake_tensor(self.element, (l, m, n)),
177
+ "beta": 1.0,
178
+ "aux": self.fake_tensor(self.element, (l, m, n)),
179
+ "cbias": self.fake_tensor(self.element, (m, 1)),
180
+ "rbias": self.fake_tensor(self.element, (n,)),
181
+ "D": self.fake_tensor(self.element, (l, m, n)),
182
+ "F": self.fake_tensor(self.element, (l, m, n)),
183
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
184
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
185
+ }
186
+
187
+ launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, epilogue_stages=2)
188
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
189
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
190
+ launcher.verify((m, n, k), input_keys, result_keys, l)
191
+
192
+ @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
193
+ def test_mixed_dag_partition_k(self):
194
+ def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
195
+ F = alpha * accum + (beta * C + aux)
196
+ F_row_max = max(F, dim=[0, 1])
197
+ E = relu(F + 1) + cbias + rbias
198
+ E_col_max = max(E, dim=[0, 2])
199
+ D = E + F
200
+ return D, F, F_row_max, E_col_max
201
+
202
+ for m, n, k, l in self.get_problem_sizes(8):
203
+ example_inputs = {
204
+ "accum": self.fake_tensor(self.element, (l, m, n)),
205
+ "alpha": 1.0,
206
+ "C": self.fake_tensor(self.element, (l, m, n)),
207
+ "beta": 1.0,
208
+ "aux": self.fake_tensor(self.element, (l, m, n)),
209
+ "cbias": self.fake_tensor(self.element, (m, 1)),
210
+ "rbias": self.fake_tensor(self.element, (n,)),
211
+ "D": self.fake_tensor(self.element, (l, m, n)),
212
+ "F": self.fake_tensor(self.element, (l, m, n)),
213
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
214
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
215
+ }
216
+
217
+ tile_description = {
218
+ "threadblock_shape": [128, 128, 64],
219
+ "warp_count": [2, 2, 2]
220
+ }
221
+
222
+ launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, tile_description=tile_description, epilogue_stages=2)
223
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
224
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
225
+ launcher.verify((m, n, k), input_keys, result_keys, l)
226
+
227
+ @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
228
+ def test_mixed_dag_stream_k(self):
229
+ def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
230
+ F = alpha * accum + (beta * C + aux)
231
+ F_row_max = max(F, dim=[0, 1])
232
+ E = relu(F + 1) + cbias + rbias
233
+ E_col_max = max(E, dim=[0, 2])
234
+ D = E + F
235
+ return D, F, F_row_max, E_col_max
236
+
237
+ # High per-sm occupancy tile_description
238
+ tile_description = {
239
+ "threadblock_shape": [128, 128, 32],
240
+ "warp_count": [2, 2, 1],
241
+ "stages": 3
242
+ }
243
+ tds = [None, tile_description]
244
+ for td in tds:
245
+ for m, n, k, l in self.get_problem_sizes(8, k=960, batch_count=[1, 3]):
246
+ if l == 1:
247
+ example_inputs = {
248
+ "accum": self.fake_tensor(self.element, (m, n)),
249
+ "alpha": 1.0,
250
+ "C": self.fake_tensor(self.element, (m, n)),
251
+ "beta": 1.0,
252
+ "aux": self.fake_tensor(self.element, (m, n)),
253
+ "cbias": self.fake_tensor(self.element, (m, 1)),
254
+ "rbias": self.fake_tensor(self.element, (n,)),
255
+ "D": self.fake_tensor(self.element, (m, n)),
256
+ "F": self.fake_tensor(self.element, (m, n)),
257
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
258
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
259
+ }
260
+ else:
261
+ example_inputs = {
262
+ "accum": self.fake_tensor(self.element, (l, m, n)),
263
+ "alpha": 1.0,
264
+ "C": self.fake_tensor(self.element, (l, m, n)),
265
+ "beta": 1.0,
266
+ "aux": self.fake_tensor(self.element, (l, m, n)),
267
+ "cbias": self.fake_tensor(self.element, (m, 1)),
268
+ "rbias": self.fake_tensor(self.element, (n,)),
269
+ "D": self.fake_tensor(self.element, (l, m, n)),
270
+ "F": self.fake_tensor(self.element, (l, m, n)),
271
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
272
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
273
+ }
274
+
275
+ if td is not None:
276
+ launcher = EVTTestBed(
277
+ self.element, evt_mixed_dag, example_inputs,
278
+ tile_description=td,
279
+ swizzling_functor=ThreadblockSwizzleStreamK, backend="torch")
280
+ else:
281
+ launcher = EVTTestBed(
282
+ self.element, evt_mixed_dag, example_inputs,
283
+ swizzling_functor=ThreadblockSwizzleStreamK, backend="torch")
284
+
285
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
286
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
287
+ launcher.verify((m, n, k), input_keys, result_keys, l)
288
+
289
+ def test_mixed_dag_no_batch(self):
290
+ def evt_mixed_dag_no_batch(accum, alpha, C, beta, aux, cbias, rbias):
291
+ F = alpha * accum + (beta * C + aux)
292
+ F_row_max = max(F, dim=[0, 1])
293
+ E = relu(F + 1) + cbias + rbias
294
+ E_col_max = max(E, dim=[0, 2])
295
+ D = E + F
296
+ return D, F, F_row_max, E_col_max
297
+
298
+ for m, n, k, _ in self.get_problem_sizes(8):
299
+ example_inputs = {
300
+ "accum": self.fake_tensor(self.element, (m, n)),
301
+ "alpha": 1.0,
302
+ "C": self.fake_tensor(self.element, (m, n)),
303
+ "beta": 1.0,
304
+ "aux": self.fake_tensor(self.element, (m, n)),
305
+ "cbias": self.fake_tensor(self.element, (m, 1)),
306
+ "rbias": self.fake_tensor(self.element, (n,)),
307
+ "D": self.fake_tensor(self.element, (m, n)),
308
+ "F": self.fake_tensor(self.element, (m, n)),
309
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
310
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
311
+ }
312
+
313
+ launcher = EVTTestBed(self.element, evt_mixed_dag_no_batch, example_inputs)
314
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
315
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
316
+ launcher.verify((m, n, k), input_keys, result_keys, 1)
317
+
318
+ if __name__ == '__main__':
319
+ unittest.main()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ """
34
+ Unit test for store nodes in SM90
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen.backend import *
42
+ from cutlass_cppgen.epilogue import *
43
+
44
+ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
45
+
46
+ cutlass_cppgen.set_log_level(logging.WARNING)
47
+
48
+
49
+ @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
50
+ class TestEVTStore(EVTTestCaseBase):
51
+
52
+ @unittest.skipIf(device_cc() != 90, "This test is only for CC 90")
53
+ def test_invalid_store(self):
54
+ """
55
+ Test invalid store
56
+ """
57
+ def evt_invalid_store(accum):
58
+ D = accum
59
+ F = D + 1 # D has users, which is not allowed on SM90 or higher
60
+ return D, F
61
+
62
+ for m, n, k, l in self.get_problem_sizes(8):
63
+ example_inputs = {
64
+ "accum": self.fake_tensor(self.element, (l, m, n)),
65
+ "D": self.fake_tensor(self.element, (l, m, n)),
66
+ "F": self.fake_tensor(self.element, (l, m, n))
67
+ }
68
+ with self.assertRaisesRegex(
69
+ RuntimeError,
70
+ r"On SM90 or higher, D is expected to be a output node with 0 users "
71
+ r"to enable smem reuse between C and D, but got 1"
72
+ ):
73
+ launcher = EVTTestBed(self.element, evt_invalid_store, example_inputs)
74
+
75
+ break # Only need to test once
76
+
77
+ def test_aux_store(self):
78
+ """
79
+ Returning a tensor with shape [m, n]
80
+ """
81
+ def evt_aux_store(accum, alpha, C):
82
+ F = alpha * accum
83
+ D = F + C
84
+ return D, F
85
+
86
+ for m, n, k, l in self.get_problem_sizes(8):
87
+ example_inputs = {
88
+ "accum": self.fake_tensor(self.element, (l, m, n)),
89
+ "alpha": 0.5,
90
+ "C": self.fake_tensor(self.element, (l, m, n)),
91
+ "F": self.fake_tensor(self.element, (l, m, n)),
92
+ "D": self.fake_tensor(self.element, (l, m, n)),
93
+ }
94
+
95
+ launcher = EVTTestBed(self.element, evt_aux_store, example_inputs)
96
+ input_keys = ["C", "alpha"]
97
+ result_keys = ["D", "F"]
98
+ launcher.verify((m, n, k), input_keys, result_keys, l)
99
+
100
+ def test_col_reduce(self):
101
+ """
102
+ Reduction [m, n] -> [m, 1]
103
+ """
104
+ def evt_row_reduce(accum, alpha, C):
105
+ acc_row_max = max(accum, dim=[2,])
106
+ F = alpha * accum
107
+ F_row_max = max(F, dim=[0, 2])
108
+ D = F + C
109
+ return D, F_row_max, acc_row_max
110
+
111
+ for m, n, k, l in self.get_problem_sizes(8):
112
+ example_inputs = {
113
+ "accum": self.fake_tensor(self.element, (l, m, n)),
114
+ "alpha": 2.0,
115
+ "C": self.fake_tensor(self.element, (l, m, n)),
116
+ "F_row_max": self.fake_tensor(np.float32, (m, 1)),
117
+ "acc_row_max": self.fake_tensor(np.float32, (l, m, 1)),
118
+ "D": self.fake_tensor(self.element, (l, m, n)),
119
+ }
120
+
121
+ launcher = EVTTestBed(self.element, evt_row_reduce, example_inputs)
122
+ input_keys = ["C", "alpha"]
123
+ result_keys = ["D", "F_row_max", "acc_row_max"]
124
+ launcher.verify((m, n, k), input_keys, result_keys, l)
125
+
126
+ def test_row_reduce(self):
127
+ """
128
+ Reduction [m, n] -> [n]
129
+ """
130
+ def evt_col_reduce(accum, alpha, C):
131
+ acc_col_max = max(accum, dim=[1,])
132
+ F = alpha * accum
133
+ F_col_max = max(F, dim=[0, 1])
134
+ D = F + C
135
+ return D, F_col_max, acc_col_max
136
+
137
+ for m, n, k, l in self.get_problem_sizes(8):
138
+ example_inputs = {
139
+ "accum": self.fake_tensor(self.element, (l, m, n)),
140
+ "alpha": 2.0,
141
+ "C": self.fake_tensor(self.element, (l, m, n)),
142
+ "F_col_max": self.fake_tensor(np.float32, (n,)),
143
+ "acc_col_max": self.fake_tensor(np.float32, (l, 1, n)),
144
+ "D": self.fake_tensor(self.element, (l, m, n)),
145
+ }
146
+
147
+ launcher = EVTTestBed(self.element, evt_col_reduce, example_inputs)
148
+ input_keys = ["C", "alpha"]
149
+ result_keys = ["D", "F_col_max", "acc_col_max"]
150
+ launcher.verify((m, n, k), input_keys, result_keys, l)
151
+
152
+ def test_scalar_reduce(self):
153
+ """
154
+ Reduction [m, n] -> [1,]
155
+ """
156
+ def evt_scalar_reduce(accum, alpha, C):
157
+ acc_max = max(accum, dim=[1, 2])
158
+ F = alpha * accum
159
+ F_max = max(F, dim=[0, 1, 2])
160
+ D = F + C
161
+ return D, F_max, acc_max
162
+
163
+ for m, n, k, l in self.get_problem_sizes(8):
164
+ example_inputs = {
165
+ "accum": self.fake_tensor(self.element, (l, m, n)),
166
+ "alpha": 2.0,
167
+ "C": self.fake_tensor(self.element, (l, m, n)),
168
+ "acc_max": self.fake_tensor(np.float32, (l, 1, 1)),
169
+ "F_max": self.fake_tensor(np.float32, (1,)),
170
+ "D": self.fake_tensor(self.element, (l, m, n)),
171
+ }
172
+
173
+ launcher = EVTTestBed(self.element, evt_scalar_reduce, example_inputs)
174
+ input_keys = ["C", "alpha"]
175
+ result_keys = ["D", "F_max", "acc_max"]
176
+ launcher.verify((m, n, k), input_keys, result_keys, l)
177
+
178
+
179
+ if __name__ == '__main__':
180
+ unittest.main()
build/torch212-cxx11-cu132-x86_64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import pathlib
34
+ import unittest
35
+
36
+
37
+ if __name__ == '__main__':
38
+ loader = unittest.TestLoader()
39
+ script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/'
40
+ tests = loader.discover(script_dir, 'evt_*.py')
41
+ testRunner = unittest.runner.TextTestRunner()
42
+ results = testRunner.run(tests)
43
+ if not results.wasSuccessful():
44
+ raise Exception('Test cases failed')