kernels-bot commited on
Commit
d7ecb94
·
verified ·
1 Parent(s): 92f2707

Uploaded using `kernel-builder` (batch 12/32).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py +0 -336
  2. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py +0 -294
  3. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py +0 -306
  4. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py +0 -277
  5. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py +0 -137
  6. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py +0 -42
  7. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py +0 -143
  8. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py +0 -120
  9. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py +0 -169
  10. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py +0 -64
  11. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py +0 -90
  12. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py +0 -217
  13. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py +0 -164
  14. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py +0 -53
  15. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py +0 -97
  16. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py +0 -59
  17. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py +0 -319
  18. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py +0 -46
  19. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py +0 -109
  20. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py +0 -2145
  21. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py +0 -509
  22. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py +0 -121
  23. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py +0 -140
  24. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py +0 -455
  25. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py +0 -35
  26. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py +0 -33
  27. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py +0 -126
  28. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py +0 -33
  29. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py +0 -267
  30. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py +0 -936
  31. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py +0 -56
  32. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py +0 -176
  33. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py +0 -98
  34. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py +0 -569
  35. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py +0 -36
  36. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py +0 -997
  37. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py +0 -725
  38. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py +0 -269
  39. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py +0 -431
  40. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py +0 -184
  41. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py +0 -65
  42. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py +0 -41
  43. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py +0 -262
  44. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py +0 -362
  45. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py +0 -41
  46. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py +0 -196
  47. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py +0 -63
  48. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py +0 -621
  49. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py +0 -482
  50. build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py +0 -250
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py DELETED
@@ -1,336 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Layout manipulation nodes and implementations
35
-
36
- The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
37
- """
38
-
39
- from copy import deepcopy
40
-
41
- from cutlass_library import LayoutType
42
- from pycute import product, flatten
43
-
44
- import cutlass_cppgen
45
- from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
46
- from cutlass_cppgen.backend.evt.ir.node import NodeBase
47
- from cutlass_cppgen.backend.evt.ir.tensor import Tensor
48
-
49
-
50
- class PermutationImpl:
51
- """
52
- Detailed implementation and helper functions for permutation
53
- """
54
- def __init__(self, node) -> None:
55
- assert "indices" in node.kwargs.keys()
56
- self.indices = list(node.kwargs["indices"])
57
- self.inverse_indices = self.get_inverse_indices(self.indices)
58
-
59
- def get_inverse_impl(self):
60
- inverse_impl = deepcopy(self)
61
- inverse_impl.indices = self.inverse_indices
62
- inverse_impl.inverse_indices = self.indices
63
- return inverse_impl
64
-
65
- def update(self, shape):
66
- num_dim = len(shape)
67
- indices = self.indices
68
- num_old_dim = len(indices)
69
- # Add offset
70
- for i, idx in enumerate(indices):
71
- indices[i] = idx + num_dim - num_old_dim
72
- # Add broadcast dims
73
- for i in range(num_dim - num_old_dim):
74
- indices = [i,] + indices
75
-
76
- self.indices = indices
77
- self.inverse_indices = self.get_inverse_indices(self.indices)
78
-
79
- def get_inverse_indices(self, indices):
80
- """
81
- Get the indices for inverse permutation
82
- """
83
- num_dim = len(indices)
84
- inverse_indices = [0] * num_dim
85
- for i in range(num_dim):
86
- inverse_indices[indices[i]] = i
87
- return inverse_indices
88
-
89
- def shape_propagation(self, input_node_meta):
90
- input_shape = input_node_meta.tensor.shape
91
- output_shape = tuple([input_shape[idx] for idx in self.indices])
92
- return output_shape
93
-
94
- def broadcast(self, shape, node_meta: NodeBase):
95
- """
96
- Broadcast the inputs based on current shape
97
- """
98
- self.update(shape)
99
- inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
100
- node_meta.tensor.broadcast(inverse_shape)
101
-
102
- def apply_to_user(self, usr_meta: NodeBase):
103
- """
104
- Propagate the permutation to the users of the current nodes
105
- """
106
- usr_meta.tensor.permute(self.inverse_indices)
107
- if hasattr(usr_meta, "store_tensor"):
108
- if usr_meta.store_tensor is not None:
109
- usr_meta.store_tensor.permute(self.inverse_indices)
110
-
111
- def apply_to_input(self, input_meta: NodeBase):
112
- """
113
- Propagate the permutation to inputs of the current nodes
114
- """
115
- input_meta.tensor.permute(self.indices)
116
- if hasattr(input_meta, "store_tensor"):
117
- if input_meta.store_tensor is not None:
118
- input_meta.store_tensor.permute(self.indices)
119
-
120
-
121
- class ReshapeImpl:
122
- """
123
- Detailed implementation and helper functions for reshape
124
- """
125
- def __init__(self, node) -> None:
126
- self.node = node
127
- assert "new_shape" in node.kwargs.keys()
128
- self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
129
-
130
- def get_inverse_impl(self):
131
- inverse_impl = deepcopy(self)
132
- inverse_impl.output_shape = self.input_shape
133
- inverse_impl.input_shape = self.output_shape
134
- return inverse_impl
135
-
136
- def shape_propagation(self, input_node_meta):
137
- self.input_shape = input_node_meta.tensor.shape
138
- return _list_to_tuple(self.output_shape)
139
-
140
- def broadcast(self, shape, node_meta: NodeBase):
141
- """
142
- Broadcast the inputs based on current shape.
143
- """
144
- # Step 1: infer split
145
- flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
146
- split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
147
- split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
148
-
149
- # broadcast shape -> split_output_shape -> flatten_split_shape
150
- if len(shape) - len(split_output_shape) > 0:
151
- for _ in range(len(shape) - len(split_output_shape)):
152
- split_output_shape = [1,] + split_output_shape
153
- flatten_split_shape = [1,] + flatten_split_shape
154
- split_input_shape = [1,] + split_input_shape
155
- broadcast_factor = []
156
- for dim, old_dim in zip(shape, split_output_shape):
157
- if not isinstance(dim, list):
158
- dim = [dim,]
159
- if not isinstance(old_dim, list):
160
- old_dim = [old_dim,]
161
- if product(tuple(dim)) == product(tuple(old_dim)):
162
- broadcast_factor += [1] * len(old_dim)
163
- elif product(tuple(old_dim)) == 1:
164
- assert len(dim) == 1
165
- broadcast_factor.append(dim[0])
166
- else:
167
- raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
168
-
169
- # flatten_split_shape -> split_input_shape
170
- factor_idx = 0
171
- broadcast_split_input_shape = []
172
- for dim in split_input_shape:
173
- if isinstance(dim, list):
174
- new_dim = []
175
- for d in dim:
176
- new_dim.append(d * broadcast_factor[factor_idx])
177
- factor_idx += 1
178
- broadcast_split_input_shape.append(new_dim)
179
- else:
180
- broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
181
- factor_idx += 1
182
- broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
183
- node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
184
- node_meta.tensor.broadcast(broadcast_split_input_shape)
185
- # Last reshape op to clean up
186
- broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
187
- node_meta.tensor.reshape(broadcast_input_shape)
188
- # Update the input shape and output shape
189
- self.input_shape = _list_to_tuple(node_meta.tensor.shape)
190
- self.output_shape = _list_to_tuple(shape)
191
-
192
- def apply_to_user(self, user_meta: NodeBase):
193
- """
194
- Propagate the reshape to user nodes
195
- """
196
- user_meta.tensor.reshape(tuple(self.input_shape))
197
- if hasattr(user_meta, "store_tensor"):
198
- if user_meta.store_tensor is not None:
199
- user_meta.store_tensor.reshape(tuple(self.input_shape))
200
-
201
- def apply_to_input(self, input_meta: NodeBase):
202
- """
203
- Propagate the reshape to input nodes
204
- """
205
- input_meta.tensor.reshape(tuple(self.output_shape))
206
- if hasattr(input_meta, "store_tensor"):
207
- if input_meta.store_tensor is not None:
208
- input_meta.store_tensor.reshape(tuple(self.output_shape))
209
-
210
- #
211
- # Helper functions
212
- #
213
-
214
- def infer_split(self, input_shape, output_shape):
215
- """
216
- Infer the flatten splitted shape that can be merged to both input_shape and output_shape
217
- """
218
- input_shape = _tuple_to_list(input_shape)
219
- output_shape = _tuple_to_list(output_shape)
220
- if len(input_shape) == 0 and len(output_shape) == 0:
221
- return []
222
- if len(input_shape) == 0:
223
- if product(tuple(output_shape)) != 1:
224
- raise ValueError("Invalid reshape size")
225
- else:
226
- return output_shape
227
- if len(output_shape) == 0:
228
- if product(tuple(input_shape)) != 1:
229
- raise ValueError("Invalid reshape size")
230
- else:
231
- return input_shape
232
- # This is done recursively by only process the last dimension at each time
233
- old_dim = input_shape[-1]
234
- new_dim = output_shape[-1]
235
- # Exact match
236
- if old_dim == new_dim:
237
- return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
238
- # Needs split
239
- if old_dim > new_dim and old_dim % new_dim == 0:
240
- residual = old_dim // new_dim
241
- return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
242
- # Needs merge
243
- if old_dim < new_dim and new_dim % old_dim == 0:
244
- residual = new_dim // old_dim
245
- return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
246
-
247
- raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
248
-
249
- def infer_merge(self, flatten_shape, shape):
250
- flatten_shape = _tuple_to_list(flatten_shape)
251
- shape = _tuple_to_list(shape)
252
- idx_flat = len(flatten_shape) - 1
253
- merged_shape = []
254
- for dim in reversed(shape):
255
- # Exact match
256
- if dim == flatten_shape[idx_flat]:
257
- merged_shape.append(dim)
258
- idx_flat -= 1
259
- # need group
260
- elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
261
- residual = dim
262
- group = []
263
- while(residual > 1):
264
- group.append(flatten_shape[idx_flat])
265
- residual = residual // flatten_shape[idx_flat]
266
- idx_flat -= 1
267
- merged_shape.append(group[::-1])
268
- else:
269
- raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
270
-
271
- return merged_shape[::-1]
272
-
273
-
274
- class LayoutNode(NodeBase):
275
- """
276
- Layout manipulation nodes
277
- """
278
- fn_to_impl = {
279
- "permute": PermutationImpl,
280
- "reshape": ReshapeImpl
281
- }
282
- def __init__(self, name: str, fn, kwargs: dict) -> None:
283
- super().__init__(name)
284
- self.op = "layout"
285
- self.fn = fn
286
- self.kwargs = kwargs
287
- self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
288
-
289
- def get_inverse_node(self):
290
- inverse_node = deepcopy(self)
291
- inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
292
- return inverse_node
293
-
294
- def shape_propagation(self, input_node_metas):
295
- if self._tensor is not None:
296
- return
297
- assert len(input_node_metas) == 1, "Layout node can only have one input node"
298
-
299
- output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
300
-
301
- self._tensor = Tensor(
302
- element=self.element_output,
303
- shape=output_shape, layout_tag=LayoutType.RowMajor
304
- )
305
-
306
- return super().shape_propagation(input_node_metas)
307
-
308
- def type_propagation(self, input_node_metas: 'list[NodeBase]'):
309
- """
310
- The store nodes has element_output = element_input
311
- """
312
- assert len(input_node_metas) == 1, "Layout node can only have one input node"
313
- self.element_output = input_node_metas[0].element_output
314
-
315
- def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
316
- """
317
- Propagate the broadcast in the reversed topological order
318
- """
319
- if self.tensor is None:
320
- raise RuntimeError(f"The tensor of node {self.name} is unknown.")
321
- shape = self.tensor.shape
322
-
323
- for child in input_node_metas:
324
- self.underlying_impl.broadcast(shape, child)
325
-
326
- def apply_to_user(self, usr_meta: NodeBase):
327
- """
328
- Propagate the permutation to user nodes
329
- """
330
- self.underlying_impl.apply_to_user(usr_meta)
331
-
332
- def apply_to_input(self, input_meta: NodeBase):
333
- """
334
- Propagate the permutation to input nodes
335
- """
336
- self.underlying_impl.apply_to_input(input_meta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py DELETED
@@ -1,294 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Load nodes and implementations
35
- """
36
-
37
- import ctypes
38
-
39
- from cutlass_cppgen.backend.c_types import tuple_factory
40
- from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
41
- from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
42
-
43
-
44
- class LoadImplBase(ImplBase):
45
- """
46
- Base class for load node implementations
47
- """
48
- reserved_names = ["accum", "C"]
49
- def __init__(self, node) -> None:
50
- super().__init__(node)
51
- self.element = node.element
52
- self.element_output = node.element_output
53
- self.stride = node.tensor.stride
54
-
55
-
56
- class AccumulatorImpl(LoadImplBase):
57
- """
58
- Accumulator node implementation
59
- """
60
-
61
- @staticmethod
62
- def match(node, problem_size: tuple):
63
- return node.name == "accum" and node.tensor.shape == problem_size
64
-
65
-
66
- class LoadSrcImpl(LoadImplBase):
67
- """
68
- Load C implementation
69
- """
70
- @property
71
- def name_camel(self) -> str:
72
- return "TensorC"
73
-
74
- @property
75
- def argument_type_c(self):
76
- stride_mnl = self.get_stride_mnl()
77
- tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
78
- class _Argument(ctypes.Structure):
79
- _fields_ = [
80
- ("ptr_C", ctypes.c_void_p),
81
- ("stride_C", tuple_type)
82
- ]
83
- def __init__(self, ptr) -> None:
84
- self.ptr_C = ptr
85
- self.stride_C = tuple_type(stride_mnl)
86
-
87
- return _Argument
88
-
89
- @staticmethod
90
- def match(node, problem_size: tuple):
91
- return node.name == "C" and node.tensor.shape == problem_size
92
-
93
-
94
- class AuxLoadImpl(LoadImplBase):
95
- """
96
- Load arbitrary tensor
97
- """
98
- @property
99
- def argument_type(self):
100
- stride_mnl = self.get_stride_mnl()
101
- name = self.name
102
- tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
103
- element_type = self.element
104
- class _Argument(ctypes.Structure):
105
- _fields_ = [
106
- ("ptr_aux", ctypes.c_void_p),
107
- ("null_default", dtype2ctype[element_type]),
108
- ("dAux", tuple_type)
109
- ]
110
- def __init__(self, kwargs) -> None:
111
- ptr = kwargs[name]
112
- self.ptr_aux = ptr
113
- self.null_default = to_ctype_value(0, element_type)
114
- self.dAux = tuple_type(stride_mnl)
115
-
116
- return _Argument
117
-
118
- @staticmethod
119
- def match(node, problem_size: tuple):
120
- if node.name in LoadImplBase.reserved_names:
121
- return False
122
- strideMN = node.tensor.stride[-2:]
123
- if (strideMN[0] == 1 and strideMN[1] != 0 or
124
- strideMN[0] != 0 and strideMN[1] == 1 ):
125
- return True
126
- else:
127
- return False
128
-
129
-
130
- class RowBroadcastImpl(LoadImplBase):
131
- """
132
- Broadcast a row vector
133
- """
134
- def __init__(self, node) -> None:
135
- super().__init__(node)
136
- self.stride_dtype = "int"
137
-
138
- @property
139
- def argument_type(self):
140
- stride_mnl = self.get_stride_mnl()
141
- name = self.name
142
- tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
143
- element_type = self.element
144
- class _Argument(ctypes.Structure):
145
- _fields_ = [
146
- ("ptr_row", ctypes.c_void_p),
147
- ("null_default", dtype2ctype[element_type]),
148
- ("dRow", tuple_type)
149
- ]
150
- def __init__(self, kwargs) -> None:
151
- ptr = kwargs[name]
152
- self.ptr_row = ptr
153
- self.null_default = to_ctype_value(0, element_type)
154
- self.dRow = tuple_type(stride_mnl)
155
-
156
- return _Argument
157
-
158
- @staticmethod
159
- def match(node, problem_size: tuple):
160
- if node.name in LoadImplBase.reserved_names:
161
- return False
162
-
163
- strideMN = node.tensor.stride[-2:]
164
- if strideMN == (0, 1):
165
- return True
166
- else:
167
- return False
168
-
169
-
170
- class ColumnBroadcastImpl(LoadImplBase):
171
- """
172
- Broadcast a column vector
173
- """
174
- def __init__(self, node) -> None:
175
- super().__init__(node)
176
- self.stride_dtype = "int"
177
-
178
- @property
179
- def argument_type(self):
180
- stride_mnl = self.get_stride_mnl()
181
- name = self.name
182
- tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
183
- element_type = self.element
184
- class _Argument(ctypes.Structure):
185
- _fields_ = [
186
- ("ptr_col", ctypes.c_void_p),
187
- ("null_default", dtype2ctype[element_type]),
188
- ("dCol", tuple_type)
189
- ]
190
- def __init__(self, kwargs) -> None:
191
- ptr = kwargs[name]
192
- self.ptr_col = int(ptr)
193
- self.null_default = to_ctype_value(0, element_type)
194
- self.dCol = tuple_type(stride_mnl)
195
-
196
- return _Argument
197
-
198
- @staticmethod
199
- def match(node, problem_size: tuple):
200
- if node.name in LoadImplBase.reserved_names:
201
- return False
202
-
203
- strideMN = node.tensor.stride[-2:]
204
- if strideMN == (1, 0):
205
- return True
206
- else:
207
- return False
208
-
209
-
210
- class ScalarBroadcastImpl(LoadImplBase):
211
- """
212
- Broadcast a scalar
213
- """
214
- def __init__(self, node) -> None:
215
- super().__init__(node)
216
- self.stride_dtype = "int"
217
-
218
- @property
219
- def argument_type(self):
220
- stride_mnl = self.get_stride_mnl()
221
- name = self.name
222
- tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
223
- element_type = self.element
224
-
225
- if self.tensor.is_constant:
226
- value = self.tensor.value
227
- class _Argument(ctypes.Structure):
228
- _fields_ = [
229
- ("scalars", dtype2ctype[element_type]),
230
- ("scalar_ptrs", ctypes.c_void_p),
231
- ("dScalar", tuple_type)
232
- ]
233
- def __init__(self, kwargs) -> None:
234
- self.scalars = to_ctype_value(value, element_type)
235
- self.scalar_ptrs = 0
236
- self.dScalar = tuple_type(stride_mnl)
237
-
238
- else:
239
- class _Argument(ctypes.Structure):
240
- _fields_ = [
241
- ("scalars", dtype2ctype[element_type]),
242
- ("scalar_ptrs", ctypes.c_void_p),
243
- ("dScalar", tuple_type)
244
- ]
245
- def __init__(self, kwargs) -> None:
246
- scalar_or_ptr = kwargs[name]
247
- if isinstance(scalar_or_ptr, float):
248
- self.scalars = to_ctype_value(scalar_or_ptr, element_type)
249
- self.scalar_ptrs = 0
250
- else:
251
- self.scalar_ptrs = int(scalar_or_ptr)
252
-
253
- self.dScalar = tuple_type(stride_mnl)
254
-
255
- return _Argument
256
-
257
- @staticmethod
258
- def match(node, problem_size: tuple):
259
- if node.name in LoadImplBase.reserved_names:
260
- return False
261
-
262
- strideMN = node.tensor.stride[-2:]
263
- if strideMN == (0, 0):
264
- return True
265
- else:
266
- return False
267
-
268
-
269
- class LoadNode(NodeBase):
270
- """
271
- Load Node
272
- """
273
- cnt = 0
274
- possible_impls = [
275
- AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
276
- RowBroadcastImpl, ColumnBroadcastImpl,
277
- ScalarBroadcastImpl
278
- ]
279
- def __init__(self, name: str) -> None:
280
- if name is None:
281
- name = f"load{LoadNode.cnt}"
282
- LoadNode.cnt += 1
283
- super().__init__(name)
284
- self.op = "load"
285
-
286
- def type_propagation(self, *args, **kwargs):
287
- """
288
- Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
289
- """
290
- if self.tensor is None:
291
- raise RuntimeError(f"The tensor of node {self.name} is unknown.")
292
-
293
- self.element = self.tensor.element
294
- self.element_output = self.tensor.element
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py DELETED
@@ -1,306 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Base & visitor classes of DAGIR Nodes
35
- """
36
-
37
- import ctypes
38
- from re import sub
39
-
40
- from cutlass_library import LayoutType
41
-
42
- from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
43
- from cutlass_cppgen.backend.evt.ir.tensor import Tensor
44
-
45
-
46
- class TupleEmitter:
47
- """
48
- Emit the cute tuple to C++ code
49
- """
50
- def __init__(self, stride_dtype):
51
- self.stride_dtype = stride_dtype
52
-
53
- def emit(self, py_tuple):
54
- if isinstance(py_tuple, int):
55
- if py_tuple in [0, 1]:
56
- return f"cute::Int<{py_tuple}>"
57
- else:
58
- return f"{self.stride_dtype}"
59
- elif isinstance(py_tuple, tuple):
60
- decl = "cute::Stride<"
61
- for item in py_tuple:
62
- decl += self.emit(item) + ", "
63
- return decl[:-2] + ">"
64
- else:
65
- raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}")
66
-
67
-
68
- class ImplBase:
69
- """
70
- Base class for Node Implementation
71
- """
72
- def __init__(self, node) -> None:
73
- self.node = node
74
- self.name = node.name
75
- self.tensor = node.tensor
76
- self._type_decl = None
77
- self.tuple_emitter = TupleEmitter("int64_t")
78
-
79
- @property
80
- def stride_dtype(self):
81
- return self.tuple_emitter.stride_dtype
82
-
83
- @stride_dtype.setter
84
- def stride_dtype(self, stride_dtype):
85
- self.tuple_emitter.stride_dtype = stride_dtype
86
-
87
- @staticmethod
88
- def match(node, problem_size: tuple):
89
- """
90
- Match function used in get_underlying_impl
91
- """
92
- raise NotImplementedError(f"The `match` function is not defined.")
93
-
94
- @property
95
- def argument_type(self):
96
- """
97
- Default class for Argument Type
98
- """
99
- class _Argument(ctypes.Structure):
100
- _fields_ = []
101
-
102
- def __init__(self, *args, **kwargs) -> None:
103
- pass
104
-
105
- return _Argument
106
-
107
- @property
108
- def name_camel(self) -> str:
109
- """
110
- Return the CamelCase name.
111
- """
112
- return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
113
-
114
- @property
115
- def stride_mnl(self):
116
- """
117
- Typename StrideMNL
118
- """
119
- stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
120
- return self.tuple_emitter.emit(stride)
121
-
122
- def get_non_constant_stride(self, py_tuple):
123
- if isinstance(py_tuple, int):
124
- if py_tuple not in [0, 1]:
125
- return py_tuple
126
- else:
127
- return None
128
- non_constant_stride = []
129
- for item in py_tuple:
130
- item_out = self.get_non_constant_stride(item)
131
- if item_out:
132
- non_constant_stride.append(item_out)
133
- return tuple(non_constant_stride)
134
-
135
- def get_stride_mnl(self):
136
- """
137
- Get the non-zero stride mnl. This is used in argument construction
138
- """
139
- stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
140
- return stride
141
-
142
- def get_smem_size(self, *args, **kwargs):
143
- """
144
- Get the shared memory size and alignment of current node
145
- """
146
- return (0, 1)
147
-
148
-
149
- class NoOpImpl(ImplBase):
150
- """
151
- The NoOpImpl does nothing but forward its input to users
152
- """
153
- def __init__(self, node) -> None:
154
- super().__init__(node)
155
-
156
- @staticmethod
157
- def match(node, problem_size: tuple):
158
- if node.op == "store":
159
- # Store that is not output is a No OP
160
- return not node.is_output
161
-
162
-
163
- class NodeBase:
164
- """
165
- Base class of DAG Node
166
- """
167
- def __init__(self, name: str) -> None:
168
- self.name = name
169
- self.underlying_impl = None
170
-
171
- self._tensor = None
172
-
173
- # Whether the node is disabled for emit
174
- self.disabled = False
175
-
176
- @property
177
- def name_camel(self) -> str:
178
- """
179
- Return the CamelCase name.
180
- """
181
- return self.underlying_impl.name_camel
182
-
183
- @property
184
- def tensor(self) -> Tensor:
185
- """
186
- Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
187
- """
188
- return self._tensor
189
-
190
- @tensor.setter
191
- def tensor(self, kwargs):
192
- """
193
- Setting the tensor
194
- """
195
- self._tensor = Tensor(**kwargs)
196
-
197
- #
198
- # Helper functions for type/shape propagation
199
- #
200
-
201
- def shape_propagation(self, input_node_metas):
202
- """
203
- Infer shape from input nodes
204
- General Broadcasting Rules from NumPy
205
- When operating on two arrays, we compare their shapes element-wise.
206
- It starts with the trailing (i.e. rightmost) dimension and works its
207
- way left. Two dimensions are compatible when
208
- 1. they are equal
209
- 2. one of them is 1
210
- """
211
- if self._tensor is not None:
212
- return
213
-
214
- shape = None
215
- for src in input_node_metas:
216
- src_shape = src.tensor.shape
217
- if shape is None:
218
- shape = src_shape
219
- else:
220
- len_difference = len(shape) - len(src_shape)
221
- if len_difference > 0:
222
- for _ in range(len_difference):
223
- src_shape = [1, ] + list(src_shape)
224
- elif len_difference < 0:
225
- for _ in range(-len_difference):
226
- shape = [1, ] + list(shape)
227
- broadcasted_shape = []
228
- # Infer broadcast shape
229
- for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
230
- if shape_dim == 1:
231
- broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
232
- elif src_dim == 1:
233
- broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
234
- elif shape_dim == src_dim:
235
- broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
236
- else:
237
- error_msg = "Dimension mismatch between "
238
- for src_ in input_node_metas:
239
- error_msg += f"{src_.name}{src_.tensor.shape}, "
240
- error_msg = error_msg[:-2] + "."
241
- raise RuntimeError(error_msg)
242
- shape = tuple(broadcasted_shape)
243
-
244
- self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
245
-
246
- def type_propagation(self, *args, **kwargs):
247
- """
248
- Each node is associated with two data types: `element` and `element_output`.
249
- The `element_output` is the type of return array of the node. The `element`
250
- has specific meaning for different node types.
251
- * Load Node: data type of tensor in gmem
252
- * Compute Node: element compute
253
- * Store Node: data type of tensor in gmem
254
- This function must be overloaded in the derived classes
255
- """
256
- raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
257
-
258
- def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
259
- """
260
- Propagate the broadcast in the reversed topological order.
261
- For example:
262
- C[l, m, n] = A[m, 1] + B[l, m, n]
263
- After the broadcast propagation, it will be come
264
- C[l, m, n] = A[l, m, n] + B[l, m, n]
265
- and each tensor will have a proper stride accessing the underlying tensor
266
- """
267
- if self.tensor is None:
268
- raise RuntimeError(f"The tensor of node {self.name} is unknown.")
269
- for child in input_node_metas:
270
- child.tensor.broadcast(self.tensor.shape)
271
-
272
- def get_underlying_impl(self, problem_size: tuple):
273
- """
274
- Get the underlying implementation of the current node.
275
- """
276
- if self.tensor is None:
277
- raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
278
-
279
- for impl in self.possible_impls:
280
- if impl.match(self, problem_size):
281
- self.underlying_impl = impl(self)
282
- break
283
-
284
- if self.underlying_impl is None:
285
- raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
286
-
287
- #
288
- # Visitor Nodes & Impls
289
- #
290
-
291
- class TopoVisitorImpl(ImplBase):
292
- """
293
- Impl for topological visitor
294
- """
295
- def __init__(self, node) -> None:
296
- super().__init__(node.output_node)
297
- self.name = node.name
298
- self.element_output = node.output_node.element_output
299
-
300
- class TopoVisitorNode(NodeBase):
301
- def __init__(self, name: str, subgraph, output_node) -> None:
302
- super().__init__(name)
303
- self.subgraph = subgraph
304
- self.output_node = output_node
305
- self.op = "dag"
306
- self.underlying_impl = TopoVisitorImpl(self)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py DELETED
@@ -1,277 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Store node and implementations
35
- """
36
-
37
- import ctypes
38
-
39
- from cutlass_library import DataType
40
-
41
- from cutlass_cppgen.backend.c_types import tuple_factory
42
- from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
43
- from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
44
- from cutlass_cppgen.backend.evt.ir.tensor import Tensor
45
- from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
46
-
47
-
48
- class StoreImplBase(ImplBase):
49
- """
50
- Base class for store node implementation
51
- """
52
- reserved_names = ["D"]
53
- def __init__(self, node) -> None:
54
- super().__init__(node)
55
- self.element = node.element
56
- self.element_output = node.element_output
57
- self.stride = node.store_tensor.stride
58
-
59
-
60
- class StoreDImpl(StoreImplBase):
61
- """
62
- Store D implementation
63
- """
64
-
65
- @property
66
- def argument_type_d(self):
67
- stride_mnl = self.get_stride_mnl()
68
- tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
69
- class _Argument(ctypes.Structure):
70
- _fields_ = [
71
- ("ptr_D", ctypes.c_void_p),
72
- ("stride_D", tuple_type)
73
- ]
74
- def __init__(self, ptr: int) -> None:
75
- self.ptr_D = ptr
76
- self.stride_D = tuple_type(stride_mnl)
77
-
78
- return _Argument
79
-
80
- @staticmethod
81
- def match(node, problem_size: tuple):
82
- if node.name == "D" and node.store_tensor.shape == problem_size:
83
- return True
84
- return False
85
-
86
-
87
- class AuxStoreImpl(StoreImplBase):
88
- def __init__(self, node) -> None:
89
- super().__init__(node)
90
- self.round_style = FloatRoundStyle.ToNearest
91
-
92
- @property
93
- def argument_type(self):
94
- stride_mnl = self.get_stride_mnl()
95
- name = self.name
96
- tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
97
- class _Argument(ctypes.Structure):
98
- _fields_ = [
99
- ("ptr_aux", ctypes.c_void_p),
100
- ("dAux", tuple_type)
101
- ]
102
- def __init__(self, kwargs) -> None:
103
- ptr = kwargs[name]
104
- self.ptr_aux = ptr
105
- self.dAux = tuple_type(stride_mnl)
106
-
107
- return _Argument
108
-
109
- @staticmethod
110
- def match(node, problem_size: tuple):
111
- if not node.is_output:
112
- return False
113
- if node.name in StoreImplBase.reserved_names:
114
- return False
115
-
116
- strideMN = node.store_tensor.stride[-2:]
117
- if (strideMN[0] == 1 and strideMN[1] != 0 or
118
- strideMN[0] != 0 and strideMN[1] == 1 ):
119
- return True
120
- else:
121
- return False
122
-
123
-
124
- class ReductionImplBase(StoreImplBase):
125
- def __init__(self, node) -> None:
126
- super().__init__(node)
127
- self.element = node.store_tensor.element
128
- self.element_compute = node.element_compute
129
- self.reg_reduce_fn = self.node.reg_reduce_fn
130
- self.gmem_reduce_fn = self.node.gmem_reduce_fn
131
- self.round_style = node.round_style
132
- self.stride_dtype = "int"
133
-
134
- def get_reduce_identity(self):
135
- """
136
- Return the reduction identity of the current reduce_fn
137
- """
138
- maxes = {
139
- DataType.f32: (2 ** 31) - 1,
140
- DataType.f16: (2 ** 15),
141
- DataType.s32: (2 ** 31) - 1,
142
- DataType.s8: (2 ** 7) - 1
143
- }
144
- mins = {
145
- DataType.f32: -maxes[DataType.f32],
146
- DataType.f16: -maxes[DataType.f16],
147
- DataType.s32: -maxes[DataType.s32],
148
- DataType.s8: -maxes[DataType.s8]
149
- }
150
- if self.reg_reduce_fn == FunctionalOp.Maximum:
151
- if self.element_compute not in mins:
152
- raise Exception(f"No min entry for data type {self.element_compute}")
153
- return to_ctype_value(mins[self.element_compute], self.element_compute)
154
- elif self.reg_reduce_fn == FunctionalOp.Multiplies:
155
- return to_ctype_value(1., self.element_compute)
156
- elif self.reg_reduce_fn == FunctionalOp.Minimum:
157
- if self.element_compute not in maxes:
158
- raise Exception(f"No max entry for data type {self.element_compute}")
159
- return to_ctype_value(maxes[self.element_compute], self.element_compute)
160
- else:
161
- return to_ctype_value(0., self.element_compute)
162
-
163
- @property
164
- def argument_type(self):
165
- self.get_reduce_identity()
166
- stride_mnl = self.get_stride_mnl()
167
- name = self.name
168
- tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
169
- element_compute = self.element_compute
170
- reduce_identity = self.get_reduce_identity()
171
- class _Argument(ctypes.Structure):
172
- _fields_ = [
173
- ("ptr", ctypes.c_void_p),
174
- ("reduce_identity", dtype2ctype[element_compute]),
175
- ("dMNL", tuple_type)
176
- ]
177
- def __init__(self, kwargs) -> None:
178
- ptr = kwargs[name]
179
- self.ptr = ptr
180
- self.reduce_identity = reduce_identity
181
- self.dMNL = tuple_type(stride_mnl)
182
-
183
- return _Argument
184
-
185
-
186
- class ColumnReductionImpl(ReductionImplBase):
187
-
188
- @staticmethod
189
- def match(node, problem_size: tuple):
190
- if not node.is_output:
191
- return False
192
- if node.name in StoreImplBase.reserved_names:
193
- return False
194
-
195
- strideMN = node.store_tensor.stride[-2:]
196
- if strideMN == (1, 0):
197
- return True
198
- else:
199
- return False
200
-
201
-
202
- class RowReductionImpl(ReductionImplBase):
203
-
204
- @staticmethod
205
- def match(node, problem_size: tuple):
206
- if not node.is_output:
207
- return False
208
- if node.name in StoreImplBase.reserved_names:
209
- return False
210
-
211
- strideMN = node.store_tensor.stride[-2:]
212
- if strideMN == (0, 1):
213
- return True
214
- else:
215
- return False
216
-
217
-
218
- class ScalarReductionImpl(ReductionImplBase):
219
-
220
- @staticmethod
221
- def match(node, problem_size: tuple):
222
- if not node.is_output:
223
- return False
224
- if node.name in StoreImplBase.reserved_names:
225
- return False
226
-
227
- strideMN = node.store_tensor.stride[-2:]
228
- if strideMN == (0, 0):
229
- return True
230
- else:
231
- return False
232
-
233
-
234
- class StoreNode(NodeBase):
235
- """
236
- Store node
237
- """
238
- possible_impls = [
239
- AuxStoreImpl, RowReductionImpl,
240
- ColumnReductionImpl, ScalarReductionImpl,
241
- NoOpImpl, StoreDImpl
242
- ]
243
- def __init__(self, name: str) -> None:
244
- super().__init__(name)
245
- self.op = "store"
246
- self.is_output = False
247
- self._store_tensor = None
248
-
249
- @property
250
- def store_tensor(self) -> Tensor:
251
- """
252
- Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
253
- """
254
- return self._store_tensor
255
-
256
- @store_tensor.setter
257
- def store_tensor(self, kwargs):
258
- """
259
- Setting the tensor
260
- """
261
- self._store_tensor = Tensor(**kwargs)
262
-
263
- def type_propagation(self, input_node_metas: 'list[NodeBase]'):
264
- """
265
- The store nodes has element_output = element_input
266
- """
267
- if self.is_output:
268
- if self.store_tensor is None:
269
- raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
270
- self.element = self.store_tensor.element
271
- assert len(input_node_metas) == 1, "Store node can only have one input node"
272
- self.element_output = input_node_metas[0].element_output
273
-
274
- def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
275
- super().broadcast_propagation(input_node_metas)
276
- if self.is_output:
277
- self._store_tensor.broadcast(self.tensor.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py DELETED
@@ -1,137 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- High-level class for tensor
35
- """
36
-
37
- from cutlass_library import LayoutType
38
-
39
- from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
40
- Layout,
41
- broadcast,
42
- canonicalization,
43
- permutation,
44
- reshape,
45
- _reverse_tuple
46
- )
47
- from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
48
-
49
-
50
- class Tensor:
51
- """
52
- The tensor abstracts the data type
53
- """
54
- def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
55
- if element is not None and tensor is not None:
56
- raise Exception(f"Must not specify both element and tensor")
57
- elif shape is not None and tensor is not None:
58
- raise Exception(f"Must not specify both shape and tensor")
59
- elif layout_tag is not None and tensor is not None:
60
- raise Exception(f"Must not specify both layout_tag and tensor")
61
- elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
62
- raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
63
- elif stride is not None and tensor is not None:
64
- raise Exception(f"Must not specify both stride and tensor")
65
- elif stride is not None and layout_tag is not None:
66
- raise Exception(f"Must not specify layout_tag when stride is provided")
67
-
68
- if isinstance(tensor, Tensor):
69
- # Directly copy all the attributes
70
- self.__dict__.update(vars(tensor))
71
- else:
72
- if tensor is None:
73
- self.element = library_type(element)
74
- else:
75
- self.element, layout_tag = get_datatype_and_layout(tensor)
76
- shape = get_tensor_shape(tensor)
77
- if stride is not None:
78
- self.layout = Layout(shape[::-1], stride[::-1])
79
- else:
80
- if layout_tag == LayoutType.RowMajor:
81
- self.layout = Layout(shape[::-1])
82
- elif layout_tag == LayoutType.ColumnMajor:
83
- self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
84
- self.layout = canonicalization(self.layout)
85
-
86
- self.is_constant = is_constant
87
- # Save the tensor value if it is constant
88
- if is_constant and tensor is not None:
89
- self.value = tensor
90
-
91
- @property
92
- def shape(self):
93
- """
94
- Returns the RowMajor layout shape
95
- """
96
- return _reverse_tuple(self.layout.shape)
97
-
98
- @property
99
- def stride(self):
100
- """
101
- Returns the RowMajor layout stride
102
- """
103
- return _reverse_tuple(self.layout.stride)
104
-
105
- @property
106
- def rank(self):
107
- """
108
- Returns the rank of the tensor
109
- """
110
- return len(self.shape)
111
-
112
- #
113
- # Layout Algorithms
114
- #
115
-
116
- def broadcast(self, shape):
117
- """
118
- Broadcast self.layout to shape
119
- """
120
- assert isinstance(shape, tuple)
121
- self.layout = broadcast(self.layout, _reverse_tuple(shape))
122
-
123
- def reshape(self, shape):
124
- """
125
- Reshape self.layout to shape
126
- """
127
- assert isinstance(shape, tuple)
128
- reverse_shape = _reverse_tuple(shape)
129
- self.layout = reshape(self.layout, reverse_shape)
130
-
131
- def permute(self, indices):
132
- """
133
- Permute self.layout according to indices
134
- """
135
- length = len(indices)
136
- indices = [length - idx - 1 for idx in indices]
137
- self.layout = permutation(self.layout, indices[::-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py DELETED
@@ -1,42 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer
34
- from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType
35
- from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
36
- from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
37
- from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
38
- from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
39
- from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager
40
- from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
41
- from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
42
- from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py DELETED
@@ -1,143 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
- from __future__ import annotations
33
-
34
- import subprocess
35
-
36
- from cutlass_library import DataTypeTag
37
-
38
- from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
39
-
40
-
41
- _COLOR_MAP = {
42
- "load": '"AliceBlue"',
43
- "compute": "LemonChiffon1",
44
- "accumulator": "LightGrey",
45
- "store": "PowderBlue",
46
- "layout": "lightseagreen",
47
- "dag": "darkorange"
48
- }
49
-
50
-
51
- class EVTGraphDrawer:
52
- """
53
- Visualize a EVT DAGIR with graphviz
54
- """
55
- def __init__(
56
- self,
57
- graph: DAGIR,
58
- name: str
59
- ):
60
- self._name = name
61
- self._dot_graphs = {}
62
-
63
- self._dot_graphs[name] = self._to_dot(graph, name)
64
-
65
- def _get_node_style(self, node):
66
- template = {
67
- "shape": "record",
68
- "fillcolor": "#CAFFE3",
69
- "style": '"filled,rounded"',
70
- "fontcolor": "#000000",
71
- }
72
- if node.op in _COLOR_MAP:
73
- template["fillcolor"] = _COLOR_MAP[node.op]
74
- else:
75
- raise NotImplementedError("unknown node op")
76
- if node.disabled:
77
- template["fontcolor"] = "grey"
78
- template["fillcolor"] = "white"
79
- return template
80
-
81
- def _get_node_label(self, node):
82
- label = "{" + f"name={node.name}|op={node.op}"
83
- if node.op == "layout":
84
- label += f"|fn={node.fn.__name__}"
85
- for key in node.kwargs:
86
- label += f"|{key}={node.kwargs[key]}"
87
- if node.underlying_impl is not None:
88
- label += f"|impl={type(node.underlying_impl).__name__}"
89
- if node.op == "load":
90
- label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
91
- elif node.op == "compute":
92
- label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
93
- elif node.op == "store":
94
- label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
95
- elif node.op == "dag":
96
- label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
97
- if node.tensor is not None:
98
- shape = node.tensor.shape
99
- stride = node.tensor.stride
100
- label += f"|shape={shape}|stride={stride}"
101
-
102
- if hasattr(node, "store_tensor"):
103
- if node.store_tensor is not None:
104
- store_shape = node.store_tensor.shape
105
- store_stride = node.store_tensor.stride
106
- label += f"|store_shape={store_shape}|stride_stride={store_stride}"
107
-
108
- label += "}"
109
- return label
110
-
111
- def _to_dot(
112
- self,
113
- graph: DAGIR,
114
- name: str
115
- ):
116
- import pydot
117
- dot_graph = pydot.Dot(name, randir="TB")
118
- for node in graph.nodes_meta:
119
- style = self._get_node_style(node)
120
- label = self._get_node_label(node)
121
- dot_node = pydot.Node(
122
- node.name, label=label, **style
123
- )
124
- dot_graph.add_node(dot_node)
125
- if node.op == "dag":
126
- dot_subgraph = self._to_dot(node.subgraph, name=node.name)
127
- self._dot_graphs[node.name] = dot_subgraph
128
-
129
- # Add edges
130
- for src, dst in graph.edges:
131
- weight = graph.get_edge_weight(src, dst)
132
- dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
133
-
134
- return dot_graph
135
-
136
- def get_dot_graph(self) -> pydot.Dot:
137
- return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]
138
-
139
- def get_dot_graph_by_name(self, name) -> pydot.Dot:
140
- return self._dot_graphs[name]
141
-
142
- def get_main_dot_graph(self) -> pydot.Dot:
143
- return self._dot_graphs[self._name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py DELETED
@@ -1,120 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Construct the epilogue visitor argument type
35
- """
36
-
37
- from cutlass_cppgen.backend.c_types import visitor_factory
38
- from cutlass_cppgen.backend.evt.ir import TopoVisitorNode
39
- from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
40
- from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
41
- from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
42
- from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
43
- from cutlass_cppgen.backend.evt.passes.util import cc_map
44
-
45
-
46
- class PassGetArgumentType(EVTPassBase):
47
- """
48
- Construct the epilogue visitor argument type
49
- """
50
- dependencies = [
51
- PassShapeTypePropagation, # The Layout of all nodes must be set
52
- PassDAG2Tree, # The type of each node must be set
53
- PassGetImpl # The DAG subgraphs must be set
54
- ]
55
-
56
- def requires(self) -> None:
57
- # Check "D" is in the node list
58
- if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")):
59
- raise SyntaxError(
60
- "Sm90+ EVT requires the epilogue to have a returned tensor D, "
61
- "but the variable 'D' is not found in the return values.")
62
-
63
- def call(self):
64
- nodes = self.dag_ir.nodes_topological_order()
65
- self.argument_types = {}
66
- for node in nodes:
67
- meta = self.dag_ir.get_node_meta(node)
68
- if not meta.disabled:
69
- self.argument_types[node] = meta.underlying_impl.argument_type
70
- if node == "D" and cc_map[self.cc] in [90, 100]:
71
- continue
72
- if isinstance(meta, TopoVisitorNode):
73
- self.get_dag_argument_type(node)
74
- else:
75
- self.get_evt_argument_type(node)
76
-
77
- self.cc_specific_method(self.set_argument_type)()
78
-
79
- def get_evt_argument_type(self, node):
80
- # Sort the input nodes by edge weight
81
- input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)]
82
- if len(input_types) > 0:
83
- self.argument_types[node] = visitor_factory(
84
- input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,])
85
-
86
- def get_dag_argument_type(self, node):
87
- meta = self.dag_ir.get_node_meta(node)
88
- subgraph = meta.subgraph
89
- subgraph_nodes = subgraph.nodes_topological_order()
90
- # Visit the unvisited nodes in subgraph
91
- for n in subgraph_nodes:
92
- m = subgraph.get_node_meta(n)
93
- if m.disabled:
94
- continue
95
- else:
96
- self.argument_types[n] = m.underlying_impl.argument_type
97
- input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
98
- if len(input_types) > 0:
99
- self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1])
100
-
101
- def set_argument_type(self):
102
- pass
103
-
104
- def sm90_set_argument_type(self):
105
- self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]]
106
- # Get the tensorD argument type
107
- self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d
108
-
109
- # Get the tensorC argument type
110
- if self.dag_ir.has_node("C"):
111
- self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c
112
- else:
113
- self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
114
-
115
- def sm100_set_argument_type(self):
116
- self.sm90_set_argument_type()
117
-
118
- def sm80_set_argument_type(self):
119
- nodes = self.dag_ir.nodes_topological_order()
120
- self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py DELETED
@@ -1,169 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
35
- by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
36
- """
37
-
38
- from copy import deepcopy
39
-
40
- from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode
41
- from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
42
- from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
43
- from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
44
-
45
-
46
- class PassDAG2Tree(EVTPassBase):
47
- """
48
- Convert the DAG IR to Tree by fusing subgraphs
49
- """
50
- dependencies = [
51
- PassShapeTypePropagation,
52
- PassGetImpl
53
- ]
54
-
55
- def call(self):
56
- # Step 1: find the nodes that have multiple parents
57
- multi_parent_nodes = []
58
-
59
- for node in self.dag_ir.nodes_topological_order():
60
- if self.dag_ir.out_degree(node) > 1:
61
- multi_parent_nodes.append(node)
62
- # Step 2: find the lowest common ancestor (LCA) of all its parents
63
- for node in multi_parent_nodes:
64
- # A multi-parent node could be already fused by the previous node
65
- if not self.dag_ir.has_node(node):
66
- continue
67
- # A node uncovered by the previous fusions can have out degree change
68
- # Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
69
- # Case 2: it has more than one edges to the previously fused subgraph, degree drops
70
- if self.dag_ir.out_degree(node) <= 1:
71
- continue
72
-
73
- # Otherwise, the node still
74
- reachable_nodes = []
75
- # Complexity: O(Dout*N)
76
- for parent in self.dag_ir.get_users(node):
77
- reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
78
- # get the common reachable objects
79
- common_items = set.intersection(*reachable_nodes)
80
- node_to_fuse = set.union(*reachable_nodes).difference(common_items)
81
-
82
- lca = None
83
- # If common ancestor exists, find the lowest one
84
- if len(common_items) > 0:
85
- topo_order = self.dag_ir.nodes_topological_order()
86
- topo_idx = -1
87
- for item in common_items:
88
- if lca is None:
89
- lca = item
90
- topo_idx = topo_order.index(item)
91
- else:
92
- if topo_idx > topo_order.index(item):
93
- lca = item
94
- topo_idx = topo_order.index(item)
95
- else:
96
- # there is no common ancestor for all the parents, we pack all the reachable
97
- # nodes into a single DAG node as a fallback. The lca should be the input node of
98
- # one of the output nodes with out_degree = 0
99
- potential_output_nodes = []
100
- for node in node_to_fuse:
101
- if self.dag_ir.out_degree(node) == 0:
102
- potential_output_nodes.append(node)
103
- if len(potential_output_nodes) == 0:
104
- raise RuntimeError(f"No output node with out degree = 0 found.")
105
-
106
- output_node = None
107
- if (self.dag_ir.cc >= 90):
108
- # For SM90+, the lca should be the input node of D
109
- if (not self.dag_ir.has_node("D")):
110
- raise RuntimeError(f"D is not a node in the DAG IR.")
111
- output_node = "D"
112
- else:
113
- output_node = potential_output_nodes[0]
114
-
115
- if (output_node is None):
116
- raise RuntimeError(f"No output node found.")
117
- lca = self.dag_ir.get_all_inputs(output_node)[0]
118
- node_to_fuse.remove(output_node)
119
-
120
- # The lca is the output node of the DAG node
121
- # Get the nodes to be fused
122
- node_to_fuse.add(lca)
123
- # Get all the input nodes
124
- all_input_nodes = []
125
- all_output_nodes = []
126
- for node in node_to_fuse:
127
- all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
128
- all_output_nodes.append(set(self.dag_ir.get_users(node)))
129
- all_input_nodes = set.union(*all_input_nodes)
130
- all_output_nodes = set.union(*all_output_nodes)
131
-
132
- new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
133
-
134
- # Create the subgraph
135
- subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
136
- subgraph = DAGIR(self.dag_ir.cc)
137
- for node in subgraph_.nodes:
138
- meta = deepcopy(self.dag_ir.get_node_meta(node))
139
- if node not in node_to_fuse:
140
- meta.disabled = True
141
- subgraph.add_node(meta)
142
- for edge in subgraph_.edges:
143
- subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
144
-
145
-
146
- # Create the fused node
147
- dag_node = TopoVisitorNode(
148
- name=f"dag_{lca}", subgraph=subgraph,
149
- output_node=self.dag_ir.get_node_meta(lca))
150
- self.dag_ir.add_node(dag_node)
151
-
152
- # Add input edges
153
- for idx, node in enumerate(all_input_nodes):
154
- self.dag_ir.add_edge(node, dag_node.name, weight=idx)
155
-
156
- # Replace all uses with DAG node (only 1 output node)
157
- self.dag_ir.replace_all_uses_with(lca, dag_node.name)
158
-
159
- # Remove all fused nodes
160
- node_to_fuse.remove(lca)
161
- for node in node_to_fuse:
162
- self.dag_ir.remove_node(node)
163
-
164
- def ensures(self) -> None:
165
- # Ensure that after the pass, the resulting DAG becomes a tree
166
- for node in self.dag_ir.nodes:
167
- out_degree = self.dag_ir.out_degree(node)
168
- if out_degree > 1:
169
- raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py DELETED
@@ -1,64 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Fix the element_output of producer of D.
35
-
36
- In Sm90 epilogue visitor, the node writing D to gmem does not have internal
37
- element converter, so the compute node producing D must have element_output = type(D).
38
- """
39
-
40
- from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
41
- from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
42
-
43
-
44
- class PassFixElementD(EVTPassBase):
45
- """
46
- In Sm90 epilogue visitor, the node writing D to gmem does not have internal
47
- element converter, so the compute node producing D must have
48
- element_output = type(D)
49
- """
50
- dependencies = [
51
- PassLayoutManipulateElimination
52
- ]
53
- def get_producer(self, node, element_D):
54
- node_meta = self.dag_ir.get_node_meta(node)
55
- if node_meta.op == "compute":
56
- node_meta.element_output = element_D
57
- elif node_meta.op == "store":
58
- self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D)
59
-
60
- def call(self):
61
- if self.dag_ir.has_node("D"):
62
- node_d_meta = self.dag_ir.get_node_meta("D")
63
- element_D = node_d_meta.store_tensor.element
64
- self.get_producer("D", element_D)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py DELETED
@@ -1,90 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Infer the underlying implement of each node.
35
-
36
- While the frontend only distinguish between Load/Store/Compute Node,
37
- each of these nodes can have different underlying implementation based
38
- on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
39
- This pass infers the underlying impl of each node
40
- """
41
-
42
- import cutlass_cppgen.backend.evt.backend as evt_backend
43
- from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode
44
- from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
45
- from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
46
- from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
47
- from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
48
- from cutlass_cppgen.backend.evt.passes.util import cc_map
49
-
50
-
51
- class PassGetImpl(EVTPassBase):
52
- """
53
- While the frontend only distinguish between Load/Store/Compute Node,
54
- each of these nodes can have different underlying implementation based
55
- on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
56
- This pass infers the underlying impl of each node
57
- """
58
- dependencies = [
59
- PassShapeTypePropagation, # The shape and type info are required for inference
60
- PassFixElementD
61
- ]
62
-
63
- def __init__(self, dag_ir: DAGIR) -> None:
64
- super().__init__(dag_ir)
65
- self.no_op_elimination = PassNoOpElimination(dag_ir)
66
-
67
- def requires(self) -> None:
68
- # Verify "accum" is in the arg list
69
- if not self.dag_ir.has_node("accum"):
70
- raise SyntaxError("Cannot find 'accum' in the argument list.")
71
-
72
- def call(self):
73
- # The loop structure of the epilogue is determined by the
74
- # accumulator shape
75
- accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
76
- problem_size = accumulator.tensor.shape
77
-
78
- for node_meta in self.dag_ir.node_metas_topological_order():
79
- node_meta.get_underlying_impl(problem_size)
80
-
81
- def ensures(self) -> None:
82
- # Some nodes will be lowered to NoOp, eliminate them
83
- self.no_op_elimination()
84
- # Lower to cc-specific impl
85
- for node_meta in self.dag_ir.nodes_meta:
86
- node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
87
- node_meta.underlying_impl = getattr(
88
- node_impl_ccs,
89
- f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__
90
- )(node_meta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py DELETED
@@ -1,217 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Eliminate layout manipulation nodes
35
- """
36
-
37
- from copy import deepcopy
38
-
39
- from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode
40
- from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
41
- from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
42
-
43
-
44
- class PassLayoutManipulateElimination(EVTPassBase):
45
- """
46
- Eliminate layout manipulation nodes
47
- """
48
- dependencies = [PassShapeTypePropagation]
49
-
50
- def __init__(self, dag_ir: DAGIR) -> None:
51
- super().__init__(dag_ir)
52
- self.copy_cnt = 0
53
-
54
- def call(self):
55
- self.layout_nodes_worklist = self.get_all_layout_nodes()
56
- # Run while loop utill all layout nodes are eliminated
57
- while(len(self.layout_nodes_worklist) > 0):
58
- node = self.layout_nodes_worklist.pop(0)
59
- # for node in layout_nodes:
60
- # Step 1: get the propagation direction
61
- direction = self.get_propagation_direction(node)
62
- self.visited = []
63
- getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node)
64
- # Eliminate the current node
65
- input_node = self.dag_ir.get_all_inputs(node)[0]
66
- self.dag_ir.replace_all_uses_with(node, input_node)
67
- # layout_nodes = self.get_all_layout_nodes()
68
-
69
- def get_all_layout_nodes(self):
70
- layout_nodes = []
71
- for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
72
- if isinstance(node_meta, LayoutNode):
73
- layout_nodes.append(node_meta.name)
74
- return layout_nodes
75
-
76
- def get_propagation_direction(self, node: str):
77
- """
78
- The logic is propagating all layout nodes away from the accumulator node.
79
- """
80
- self.visited = []
81
- self.get_influenced_users(node)
82
- nodes_influenced_dir_users = self.visited
83
- self.visited = []
84
- self.get_influenced_inputs(node)
85
- nodes_influenced_dir_inputs = self.visited
86
-
87
- if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs:
88
- return "inputs"
89
- elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs:
90
- return "users"
91
- else:
92
- raise RuntimeError("Unsolved propagation direction")
93
-
94
- # Get all influenced nodes if we propagate along the user direction
95
- def get_influenced_users(self, node: str):
96
- if node in self.visited:
97
- return
98
- self.visited.append(node)
99
-
100
- users = self.dag_ir.get_users(node)
101
- for user in users:
102
- self.get_influenced_users(user)
103
- user_inputs = []
104
- for user in users:
105
- user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
106
- if len(user_inputs) > 0:
107
- user_inputs = set.union(*user_inputs)
108
- user_inputs.remove(node)
109
- for input in user_inputs:
110
- self.get_influenced_inputs(input)
111
-
112
- # Get all influenced nodes if we propagate along the input direction
113
- def get_influenced_inputs(self, node: str):
114
- if node in self.visited:
115
- return
116
- self.visited.append(node)
117
-
118
- inputs = self.dag_ir.get_all_inputs(node)
119
- for input in inputs:
120
- self.get_influenced_inputs(input)
121
- input_users = []
122
- for input in inputs:
123
- input_users.append(set(self.dag_ir.get_users(input)))
124
- if len(input_users) > 0:
125
- input_users = set.union(*input_users)
126
- input_users.remove(node)
127
- for user in input_users:
128
- self.get_influenced_users(user)
129
-
130
- def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
131
- copied_node_meta = deepcopy(layout_node_meta)
132
- copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
133
- self.copy_cnt += 1
134
- copied_node_meta.name = copied_node
135
- self.dag_ir.add_node(copied_node_meta)
136
- # Add edges
137
- target_inputs = self.dag_ir.get_all_inputs(target)
138
- for src in target_inputs:
139
- self.dag_ir.remove_edge(src, target)
140
- self.dag_ir.add_edge(src, copied_node)
141
- self.dag_ir.add_edge(copied_node, target)
142
- self.layout_nodes_worklist.append(copied_node)
143
-
144
- def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
145
- copied_node_meta = deepcopy(layout_node_meta)
146
- copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
147
- self.copy_cnt += 1
148
- copied_node_meta.name = copied_node
149
- self.dag_ir.add_node(copied_node_meta)
150
- # Add edges
151
- users = self.dag_ir.get_users(target)
152
- for user in users:
153
- self.dag_ir.remove_edge(target, user)
154
- self.dag_ir.add_edge(copied_node, user)
155
- self.dag_ir.add_edge(target, copied_node)
156
- self.layout_nodes_worklist.append(copied_node)
157
-
158
- # Propagate the layout `node` along the user direction
159
- def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
160
- """
161
- Propagate layout node to users
162
- """
163
- if node in self.visited:
164
- # Avoid applying twice
165
- return
166
- self.visited.append(node)
167
-
168
- node_meta = self.dag_ir.get_node_meta(node)
169
- if layout_node_meta.name != node:
170
- if isinstance(node_meta, LayoutNode):
171
- # Layout node is not transparent with layout node
172
- self.add_copy_before(layout_node_meta, node)
173
- return
174
- else:
175
- layout_node_meta.apply_to_user(node_meta)
176
-
177
- users = self.dag_ir.get_users(node)
178
- user_inputs = []
179
- for user in users:
180
- user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
181
- for user in users:
182
- self.propagate_to_users(layout_node_meta, user)
183
- if len(user_inputs) > 0:
184
- user_inputs = set.union(*user_inputs)
185
- user_inputs.remove(node)
186
- for input in user_inputs:
187
- self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input)
188
-
189
- # Propagate the layout `node` along the input direction
190
- def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
191
- """
192
- Propagate layout node to inputs
193
- """
194
- if node in self.visited:
195
- # Avoid applying twice
196
- return
197
- self.visited.append(node)
198
-
199
- node_meta = self.dag_ir.get_node_meta(node)
200
- if layout_node_meta.name != node:
201
- if isinstance(node_meta, LayoutNode):
202
- # Layout node is not transparent with layout node
203
- self.add_copy_after(layout_node_meta, node)
204
- return
205
- else:
206
- layout_node_meta.apply_to_input(node_meta)
207
- inputs = self.dag_ir.get_all_inputs(node)
208
- input_users = []
209
- for input in inputs:
210
- input_users.append(set(self.dag_ir.get_users(input)))
211
- for input in inputs:
212
- self.propagate_to_inputs(layout_node_meta, input)
213
- if len(input_users) > 0:
214
- input_users = set.union(*input_users)
215
- input_users.remove(node)
216
- for user in input_users:
217
- self.propagate_to_users(layout_node_meta.get_inverse_node(), user)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py DELETED
@@ -1,164 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Pass manager for DAG IR.
35
- """
36
-
37
- from typing import Any
38
-
39
- import networkx as nx
40
-
41
- from cutlass_cppgen.backend.evt.ir import DAGIR
42
- from cutlass_cppgen.backend.evt.passes.util import cc_map
43
-
44
-
45
- class EVTPassBase:
46
- """
47
- Base class for EVT Passes
48
- """
49
- dependencies = []
50
- def __init__(self, dag_ir: DAGIR) -> None:
51
- self.dag_ir = dag_ir
52
- self.cc = self.dag_ir.cc
53
-
54
- def requires(self) -> None:
55
- """
56
- This function will be called before the pass is run.
57
- """
58
- pass
59
-
60
- def call(self) -> None:
61
- """
62
- The pass that is run through the self.dag_ir
63
- """
64
- raise NotImplementedError(
65
- f"__call__ is not overwritten in Pass {self.__class__.__name__}")
66
-
67
- def ensures(self) -> None:
68
- """
69
- This function will be called after the pass is run.
70
- """
71
- pass
72
-
73
- def __call__(self) -> Any:
74
- self.requires()
75
- self.call()
76
- self.ensures()
77
-
78
- def cc_specific_method(self, func):
79
- """
80
- This enables defining function that behaves differently under different cc
81
- The simplest example of using this function is the following
82
-
83
- .. highlight:: python
84
- .. code-block:: python
85
-
86
- class ExamplePass(EVTPassBase):
87
-
88
- def call(sekf):
89
- # This automatically select the smXX_func based on current cc
90
- self.cc_specific_method(self.func)()
91
-
92
- # Interface func, can be empty
93
- def func(self):
94
- pass
95
-
96
- # Sm90 specific func
97
- def sm90_func(self):
98
- // sm90 specific method
99
- return
100
-
101
- # Sm80 specific func
102
- def sm80_func(self):
103
- // sm80 specific method
104
- return
105
- """
106
- func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
107
- if hasattr(self, func_name):
108
- return getattr(self, func_name)
109
- else:
110
- raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}")
111
-
112
-
113
- class EVTPassManager(nx.DiGraph):
114
- """
115
- Topological-based Pass Manager.
116
- Each registered pass has a list of dependencies. The pass manager organizes
117
- the passes as a DAG and launch the compiler passes under topological order.
118
- """
119
- def __init__(self, dag_ir: DAGIR, pass_list):
120
- super().__init__()
121
- self.dag_ir = dag_ir
122
- for pass_cls in pass_list:
123
- self.add_pass(pass_cls)
124
-
125
- self.sorted_passes = self.schedule()
126
-
127
- def get_callable(self, pass_name):
128
- """
129
- Return the callable of the pass
130
- """
131
- return self.nodes[pass_name]["callable"]
132
-
133
- def add_pass(self, pass_cls):
134
- """
135
- Add a pass to the pass manager
136
- :param pass_cls: the class of pass
137
- :type pass_cls: derived class of EVTPassBase
138
- """
139
- name = pass_cls.__name__
140
- pass_callable = pass_cls(self.dag_ir)
141
- self.add_node(name, callable=pass_callable)
142
-
143
- def schedule(self):
144
- """
145
- Schedule the added passes under topological order
146
- """
147
- # Add edges
148
- for pass_name in self.nodes:
149
- callable = self.get_callable(pass_name)
150
- for dependency_cls in callable.dependencies:
151
- self.add_edge(
152
- dependency_cls.__name__,
153
- type(callable).__name__)
154
-
155
- # Topological sort
156
- return list(nx.topological_sort(self))
157
-
158
- def __call__(self) -> Any:
159
- """
160
- Launch the registered passes
161
- """
162
- for pass_name in self.sorted_passes:
163
- callable = self.get_callable(pass_name)
164
- callable()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py DELETED
@@ -1,53 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- No op elimination node
35
- """
36
-
37
- from typing import Any
38
-
39
- from cutlass_cppgen.backend.evt.ir import NoOpImpl
40
- from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
41
-
42
-
43
- class PassNoOpElimination(EVTPassBase):
44
- """
45
- The dead node elimination pass removes nodes with NoOpImpl in DAG IR
46
- """
47
- dependencies = []
48
-
49
- def call(self) -> Any:
50
- for node in self.dag_ir.nodes_topological_order():
51
- node_meta = self.dag_ir.get_node_meta(node)
52
- if isinstance(node_meta.underlying_impl, NoOpImpl):
53
- self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py DELETED
@@ -1,97 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Preprocess the reduction nodes.
35
-
36
- The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
37
- This pass fuses these into a single store node, and then replaces all uses of the
38
- current node with the new store node.
39
- """
40
-
41
- from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode
42
- from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
43
-
44
-
45
- class PassPreprocessRed(EVTPassBase):
46
- """
47
- Preprocess red nodes
48
- """
49
-
50
- def call(self):
51
- # Step 1: find the compute nodes with op=red
52
- red_compute_nodes = []
53
- for node_meta in self.dag_ir.nodes_meta:
54
- if isinstance(node_meta, ComputeNode):
55
- if type(node_meta.fn) == tuple:
56
- # To keep the frontend simple, the reduction nodes
57
- # are parsed into compute nodes by default
58
- # The simple heuristic to distinguish between compute
59
- # and reduction node is that compute node is a single function,
60
- # while the reduction node is a tuple of functions for
61
- # in-register reduction and atomic global memory reduction
62
- red_compute_nodes.append(node_meta.name)
63
-
64
- # Step 2: for each compute, merge it with the succeeding store
65
- for node in red_compute_nodes:
66
- # Verify
67
- users = self.dag_ir.get_users(node)
68
- inputs = self.dag_ir.get_all_inputs(node)
69
- # Has a single user
70
- assert len(users) == 1
71
- assert len(inputs) == 1
72
- user = users[0]
73
- input = inputs[0]
74
-
75
- user_meta = self.dag_ir.get_node_meta(user)
76
- # Must be a store node
77
- assert isinstance(user_meta, StoreNode)
78
- # With output degree == 0
79
- assert self.dag_ir.out_degree(user) == 0
80
- # Register the reduce op
81
- node_meta = self.dag_ir.get_node_meta(node)
82
- user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
83
- user_meta.element_compute = node_meta.element_compute
84
- user_meta.round_style = node_meta.round_style
85
-
86
- # Replace all uses
87
- self.dag_ir.remove_edge(input, node)
88
- input_users = self.dag_ir.get_users(input)
89
- for iu in input_users:
90
- weight = self.dag_ir.get_edge_weight(input, iu)
91
- self.dag_ir.add_edge(user, iu, weight)
92
- self.dag_ir.remove_edge(input, iu)
93
- self.dag_ir.add_edge(input, user)
94
- self.dag_ir.remove_node(node)
95
-
96
- # Register the reduction name
97
- self.dag_ir.reduction_names.append(user)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py DELETED
@@ -1,59 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Shape and type propagation pass
35
- """
36
-
37
- from cutlass_cppgen.backend.evt.ir.node import NodeBase
38
- from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
39
- from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
40
-
41
-
42
- class PassShapeTypePropagation(EVTPassBase):
43
- """
44
- Propagate the shape and type of all nodes
45
- """
46
- dependencies = [PassPreprocessRed]
47
-
48
- def call(self):
49
- # Propagate the node shape and type
50
- for node in self.dag_ir.nodes_topological_order():
51
- node_meta: NodeBase = self.dag_ir.get_node_meta(node)
52
- input_node_metas = self.dag_ir.get_all_inputs_meta(node)
53
- node_meta.type_propagation(input_node_metas)
54
- node_meta.shape_propagation(input_node_metas)
55
-
56
- for node in reversed(self.dag_ir.nodes_topological_order()):
57
- node_meta: NodeBase = self.dag_ir.get_node_meta(node)
58
- input_node_metas = self.dag_ir.get_all_inputs_meta(node)
59
- node_meta.broadcast_propagation(input_node_metas)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py DELETED
@@ -1,319 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Compute the shared memory size in bytes
35
- """
36
-
37
- from math import gcd
38
-
39
- import cutlass_library
40
- from pycute import flatten, shape_div, product
41
-
42
- import cutlass_cppgen
43
- from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
44
- from cutlass_cppgen.backend.library import DataType, DataTypeSize
45
-
46
-
47
- class GetSmemSize:
48
- """
49
- Get the size in byte of shared memory used by the kernel
50
- """
51
- def __init__(self, dag_ir: DAGIR) -> None:
52
- self.dag_ir = dag_ir
53
- self.cc = self.dag_ir.cc
54
-
55
- #
56
- # Sm90 epilogue specific
57
- #
58
-
59
- def sm90_epilogue_tile(self, tile_description):
60
- # Get the epilogue tile size
61
- schedule = tile_description.epilogue_schedule
62
- if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized:
63
- element_d = self.dag_ir.get_node_meta("D").element
64
- nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32
65
- epi_tile_m = min(64, tile_description.threadblock_shape[0])
66
- epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
67
- epilogue_tile_mn = (epi_tile_m, epi_tile_n)
68
- elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative:
69
- epi_tile_m = min(128, tile_description.threadblock_shape[0])
70
- epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
71
- epilogue_tile_mn = (epi_tile_m, epi_tile_n)
72
- else:
73
- raise NotImplementedError(f"Unsupported schedule: {schedule}")
74
-
75
- # Get the pipeline stages
76
- stages_d = 2
77
- epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
78
- if self.dag_ir.has_node("C"):
79
- element_c = self.dag_ir.get_node_meta("C").element
80
- else:
81
- element_c = None
82
-
83
- element_d = self.dag_ir.get_node_meta("D").element
84
- if element_c == element_d:
85
- reuse_smem_c = True
86
- else:
87
- reuse_smem_c = False
88
- stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
89
-
90
- # Record the epilogue tile
91
- self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
92
- self.epilogue_tile_mn = epilogue_tile_mn
93
- self.epi_tiles = epi_tiles
94
- self.stages_c = stages_c
95
- self.stages_d = stages_d
96
- self.reuse_smem_c = reuse_smem_c
97
- self.element_c = element_c
98
- self.element_d = element_d
99
- self.is_source_supported = element_c is not None
100
-
101
- def sm90_or_sm100_epilogue_smem_size(self, tile_description):
102
- # Get the Fusion Storage
103
- nodes = self.dag_ir.nodes_topological_order()
104
- self.smem_types = {}
105
- for node in nodes:
106
- meta = self.dag_ir.get_node_meta(node)
107
- if not meta.disabled:
108
- self.smem_types[node] = meta.underlying_impl.get_smem_size(
109
- self.cta_tile_mnk, self.epilogue_tile_mn,
110
- self.stages_c, self.stages_d, self.epi_tiles)
111
- if node == "D":
112
- continue
113
- if isinstance(meta, TopoVisitorNode):
114
- self.get_dag_smem_type(node)
115
- else:
116
- self.get_evt_smem_type(node)
117
-
118
- thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
119
- # Get the Tensor Storage
120
- tensors = []
121
- if self.is_source_supported:
122
- smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8
123
- tensors.append((smem_C, 128))
124
- else:
125
- tensors.append((0, 1))
126
- if self.reuse_smem_c:
127
- tensors.append((0, 128))
128
- else:
129
- smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8
130
- tensors.append((smem_D, 128))
131
- tensors.append((thread_smem_size, 128))
132
-
133
- tensor_smem_size = self.get_struct_size(tensors)
134
- # Get pipeline storage size
135
- # sizeof(uint64_t * stages_c * 2), alignment of uint64_t
136
- # 2 is for FullBarrier and EmptyBarrier
137
- pipeline_smem_size = (8 * self.stages_c * 2, 8)
138
-
139
- # get SharedStorage size
140
- smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
141
- return smem_size[0]
142
-
143
- def sm90_epilogue_smem_size(self, tile_description):
144
- """
145
- Compute the shared memory size of sm90 collective epilogue
146
- """
147
- self.sm90_epilogue_tile(tile_description)
148
- return self.sm90_or_sm100_epilogue_smem_size(tile_description)
149
-
150
- #
151
- # Sm100 epilogue specific
152
- #
153
-
154
- def sm100_epilogue_tile(self, tile_description):
155
- cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1])
156
- mma_tile = cta_tile
157
-
158
- if tile_description.is_2sm:
159
- cta_tile = (cta_tile[0] // 2, cta_tile[1])
160
-
161
- if tile_description.is_2sm and mma_tile[0] == 128:
162
- tmem_warps = (2, 2)
163
- else:
164
- tmem_warps = (4, 1)
165
-
166
- if self.dag_ir.has_node("C"):
167
- element_c = self.dag_ir.get_node_meta("C").element
168
- element_c_size = DataTypeSize[element_c]
169
- else:
170
- element_c = None
171
- element_c_size = 0
172
-
173
- element_d = self.dag_ir.get_node_meta("D").element
174
-
175
- DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void
176
-
177
- CtaM = cta_tile[0]
178
- CtaN = cta_tile[1]
179
- WarpM = tmem_warps[0]
180
- WarpN = tmem_warps[1]
181
- MaxBits = max(element_c_size, DataTypeSize[element_d])
182
- DpFull = 32
183
- M = min(CtaM, DpFull * WarpM)
184
-
185
- if DisableSource:
186
- # Epilogues w/o residual load are less sensitive to smem allocation
187
- # Target a fixed amount of compute per epilogue iteration
188
- if MaxBits == 4:
189
- # Make epilogue tile larger to reduce the epilogue iterations.
190
- # 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
191
- ComputeElts = 8192
192
- Nperf = ComputeElts // M
193
- else:
194
- ComputeElts = 4096
195
- Nperf = ComputeElts // M
196
- else:
197
- # Epilogues w/ residual load are more sensitive to smem allocation
198
- # Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
199
- if MaxBits == 32:
200
- Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32
201
- elif MaxBits == 16:
202
- Nperf = 32 if CtaN <= 128 else 64
203
- else:
204
- Nperf = 64
205
-
206
- def is_m_major(layout):
207
- return flatten(layout.stride[0]) == 1
208
-
209
- if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout):
210
- N_min_C = 8 * WarpN
211
- elif element_c_size == 6:
212
- N_min_C = 128 * WarpN
213
- else:
214
- N_min_C = (128 // element_c_size) * WarpN
215
-
216
- if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout):
217
- N_min_D = 8 * WarpN
218
- elif DataTypeSize[element_d] == 6:
219
- N_min_D = 128 * WarpN
220
- else:
221
- N_min_D = (128 // DataTypeSize[element_d]) * WarpN
222
-
223
- N = min(CtaN, max(Nperf, N_min_C, N_min_D))
224
-
225
- tile_m = M
226
- tile_n_size = N // WarpN * WarpN
227
-
228
- epilogue_tile_mn = (tile_m, tile_n_size)
229
- epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
230
-
231
- stages_d = min(epi_tiles, 2)
232
- reuse_smem_c = (element_c_size > 8)
233
-
234
- if reuse_smem_c:
235
- stages_c = max(min(epi_tiles, 4), stages_d + 1)
236
- else:
237
- stages_c = min(epi_tiles, 4)
238
-
239
- # Record the epilogue tile
240
- self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
241
- self.epilogue_tile_mn = epilogue_tile_mn
242
- self.epi_tiles = epi_tiles
243
- self.stages_c = stages_c
244
- self.stages_d = stages_d
245
- self.reuse_smem_c = reuse_smem_c
246
- self.element_c = element_c
247
- self.element_d = element_d
248
- self.is_source_supported = not DisableSource
249
-
250
- def sm100_epilogue_smem_size(self, tile_description):
251
- """
252
- Compute the shared memory size of sm100 collective epilogue
253
- """
254
- self.sm100_epilogue_tile(tile_description)
255
- return self.sm90_or_sm100_epilogue_smem_size(tile_description)
256
-
257
- def __call__(self, tile_description):
258
- return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
259
-
260
- #
261
- # Helper functions
262
- #
263
-
264
- @staticmethod
265
- def get_visitor_size(members: list, ebo: bool):
266
- """
267
- Get the size of struct in bytes
268
- """
269
- offset = 0
270
- max_alignment = 1
271
- if len(members) > 0:
272
- # Get alignment
273
- for _, alignment in members:
274
- max_alignment = max(max_alignment, alignment)
275
-
276
- for type_size, _ in members:
277
- if type_size != 0:
278
- offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
279
- if type_size == 0 and not ebo:
280
- offset += 1
281
- else:
282
- offset += type_size
283
- offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
284
- return (offset, max_alignment)
285
- else:
286
- # Struct size is at least 1
287
- return (1, 1)
288
-
289
- def get_struct_size(self, members: list):
290
- """
291
- Get the size of struct in bytes
292
- """
293
- return self.get_visitor_size(members, False)
294
-
295
- def get_evt_smem_type(self, node):
296
- # Sort the input nodes by edge weight
297
- input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)]
298
- input_types.append(self.smem_types[node])
299
- if len(input_types) > 1:
300
- ebo = len(input_types) > 4
301
- self.smem_types[node] = self.get_visitor_size(input_types, ebo)
302
-
303
- def get_dag_smem_type(self, node):
304
- meta = self.dag_ir.get_node_meta(node)
305
- subgraph = meta.subgraph
306
- subgraph_nodes = subgraph.nodes_topological_order()
307
- # Visit the unvisited nodes in subgraph
308
- for n in subgraph_nodes:
309
- m = subgraph.get_node_meta(n)
310
- if m.disabled:
311
- continue
312
- else:
313
- self.smem_types[n] = m.underlying_impl.get_smem_size(
314
- self.cta_tile_mnk, self.epilogue_tile_mn,
315
- self.stages_c, self.stages_d, self.epi_tiles)
316
- input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
317
- if len(input_types) > 0:
318
- ebo = len(input_types) > 4
319
- self.smem_types[node] = self.get_visitor_size(input_types, ebo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py DELETED
@@ -1,46 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Utilities for passes
35
- """
36
-
37
- # Map from the CC of the kernel to the EVT implementation that the CC targets
38
- cc_map = {
39
- 80: 80,
40
- 86: 80,
41
- 89: 80,
42
- 90: 90,
43
- 100: 100,
44
- 101: 100,
45
- 103: 100,
46
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py DELETED
@@ -1,109 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
- from __future__ import annotations
33
-
34
- from cutlass_cppgen.utils.lazy_import import lazy_import
35
- cuda = lazy_import("cuda.cuda")
36
- import numpy as np
37
-
38
- from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
39
- from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
40
-
41
-
42
- class NumpyFrontend:
43
- """
44
- Frontend node for numpy
45
- """
46
-
47
- @staticmethod
48
- def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr:
49
- """Convert the input numpy tensor to CUDA device pointer
50
-
51
- :param np_tensor: input numpy nd array
52
- :param is_output: whether the tensor is output
53
-
54
- :return: CUDA device pointer
55
- """
56
- # copy the data to device
57
- if is_output:
58
- return device_mem_alloc(np_tensor.size * np_tensor.itemsize)
59
- else:
60
- return todevice(np_tensor)
61
-
62
-
63
- class TorchFrontend:
64
- """
65
- Frontend node for torch
66
- """
67
-
68
- @staticmethod
69
- def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr:
70
- """Convert the input torch tensor to CUDA device pointer
71
-
72
- :param torch_tensor: input torch tensor
73
- :param is_output: whether the tensor is output
74
-
75
- :return: CUDA device pointer
76
- """
77
-
78
- # check the device of torch_tensor
79
- if not torch_tensor.is_cuda:
80
- torch_tensor = torch_tensor.to("cuda")
81
-
82
- return cuda.CUdeviceptr(torch_tensor.data_ptr())
83
-
84
-
85
- class CupyFrontend:
86
- """
87
- Frontend node for cupy
88
- """
89
-
90
- @staticmethod
91
- def argument(cupy_ndarray: "cp.ndarray"):
92
- return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr))
93
-
94
-
95
- class TensorFrontend:
96
- """
97
- Universal Frontend for client-provide tensors
98
- """
99
-
100
- @staticmethod
101
- def argument(tensor, is_output=False):
102
- if is_numpy_tensor(tensor):
103
- return NumpyFrontend.argument(tensor, is_output)
104
- elif is_torch_tensor(tensor):
105
- return TorchFrontend.argument(tensor)
106
- elif is_cupy_tensor(tensor):
107
- return CupyFrontend.argument(tensor)
108
- else:
109
- raise NotImplementedError("Unknown Tensor Type")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py DELETED
@@ -1,2145 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
- from __future__ import annotations
33
-
34
- import copy
35
- import ctypes
36
- import enum
37
-
38
- from cutlass_cppgen.utils.lazy_import import lazy_import
39
- cuda = lazy_import("cuda.cuda")
40
- cudart = lazy_import("cuda.cudart")
41
- from cutlass_library import SubstituteTemplate
42
- import numpy as np
43
-
44
- from cutlass_library import (
45
- ComplexTransformTag,
46
- DataType,
47
- DataTypeNames,
48
- DataTypeSize,
49
- DataTypeTag,
50
- EpilogueScheduleSuffixes,
51
- EpilogueScheduleTag,
52
- EpilogueScheduleType,
53
- GemmKind,
54
- GemmKindNames,
55
- GemmUniversalMode,
56
- KernelScheduleSuffixes,
57
- KernelScheduleTag,
58
- KernelScheduleType,
59
- LayoutTag,
60
- LayoutType,
61
- MathOperation,
62
- MathOperationTag,
63
- OpcodeClass,
64
- OpcodeClassNames,
65
- OpcodeClassTag,
66
- OperationKind,
67
- ShortComplexLayoutNames,
68
- ShortDataTypeNames,
69
- ShortLayoutTypeNames,
70
- SwizzlingFunctor,
71
- SwizzlingFunctorTag,
72
- TileSchedulerSuffixes,
73
- TileSchedulerTag,
74
- TileSchedulerType,
75
- get_complex_from_real
76
- )
77
- from cutlass_cppgen.backend.arguments import ArgumentBase
78
- from cutlass_cppgen.backend.c_types import (
79
- GemmCoord_,
80
- GemmCoordBatched_,
81
- GenericMainloopArguments3x_,
82
- StrideBatched_,
83
- dim3_,
84
- get_gemm_arguments,
85
- get_gemm_arguments_3x,
86
- get_gemm_arguments_streamk,
87
- get_gemm_grouped_arguments,
88
- get_mainloop_arguments_3x,
89
- get_tile_scheduler_arguments_3x,
90
- )
91
- from cutlass_cppgen.backend.library import (
92
- ApiVersion,
93
- EmissionType,
94
- SchedulerMode,
95
- SchedulerModeTag,
96
- TensorDescription,
97
- TileDescription,
98
- api_version,
99
- )
100
- from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
101
- from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
102
- from cutlass_cppgen.backend.type_hint import GemmOperation, Tensor
103
- from cutlass_cppgen.backend.utils.device import device_sm_count
104
- from cutlass_cppgen.shape import GemmCoord, MatrixCoord
105
-
106
-
107
- ################################################################################
108
- #
109
- # Data structure modeling a GEMM operation
110
- #
111
- ################################################################################
112
-
113
-
114
- def leading_dimension(layout: LayoutType, shape: MatrixCoord) -> int:
115
- """
116
- Returns the leading dimenson of a tensor with layout ``layout`` and shape ``shape``.
117
-
118
- :param layout: layout of the tensor
119
- :type layout: cutlass_cppgen.shape.LayoutType
120
- :param shape: shape of the tensor
121
- :type shape: cutlass_cppgen.shape.MatrixCoord
122
-
123
- :return: leading dimension of the tensor
124
- :rtype: int
125
- """
126
- if layout == LayoutType.RowMajor:
127
- return shape.column
128
- elif layout == LayoutType.ColumnMajor:
129
- return shape.row
130
-
131
-
132
- def transpose_layout(layout: LayoutType) -> LayoutType:
133
- if layout == LayoutType.ColumnMajor:
134
- return LayoutType.RowMajor
135
- elif layout == LayoutType.RowMajor:
136
- return LayoutType.ColumnMajor
137
- else:
138
- raise ValueError(f"Unsupported Layout {layout}")
139
-
140
-
141
- class GemmArguments2x(ArgumentBase):
142
- """
143
- Argument wrapper for GEMM in CUTLASS 2. It encodes problem information and
144
- user-provide tensors into the kernel's argument
145
-
146
- :param operation: the GEMM operation to take the argument
147
- :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
148
- :class:`cutlass_cppgen.backend.GemmOperationGrouped`
149
-
150
- :param problem_size: GEMM problem size gemm(M, N, K)
151
- :type operation: :class:`cutlass_cppgen.shape.GemmCoord`
152
-
153
- :param A: tensor A
154
- :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
155
-
156
- :param B: tensor B
157
- :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
158
-
159
- :param C: tensor C
160
- :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
161
-
162
- :param D: tensor D
163
- :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
164
-
165
- :param gemm_mode: GEMM mode
166
- :type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
167
-
168
- :param output_op: output operator, optional
169
- :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
170
-
171
- :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
172
- :type stream: :class:`cuda.cuda.CUstream`
173
- """
174
-
175
- def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
176
- self.operation = operation
177
-
178
- self.layout_A = operation.A.layout
179
- self.layout_B = operation.B.layout
180
- self.layout_C = operation.C.layout
181
-
182
- self.element_A = operation.A.element
183
- self.element_B = operation.B.element
184
- self.element_C = operation.C.element
185
-
186
- if operation.C.layout in [LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32]:
187
- raise Exception("Interleaved layout not currently supported")
188
-
189
- if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch not in [90, 100, 101, 103]:
190
- super().__init__(A, B, None, None, **kwargs)
191
- else:
192
- super().__init__(A, B, C, D, **kwargs)
193
-
194
- if operation.switched:
195
- self.problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
196
- self.ptr_A, self.ptr_B = self.ptr_B, self.ptr_A
197
- else:
198
- self.problem_size = problem_size
199
- # If the number of elements in C = problem_size.n, C is treated as the bias
200
- if hasattr(self, "tensor_c_numel"):
201
- if self.tensor_c_numel == self.problem_size.n and self.problem_size.m != 1:
202
- self.bias = True
203
-
204
- self.lda = leading_dimension(self.layout_A, self.problem_size.mk)
205
- self.ldb = leading_dimension(self.layout_B, self.problem_size.kn)
206
- self.ldc = leading_dimension(self.layout_C, self.problem_size.mn)
207
- self.ldd = self.ldc
208
-
209
- if self.bias:
210
- self.ldc = 0
211
-
212
- if "output_op" in kwargs.keys() and gemm_mode != GemmUniversalMode.GemmSplitKParallel:
213
- self.output_op = kwargs["output_op"]
214
- else:
215
- if self.operation.epilogue_functor.element_epilogue in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]:
216
- dtype = int
217
- else:
218
- dtype = float
219
- self.output_op = self.operation.epilogue_type(dtype(1.0), dtype(0.0))
220
-
221
- self.gemm_mode = gemm_mode
222
- if gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]:
223
- if "split_k_slices" in kwargs.keys():
224
- self.batch_count = kwargs["split_k_slices"]
225
- else:
226
- self.batch_count = 1
227
- self.split_k_slices = self.batch_count
228
-
229
- if gemm_mode in [GemmUniversalMode.Batched, GemmUniversalMode.Array]:
230
- if "batch" in kwargs.keys():
231
- self.batch_count = kwargs["batch"]
232
- else:
233
- self.batch_count = 1
234
-
235
- if "batch_strides" in kwargs:
236
- self.batched_stride_A = kwargs["batch_strides"]["A"]
237
- self.batched_stride_B = kwargs["batch_strides"]["B"]
238
- self.batched_stride_C = kwargs["batch_strides"]["C"]
239
- self.batched_stride_D = kwargs["batch_strides"]["D"]
240
- else:
241
- self.batched_stride_A = self.problem_size.m * self.problem_size.k
242
- self.batched_stride_B = self.problem_size.n * self.problem_size.k
243
- self.batched_stride_C = self.problem_size.m * self.problem_size.n
244
- self.batched_stride_D = self.problem_size.m * self.problem_size.n
245
-
246
- if self.bias:
247
- self.batched_stride_C = self.problem_size.n
248
-
249
- if gemm_mode == GemmUniversalMode.Array:
250
- self.ptr_A_array = []
251
- self.ptr_B_array = []
252
- self.ptr_C_array = []
253
- self.ptr_D_array = []
254
-
255
- ptr_A_addr = int(self.ptr_A)
256
- ptr_B_addr = int(self.ptr_B)
257
- ptr_C_addr = int(self.ptr_C)
258
- ptr_D_addr = int(self.ptr_D)
259
-
260
- stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8
261
- stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8
262
- stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8
263
- stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8
264
- for _ in range(self.batch_count):
265
- self.ptr_A_array.append(ptr_A_addr)
266
- self.ptr_B_array.append(ptr_B_addr)
267
- self.ptr_C_array.append(ptr_C_addr)
268
- self.ptr_D_array.append(ptr_D_addr)
269
-
270
- ptr_A_addr += stride_A
271
- ptr_B_addr += stride_B
272
- ptr_C_addr += stride_C
273
- ptr_D_addr += stride_D
274
-
275
- self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64)
276
- self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64)
277
- self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64)
278
- self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64)
279
-
280
- if isinstance(self.operation, GemmOperationUniversal):
281
- self.initialize()
282
-
283
- def get_arguments(self):
284
- problem_size_ = self.problem_size.ctype
285
- grid_tiled_shape_ = GemmCoord(
286
- self.grid_tiled_shape.x,
287
- self.grid_tiled_shape.y,
288
- self.grid_tiled_shape.z ).ctype
289
-
290
- if self.gemm_mode == GemmUniversalMode.Array:
291
- arguments = self.operation.argument_type(
292
- # Arguments from UniversalArgumentsBase
293
- self.gemm_mode,
294
- problem_size_,
295
- self.batch_count,
296
- 0,
297
- # Remaining arguments
298
- self.output_op,
299
- int(self.ptr_A_array_buffer.ptr),
300
- int(self.ptr_B_array_buffer.ptr),
301
- int(self.ptr_C_array_buffer.ptr),
302
- int(self.ptr_D_array_buffer.ptr),
303
- 0, 0, 0,
304
- self.lda, self.ldb, self.ldc, self.ldd,
305
- self.lda, self.ldb, self.ldc, self.ldd,
306
- 0, 0, 0
307
- )
308
- else:
309
- arguments = self.operation.argument_type(
310
- # Arguments from UniversalArgumentsBase
311
- self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D,
312
- # Remaining arguments
313
- self.output_op,
314
- int(self.ptr_A),
315
- int(self.ptr_B),
316
- int(self.ptr_C),
317
- int(self.ptr_D),
318
- self.batched_stride_A,
319
- self.batched_stride_B,
320
- self.batched_stride_C,
321
- self.lda, self.ldb, self.ldc, self.ldd,
322
- self.lda, self.ldb, self.ldc, self.ldd,
323
- 0, 0, 0
324
- )
325
-
326
- self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size
327
-
328
- def initialize(self):
329
- launch_config = self.operation.rt_module.plan(self)
330
-
331
- # Get the host and device workspace
332
- device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
333
-
334
- if device_workspace_size > 0:
335
- self.workspace_buffer = device_mem_alloc(device_workspace_size)
336
- workspace_ptr = self.workspace_buffer.ptr
337
- err, = cuda.cuMemsetD32(
338
- workspace_ptr, 0, device_workspace_size // 4)
339
- else:
340
- workspace_ptr = None
341
-
342
- device_workspace = 0
343
- if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
344
- # In GEMM splik-K parallel, the D pointer is redirected to the workspace
345
- self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
346
- elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
347
- device_workspace = workspace_ptr
348
-
349
- self.get_arguments()
350
-
351
- arguments, grid_tiled_shape, gemm_k_size = self.arguments
352
- res_arg = self.operation.rt_module.get_args(
353
- ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace)))
354
- host_workspace = bytearray(res_arg.contents)
355
-
356
- device_workspace = None
357
-
358
- self.host_workspace = host_workspace
359
- self.device_workspace = device_workspace
360
- self.launch_config = launch_config
361
-
362
- def sync(self, stream_sync=True):
363
- super().sync(stream_sync)
364
- if hasattr(self.output_op, "sync"):
365
- self.output_op.sync()
366
-
367
-
368
- class GemmArguments2xStreamK(GemmArguments2x):
369
- """
370
- Argument wrapper for stream-K GEMMs in CUTLASS 2. It encodes problem information and
371
- user-provide tensors into the kernel's argument
372
-
373
- :param operation: the GEMM operation to take the argument
374
- :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
375
- :class:`cutlass_cppgen.backend.GemmOperationGrouped`
376
-
377
- :param problem_size: GEMM problem size gemm(M, N, K)
378
- :type operation: :class:`cutlass_cppgen.shape.GemmCoord`
379
-
380
- :param A: tensor A
381
- :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
382
-
383
- :param B: tensor B
384
- :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
385
-
386
- :param C: tensor C
387
- :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
388
-
389
- :param D: tensor D
390
- :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
391
-
392
- :param gemm_mode: GEMM mode
393
- :type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
394
-
395
- :param output_op: output operator, optional
396
- :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
397
- """
398
-
399
- def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
400
- if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]:
401
- raise Exception(f"Unsupported GEMM mode {gemm_mode}.")
402
-
403
- super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
404
-
405
- def get_arguments(self):
406
- batch_stride_A = self.problem_size.m * self.problem_size.k
407
- batch_stride_B = self.problem_size.k * self.problem_size.n
408
- batch_stride_C = self.problem_size.m * self.problem_size.n
409
- batch_stride_D = self.problem_size.m * self.problem_size.n
410
-
411
- arguments = self.operation.argument_type(
412
- self.gemm_mode,
413
- GemmCoord_(self.problem_size.m, self.problem_size.n, self.problem_size.k),
414
- self.batch_count,
415
- self.output_op,
416
- int(self.ptr_A),
417
- int(self.ptr_B),
418
- int(self.ptr_C),
419
- int(self.ptr_D),
420
- batch_stride_A,
421
- batch_stride_B,
422
- batch_stride_C,
423
- batch_stride_D,
424
- self.lda, self.ldb, self.ldc, self.ldd, # strides
425
- self.lda, self.ldb, self.ldc, self.ldd,
426
- -1, # avail_sms
427
- )
428
- return arguments
429
-
430
- def initialize(self):
431
- # Get the host and device workspace
432
- device_workspace_size = self.operation.rt_module.get_device_workspace_size(
433
- self,
434
- device_sm_count(),
435
- self.operation.rt_module.occupancy
436
- )
437
-
438
- if device_workspace_size > 0:
439
- self.workspace_buffer = device_mem_alloc(device_workspace_size)
440
- workspace_ptr = self.workspace_buffer.ptr
441
- err, = cuda.cuMemsetD32(
442
- workspace_ptr, 0, device_workspace_size // 4)
443
- else:
444
- workspace_ptr = None
445
-
446
- device_workspace = 0
447
- if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
448
- # In GEMM splik-K parallel, the D pointer is redirected to the workspace
449
- self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
450
- elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
451
- device_workspace = workspace_ptr
452
-
453
- arguments = self.get_arguments()
454
-
455
- res_arg = self.operation.rt_module.get_args(
456
- ctypes.byref(arguments),
457
- ctypes.c_void_p(int(device_workspace)),
458
- device_sm_count(),
459
- self.operation.rt_module.occupancy
460
- )
461
- host_workspace = bytearray(res_arg.contents)
462
-
463
- grid = self.operation.rt_module.get_grid_shape(
464
- ctypes.byref(arguments),
465
- device_sm_count(),
466
- self.operation.rt_module.occupancy
467
- )
468
-
469
- device_workspace = None
470
-
471
- self.host_workspace = host_workspace
472
- self.device_workspace = device_workspace
473
- self.launch_config = LaunchConfiguration(
474
- [grid.m, grid.n, grid.k],
475
- [self.operation.rt_module.threads, 1, 1],
476
- self.operation.rt_module.shared_memory_capacity
477
- )
478
-
479
-
480
- class GemmArguments3x(GemmArguments2x):
481
- """
482
- Argument wrapper for GEMM in CUTLASS 3. It encodes problem information and
483
- user-provide tensors into the kernel's argument
484
-
485
- :param operation: the GEMM operation to take the argument
486
- :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
487
- :class:`cutlass_cppgen.backend.GemmOperationGrouped`
488
-
489
- :param problem_size: GEMM problem size gemm(M, N, K)
490
- :type operation: :class:`cutlass_cppgen.shape.GemmCoord`
491
-
492
- :param A: tensor A
493
- :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
494
-
495
- :param B: tensor B
496
- :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
497
-
498
- :param C: tensor C
499
- :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
500
-
501
- :param D: tensor D
502
- :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
503
-
504
- :param gemm_mode: GEMM mode
505
- :type gemm_mode: GemmUniversalMode
506
-
507
- :param output_op: output operator, optional
508
- :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
509
- """
510
-
511
- def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
512
- if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]:
513
- raise Exception(f"Unsupported GEMM mode {gemm_mode}.")
514
-
515
- super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
516
-
517
- def get_arguments(self):
518
- mainloop_args = get_mainloop_arguments_3x(
519
- self.operation.tile_description.kernel_schedule,
520
- self.operation.A.element,
521
- self.operation.B.element,
522
- self.operation.A.alignment,
523
- self.operation.B.alignment
524
- )
525
- scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler)
526
- uses_default_epilogue = self.operation.rt_module.uses_default_epilogue()
527
- argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x(
528
- mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue)
529
-
530
- problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count)
531
-
532
- if self.batch_count > 1:
533
- bsA = self.batched_stride_A
534
- bsB = self.batched_stride_B
535
- bsC = self.batched_stride_C
536
- bsD = self.batched_stride_D
537
- else:
538
- bsA = 0
539
- bsB = 0
540
- bsC = 0
541
- bsD = 0
542
- stride_A = StrideBatched_(self.lda, bsA)
543
- stride_B = StrideBatched_(self.ldb, bsB)
544
- stride_C = StrideBatched_(self.ldc, bsC)
545
- stride_D = StrideBatched_(self.ldd, bsD)
546
-
547
- # Superset of potential mainloop arguments
548
- generic_args = GenericMainloopArguments3x_(
549
- int(self.ptr_A),
550
- stride_A,
551
- int(self.ptr_B),
552
- stride_B,
553
- 4 # mma_promotion_interval
554
- )
555
-
556
- # Set of mainloop arguments needed for this kernel
557
- mainloop = mainloop_args.from_generic_mainloop_args(generic_args)
558
-
559
- if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"):
560
- self.output_op = self.output_op.to_evt_params()
561
-
562
- epilogue = epilogue_args(
563
- self.output_op,
564
- int(self.ptr_C),
565
- stride_C,
566
- int(self.ptr_D),
567
- stride_D,
568
- )
569
-
570
- # Set hardware info
571
- hw_info_ = hw_info(
572
- 0, device_sm_count(), 0,
573
- dim3_(0,0,0),
574
- dim3_(0,0,0),
575
- )
576
-
577
- self.arguments = argument_type(
578
- int(self.gemm_mode),
579
- problem_size_,
580
- mainloop,
581
- epilogue,
582
- hw_info_,
583
- scheduler_args
584
- )
585
- return self.arguments
586
-
587
- def initialize(self):
588
- # Get the host and evice workspace
589
- device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
590
-
591
- if device_workspace_size > 0:
592
- self.workspace_buffer = device_mem_alloc(device_workspace_size)
593
- workspace_ptr = self.workspace_buffer.ptr
594
- err, = cuda.cuMemsetD32(
595
- workspace_ptr, 0, device_workspace_size // 4)
596
- else:
597
- workspace_ptr = None
598
-
599
- device_workspace = 0
600
- if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
601
- # In GEMM splik-K parallel, the D pointer is redirected to the workspace
602
- self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
603
- elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
604
- device_workspace = workspace_ptr
605
-
606
- self.get_arguments()
607
- res_arg = self.operation.rt_module.get_args(
608
- ctypes.byref(self.arguments),
609
- ctypes.c_void_p(int(device_workspace)),
610
- )
611
- host_workspace = bytearray(res_arg.contents)
612
-
613
- grid = self.operation.rt_module.get_grid_shape(
614
- ctypes.byref(self.arguments),
615
- ctypes.c_void_p(int(device_workspace)),
616
- )
617
- block = self.operation.rt_module.get_block_shape()
618
-
619
- device_workspace = None
620
-
621
- self.host_workspace = host_workspace
622
- self.device_workspace = device_workspace
623
- self.launch_config = LaunchConfiguration(
624
- [grid.x, grid.y, grid.z],
625
- [block.x, block.y, block.z],
626
- self.operation.rt_module.shared_memory_capacity,
627
- )
628
-
629
-
630
- def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
631
- """
632
- Argument wrapper for GEMM in CUTLASS 2 or 3. It returns either 2x arguments
633
- or 3x arguments depending on the `arch` field specified in `operation`.
634
-
635
- :param operation: the GEMM operation to take the argument
636
- :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
637
- :class:`cutlass_cppgen.backend.GemmOperationGrouped`
638
-
639
- :param problem_size: GEMM problem size gemm(M, N, K)
640
- :type operation: :class:`cutlass_cppgen.shape.GemmCoord`
641
-
642
- :param A: tensor A
643
- :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
644
-
645
- :param B: tensor B
646
- :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
647
-
648
- :param C: tensor C
649
- :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
650
-
651
- :param D: tensor D
652
- :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
653
-
654
- :param gemm_mode: GEMM mode
655
- :type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
656
-
657
- :param output_op: output operator, optional
658
- :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
659
- """
660
- if operation.swizzling_functor == SwizzlingFunctor.StreamK:
661
- if operation.api == ApiVersion.v3x:
662
- raise Exception("Stream K is currently only supported in CUTLASS 2.x")
663
- ArgClass = GemmArguments2xStreamK
664
- else:
665
- ArgClass = GemmArguments3x if operation.api == ApiVersion.v3x else GemmArguments2x
666
- return ArgClass(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
667
-
668
-
669
- class GemmGroupedArguments:
670
- """
671
- Argument wrapper for GEMM Grouped. It encodes problem information and
672
- user-provide tensors into the kernel's argument
673
-
674
- :param operation: the GEMM Grouped operation to take the argument
675
- :type operation: :class:`cutlass_cppgen.backend.GemmOperationGrouped`
676
-
677
- :param problem_size: list of GEMM problem size gemm(M, N, K)
678
- :type operation: list[:class:`cutlass_cppgen.shape.GemmCoord`]
679
-
680
- :param A: list of tensor A
681
- :type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
682
-
683
- :param B: list of tensor B
684
- :type B: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
685
-
686
- :param C: list of tensor C
687
- :type C: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
688
-
689
- :param D: list of tensor D
690
- :type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
691
-
692
- :param output_op: output operator, optional
693
- :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
694
-
695
- :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
696
- :type stream: :class:`cuda.cuda.CUstream`
697
- """
698
-
699
- def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs):
700
- # Get number of problems in the group
701
- self.problem_count = len(problem_sizes)
702
-
703
- # Check the input arguments
704
- assert len(A) == self.problem_count
705
- assert len(B) == self.problem_count
706
- assert len(C) == self.problem_count
707
- assert len(D) == self.problem_count
708
-
709
- problem_size_host = []
710
- self.ptr_A_host = []
711
- self.ptr_B_host = []
712
- self.ptr_C_host = []
713
- self.ptr_D_host = []
714
-
715
- lda_host = []
716
- ldb_host = []
717
- ldc_host = []
718
- ldd_host = []
719
-
720
- self.partitions = 1
721
-
722
- self.operation = operation
723
-
724
- # Get the threadblock
725
- threadblock_shape = operation.tile_description.threadblock_shape
726
- self.threadblock_shape = GemmCoord(
727
- threadblock_shape[0],
728
- threadblock_shape[1],
729
- threadblock_shape[2],
730
- )
731
- self.threadblock_swizzle = operation.swizzling_functor
732
-
733
- self.total_tiles = 0
734
-
735
- self.gemm_arguments = []
736
-
737
- self.stream = kwargs.get("stream", cuda.CUstream(0))
738
-
739
- # Process the input arguments
740
- for idx, problem_size in enumerate(problem_sizes):
741
- M, N, K = problem_size.m, problem_size.n, problem_size.k
742
- temp_argument = GemmArguments2x(
743
- operation=operation,
744
- problem_size=GemmCoord(M, N, K),
745
- A=A[idx], B=B[idx], C=C[idx], D=D[idx])
746
- self.gemm_arguments.append(temp_argument)
747
-
748
- problem_size_host.append(
749
- [temp_argument.problem_size.m,
750
- temp_argument.problem_size.n,
751
- temp_argument.problem_size.k]
752
- )
753
-
754
- self.ptr_A_host.append(int(temp_argument.ptr_A))
755
- lda_host.append(temp_argument.lda)
756
-
757
- self.ptr_B_host.append(int(temp_argument.ptr_B))
758
- ldb_host.append(temp_argument.ldb)
759
-
760
- self.ptr_C_host.append(int(temp_argument.ptr_C))
761
- ldc_host.append(temp_argument.ldc)
762
-
763
- self.ptr_D_host.append(int(temp_argument.ptr_D))
764
- ldd_host.append(temp_argument.ldd)
765
-
766
- # Get number of tiles
767
- grid = self.operation.rt_module.get_grid_shape(
768
- self.operation.rt_module.get_tiled_shape(
769
- temp_argument.problem_size.ctype,
770
- self.threadblock_shape.ctype,
771
- temp_argument.batch_count
772
- )
773
- )
774
- self.total_tiles += grid.x * grid.y * grid.z
775
-
776
- self.problem_size_buffer = todevice(problem_size_host, np.int32)
777
- self.ptr_A_buffer = todevice(self.ptr_A_host, np.int64)
778
- self.ptr_B_buffer = todevice(self.ptr_B_host, np.int64)
779
- self.ptr_C_buffer = todevice(self.ptr_C_host, np.int64)
780
- self.ptr_D_buffer = todevice(self.ptr_D_host, np.int64)
781
-
782
- self.lda_buffer = todevice(lda_host, np.int64)
783
- self.ldb_buffer = todevice(ldb_host, np.int64)
784
- self.ldc_buffer = todevice(ldc_host, np.int64)
785
- self.ldd_buffer = todevice(ldd_host, np.int64)
786
-
787
- if "output_op" in kwargs.keys():
788
- self.alpha = kwargs["output_op"].alpha
789
- self.beta = kwargs["output_op"].beta
790
- else:
791
- self.alpha = 1.0
792
- self.beta = 0.0
793
-
794
- if "output_op" in kwargs.keys():
795
- self.output_op = kwargs["output_op"]
796
- else:
797
- self.output_op = self.operation.epilogue_type(1.0, 0.0)
798
-
799
- # Get host problem size
800
- self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0]
801
-
802
- self.arguments = self.get_arguments()
803
-
804
- self.initialize()
805
-
806
- def get_arguments(self):
807
- return self.operation.argument_type(
808
- self.problem_size_buffer.ptr,
809
- self.problem_count,
810
- self.total_tiles,
811
- self.output_op,
812
- self.ptr_A_buffer.ptr,
813
- self.ptr_B_buffer.ptr,
814
- self.ptr_C_buffer.ptr,
815
- self.ptr_D_buffer.ptr,
816
- self.lda_buffer.ptr,
817
- self.ldb_buffer.ptr,
818
- self.ldc_buffer.ptr,
819
- self.ldd_buffer.ptr,
820
- ctypes.c_void_p(int(self.host_problem_size_ptr)),
821
- )
822
-
823
- def initialize(self):
824
- # Get launch configuration
825
- launch_config = self.operation.rt_module.plan(self)
826
-
827
- # Get the host and evice workspace
828
- device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
829
-
830
- if device_workspace_size > 0:
831
- self.workspace_buffer = device_mem_alloc(device_workspace_size)
832
- workspace_ptr = self.workspace_buffer.ptr
833
- err, = cuda.cuMemsetD32(
834
- workspace_ptr, 0, device_workspace_size // 4)
835
- else:
836
- workspace_ptr = None
837
-
838
- if self.operation.precompute_mode == SchedulerMode.Host:
839
- device_workspace_ptr = self.operation.rt_module.host_precompute(
840
- self, self.operation.rt_module.get_workspace_size(self),)
841
- else:
842
- device_workspace_ptr = 0
843
-
844
- result = self.operation.rt_module.get_args(
845
- ctypes.byref(self.arguments),
846
- self.total_tiles,
847
- ctypes.c_void_p(int(device_workspace_ptr)),
848
- )
849
- host_workspace = bytearray(result.contents)
850
-
851
- device_workspace = None
852
-
853
- self.host_workspace = host_workspace
854
- self.device_workspace = device_workspace
855
- self.launch_config = launch_config
856
-
857
- def sync(self):
858
- err, = cudart.cudaDeviceSynchronize()
859
- if err != cuda.CUresult.CUDA_SUCCESS:
860
- raise RuntimeError("CUDA Error %s" % str(err))
861
- for arg in self.gemm_arguments:
862
- arg.sync(stream_sync=False)
863
-
864
-
865
- ################################################################################
866
- # Base class for GEMM runtime module
867
- ################################################################################
868
-
869
-
870
- class GemmRTbase(ExecutableOperation):
871
- """
872
- GemmRT manages the CUTLASS runtime components
873
- """
874
-
875
- KernelTemplate = r"""
876
- extern "C"
877
- __global__ void
878
- ${operation_name}(${operation_name}${operation_suffix}::Params params) {
879
-
880
- // Dynamic shared memory base pointer
881
- extern __shared__ int SharedStorageBase[];
882
-
883
- // Declare pointer to dynamic shared memory.
884
- ${operation_name}${operation_suffix}::SharedStorage *shared_storage =
885
- reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
886
-
887
- ${operation_name}${operation_suffix}::invoke(params, *shared_storage);
888
- }
889
- """
890
-
891
- def __init__(self, operation: "GemmOperation"):
892
- super().__init__(operation)
893
-
894
- self.operation = operation
895
- threadblock_shape = operation.tile_description.threadblock_shape
896
- self.threadblock_shape = GemmCoord(
897
- threadblock_shape[0], threadblock_shape[1], threadblock_shape[2])
898
- self.threadblock_swizzle = operation.swizzling_functor
899
-
900
- # Threads per threadblock
901
- self.threads = operation.tile_description.num_threads
902
-
903
- def emit(self):
904
- return self.emitter.emit(self.operation)
905
-
906
- def can_implement(self, configuration, arguments):
907
- raise NotImplementedError()
908
-
909
- def get_host_workspace_size(self, arguments):
910
- raise NotImplementedError()
911
-
912
- def get_device_workspace_size(self, arguments):
913
- return 0
914
-
915
- def initialize(self):
916
- err, = cuda.cuFuncSetAttribute(
917
- self.kernel,
918
- attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
919
- value=self.shared_memory_capacity)
920
- if err != cuda.CUresult.CUDA_SUCCESS:
921
- raise RuntimeError(
922
- f"CUDA error on call to cuFuncSetAttribute: {cuda.cuGetErrorString(err)[1]}"
923
- )
924
-
925
-
926
- ################################################################################
927
- # Runtime module for GEMM Universal
928
- ################################################################################
929
-
930
-
931
- class GemmRTUniversal(GemmRTbase):
932
- """
933
- GemmRTUniversal manages the CUTLASS runtime components
934
- """
935
-
936
- HostTemplate = r"""
937
- extern "C" {
938
- // Get the size of params in bytes
939
- int ${operation_name}_get_param_size(){
940
- return sizeof(${operation_name}${operation_suffix}::Params);
941
- }
942
-
943
- // Get the size of dynamic shared memory in bytes
944
- int ${operation_name}_shared_memory_size() {
945
- return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
946
- }
947
-
948
- // Get the params as byte array
949
- char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){
950
- ${operation_name}_base::Params* params;
951
- params = new ${operation_name}_base::Params(*argument,
952
- -1, // SM count. Only used for stream-K
953
- -1 // Occupancy. Only used for stream-K
954
- );
955
-
956
- // Semaphore holds the pointer to the workspace in the Params struct
957
- params->semaphore = workspace;
958
-
959
- char *bytes = ((char*)(params));
960
- char *output = new char[sizeof(${operation_name}_base::Params)];
961
- for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++)
962
- output[i] = bytes[i];
963
-
964
- return output;
965
- }
966
-
967
- cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape(
968
- cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) {
969
- return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape(
970
- problem_size, tile_size, split_k_slices);
971
- }
972
-
973
- dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) {
974
- return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape);
975
- }
976
- }
977
- """
978
-
979
- def __init__(self, operation):
980
- super(GemmRTUniversal, self).__init__(operation)
981
- self.extra_funcs = {
982
- "get_tiled_shape": GemmCoord_,
983
- "get_grid_shape": dim3_,
984
- }
985
- self.emitter = EmitGemmUniversalInstance(
986
- "_type", operation.direct_store)
987
-
988
- self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor)
989
- self.argtype = [
990
- ctypes.POINTER(self.argument_type),
991
- ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p
992
- ]
993
-
994
- def plan(self, arguments):
995
- grid = self.get_tiled_shape(
996
- arguments.problem_size.ctype,
997
- self.threadblock_shape.ctype,
998
- arguments.batch_count
999
- )
1000
-
1001
- gemm_k_size = arguments.problem_size.k
1002
- if arguments.gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]:
1003
- alignk = max(max(128 // DataTypeSize[self.operation.A.element],
1004
- 128 // DataTypeSize[self.operation.B.element]), 1)
1005
-
1006
- gemm_k_size = (((arguments.problem_size.k + arguments.batch_count - 1) //
1007
- arguments.batch_count + alignk - 1) // alignk) * alignk
1008
-
1009
- if gemm_k_size:
1010
- grid_z = (arguments.problem_size.k + gemm_k_size - 1) // gemm_k_size
1011
- grid = GemmCoord(grid.m, grid.n, grid_z).ctype
1012
-
1013
- arguments.grid_tiled_shape = dim3_(grid.m, grid.n, grid.k)
1014
- grid = self.get_grid_shape(grid)
1015
- arguments.gemm_k_size = gemm_k_size
1016
- return LaunchConfiguration(
1017
- [grid.x, grid.y, grid.z],
1018
- [self.threads, 1, 1],
1019
- self.shared_memory_capacity)
1020
-
1021
- def get_device_workspace_size(self, arguments: GemmArguments):
1022
- workspace_bytes = 0
1023
- if arguments.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
1024
- workspace_bytes = (DataTypeSize[arguments.operation.C.element]
1025
- * arguments.batched_stride_D * arguments.grid_tiled_shape.z // 8)
1026
- elif (arguments.gemm_mode == GemmUniversalMode.Gemm and
1027
- arguments.split_k_slices > 1):
1028
- workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y
1029
-
1030
- return workspace_bytes
1031
-
1032
-
1033
- class GemmRTUniversalStreamK(GemmRTUniversal):
1034
- """
1035
- Manages the CUTLASS runtime components for 2.x stream K kernels
1036
- """
1037
-
1038
- HostTemplate = r"""
1039
- extern "C" {
1040
- // Get the size of params in bytes
1041
- int ${operation_name}_get_param_size(){
1042
- return sizeof(${operation_name}${operation_suffix}::Params);
1043
- }
1044
-
1045
- // Get the size of dynamic shared memory in bytes
1046
- int ${operation_name}_shared_memory_size() {
1047
- return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
1048
- }
1049
-
1050
- using GemmType = ${operation_name}_base;
1051
-
1052
- // Get the params as byte array
1053
- char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace,
1054
- int sm_count, int occupancy) {
1055
- GemmType::Params* params;
1056
- params = new GemmType::Params(*argument, sm_count, occupancy);
1057
-
1058
- params->init_workspace(workspace);
1059
-
1060
- char *bytes = ((char*)(params));
1061
- char *output = new char[sizeof(GemmType::Params)];
1062
- for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++)
1063
- output[i] = bytes[i];
1064
-
1065
- return output;
1066
- }
1067
-
1068
- dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int device_sms, int sm_occupancy) {
1069
- typename GemmType::Params params(*args, device_sms, sm_occupancy);
1070
- return params.get_grid_dims();
1071
- }
1072
-
1073
- uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* args, int device_sms, int sm_occupancy) {
1074
- typename GemmType::Params params(*args, device_sms, sm_occupancy);
1075
- return params.get_workspace_size();
1076
- }
1077
- }
1078
- """
1079
-
1080
- def __init__(self, operation: "GemmOperation"):
1081
- super(GemmRTUniversalStreamK, self).__init__(operation)
1082
- self.extra_funcs = {
1083
- "get_grid_shape": GemmCoord_,
1084
- "get_kernel_workspace_size": ctypes.c_uint64,
1085
- }
1086
- self._occupancy = None
1087
- self.argument_type, self.epilogue_type = get_gemm_arguments_streamk(operation.epilogue_functor)
1088
-
1089
- @property
1090
- def occupancy(self):
1091
- if self._occupancy is None:
1092
- err, self._occupancy = cuda.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
1093
- self.kernel, self.threads, self.shared_memory_capacity,
1094
- cuda.CUoccupancy_flags.CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE)
1095
-
1096
- if err != cuda.CUresult.CUDA_SUCCESS:
1097
- raise RuntimeError(
1098
- "CUDA error on call to cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: "
1099
- f"{cuda.cuGetErrorString(err)[1]}")
1100
- return self._occupancy
1101
-
1102
- def get_device_workspace_size(self, arguments: GemmArguments2xStreamK, device_sms: int, sm_occupancy: int):
1103
- return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()), device_sms, sm_occupancy)
1104
-
1105
-
1106
- ################################################################################
1107
- # Runtime module for GEMM Universal within CUTLASS 3
1108
- ################################################################################
1109
-
1110
-
1111
- class GemmRTUniversal3x(GemmRTUniversal):
1112
- """
1113
- Manages the CUTLASS runtime components for 3.x kernels
1114
- """
1115
-
1116
- KernelTemplate = r"""
1117
-
1118
- using Operator = ${operation_name}${operation_suffix};
1119
- extern "C"
1120
- __global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor)
1121
- void ${operation_name}(__grid_constant__ typename Operator::Params const params) {
1122
- // Dynamic shared memory base pointer
1123
- extern __shared__ char smem[];
1124
-
1125
- // Declare pointer to dynamic shared memory.
1126
- Operator op;
1127
- op(params, smem);
1128
- }
1129
- """
1130
- HostTemplate = r"""
1131
- extern "C" {
1132
- // Get the size of params in bytes
1133
- int ${operation_name}_get_param_size(){
1134
- return sizeof(${operation_name}${operation_suffix}::Params);
1135
- }
1136
-
1137
- // Get the size of dynamic shared memory in bytes
1138
- int ${operation_name}_shared_memory_size() {
1139
- return ${operation_name}${operation_suffix}::SharedStorageSize;
1140
- }
1141
-
1142
- using GemmType = ${operation_name}_base;
1143
-
1144
- bool ${operation_name}_uses_default_epilogue() {
1145
- return std::is_same_v<GemmType::CollectiveEpilogue::DispatchPolicy, cutlass::gemm::EpilogueDefault>;
1146
- }
1147
-
1148
- // Get the workspace size
1149
- uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
1150
- return GemmType::get_workspace_size(*argument);
1151
- }
1152
-
1153
- // Get the params as byte array
1154
- char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){
1155
- GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace);
1156
- char *bytes = ((char*)(&params));
1157
- char *output = new char[sizeof(GemmType::Params)];
1158
- for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++)
1159
- output[i] = bytes[i];
1160
-
1161
- return output;
1162
- }
1163
-
1164
- // Get the total number of blocks for a persistent kernel
1165
- uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) {
1166
- auto problem_shape_MNKL = append<4>(problem, Int<1>{});
1167
- auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] =
1168
- cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl(
1169
- problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{});
1170
- return problem_blocks_m * problem_blocks_n * problem_blocks_l;
1171
- }
1172
-
1173
- // Get the grid shape
1174
- dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int* workspace) {
1175
- auto tmp_params = GemmType::to_underlying_arguments(*args, workspace);
1176
- return GemmType::get_grid_shape(tmp_params);
1177
- }
1178
-
1179
- // Get the block shape
1180
- dim3 ${operation_name}_get_block_shape() {
1181
- return GemmType::get_block_shape();
1182
- }
1183
- }
1184
- """
1185
-
1186
- def __init__(self, operation):
1187
- super(GemmRTUniversal3x, self).__init__(operation)
1188
- self.extra_funcs = {
1189
- "get_grid_shape": dim3_,
1190
- "get_block_shape": dim3_,
1191
- "get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
1192
- "get_kernel_workspace_size": ctypes.c_uint64,
1193
- "uses_default_epilogue": ctypes.c_bool,
1194
- }
1195
- self.emitter = EmitGemmUniversalInstance3x("_type")
1196
-
1197
- def get_device_workspace_size(self, arguments: GemmArguments3x):
1198
- return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))
1199
-
1200
-
1201
- class EmitGemmUniversalInstance3x:
1202
- """Responsible for emitting a CUTLASS 3 template definition"""
1203
-
1204
- def __init__(self, operation_suffix=""):
1205
- self.operation_suffix = operation_suffix
1206
- self.includes = [
1207
- "cutlass/cutlass.h",
1208
- "cute/tensor.hpp",
1209
- "cute/atom/mma_atom.hpp",
1210
- "cutlass/numeric_types.h",
1211
- "cutlass/gemm/collective/collective_builder.hpp",
1212
- "cutlass/gemm/kernel/sm90_tile_scheduler.hpp",
1213
- "cutlass/gemm/kernel/gemm_universal.hpp",
1214
- "cutlass/epilogue/collective/collective_builder.hpp",
1215
- "cutlass/epilogue/collective/default_epilogue.hpp",
1216
- "cutlass/epilogue/thread/linear_combination.h"
1217
- ]
1218
- self.gemm_template_kernel = """
1219
- using namespace cute;
1220
-
1221
- using CollectiveEpilogue =
1222
- typename cutlass::epilogue::collective::CollectiveBuilder<
1223
- ${arch}, ${opcode_class},
1224
- cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
1225
- cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
1226
- cutlass::epilogue::collective::EpilogueTileAuto,
1227
- ${element_accumulator}, ${element_epilogue},
1228
- ${element_c}, ${layout_c}, ${align_c},
1229
- ${element_d}, ${layout_d}, ${align_d},
1230
- ${epilogue_schedule}
1231
- >::CollectiveOp;
1232
-
1233
- using CollectiveMainloop =
1234
- typename cutlass::gemm::collective::CollectiveBuilder<
1235
- ${arch}, ${opcode_class},
1236
- ${element_a}, ${layout_a}, ${align_a},
1237
- ${element_b}, ${layout_b}, ${align_b},
1238
- ${element_accumulator},
1239
- cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
1240
- cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
1241
- ${stage_count_type},
1242
- ${kernel_schedule}
1243
- >::CollectiveOp;
1244
-
1245
- // Gemm operator ${operation_name}
1246
- using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
1247
- Shape<int,int,int,int>,
1248
- CollectiveMainloop,
1249
- CollectiveEpilogue,
1250
- ${tile_scheduler}
1251
- >;
1252
-
1253
- // Define named type
1254
- struct ${operation_name}${operation_suffix} :
1255
- public ${operation_name}_base { };
1256
- """
1257
- self.gemm_template_kernel_visitor = """
1258
- using namespace cute;
1259
-
1260
- ${callback_decl}
1261
-
1262
- using CollectiveEpilogue =
1263
- typename cutlass::epilogue::collective::CollectiveBuilder<
1264
- ${arch}, ${opcode_class},
1265
- cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
1266
- cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
1267
- cutlass::epilogue::collective::EpilogueTileAuto,
1268
- ${element_accumulator}, ${element_epilogue},
1269
- ElementC, StrideC, ${align_c},
1270
- ElementD, StrideD, ${align_d},
1271
- ${epilogue_schedule},
1272
- ${callback_name}
1273
- >::CollectiveOp;
1274
-
1275
- using CollectiveMainloop =
1276
- typename cutlass::gemm::collective::CollectiveBuilder<
1277
- ${arch}, ${opcode_class},
1278
- ${element_a}, ${layout_a}, ${align_a},
1279
- ${element_b}, ${layout_b}, ${align_b},
1280
- ${element_accumulator},
1281
- cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
1282
- cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
1283
- ${stage_count_type},
1284
- ${kernel_schedule}
1285
- >::CollectiveOp;
1286
-
1287
- // Gemm operator ${operation_name}
1288
- using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
1289
- Shape<int,int,int,int>,
1290
- CollectiveMainloop,
1291
- CollectiveEpilogue,
1292
- ${tile_scheduler}
1293
- >;
1294
-
1295
- // Define named type
1296
- struct ${operation_name}${operation_suffix} :
1297
- public ${operation_name}_base { };
1298
- """
1299
-
1300
- self.gemm_template_device = self.gemm_template_kernel + """
1301
-
1302
- // Define device-level operator
1303
- using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}${operation_suffix}>;
1304
- """
1305
-
1306
- def emit(self, operation):
1307
- # Support built-in epilogue functors or user-defined functions
1308
-
1309
- if operation.tile_description.stages is None or operation.tile_description.stages == 0:
1310
- stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>"
1311
- else:
1312
- stage_count_type = "_" + str(operation.tile_description.stages)
1313
-
1314
- if operation.emission_type == EmissionType.Kernel:
1315
- gemm_template = self.gemm_template_kernel
1316
- else:
1317
- gemm_template = self.gemm_template_device
1318
-
1319
- kschedule = KernelScheduleType.ScheduleAuto
1320
- eschedule = EpilogueScheduleType.ScheduleAuto
1321
- tschedule = TileSchedulerType.Default
1322
- if operation.tile_description.kernel_schedule is not None:
1323
- kschedule = operation.tile_description.kernel_schedule
1324
- if operation.tile_description.epilogue_schedule is not None:
1325
- eschedule = operation.tile_description.epilogue_schedule
1326
- if operation.tile_description.tile_scheduler is not None:
1327
- tschedule = operation.tile_description.tile_scheduler
1328
-
1329
- emit_tile_m, emit_tile_n, emit_tile_k = operation.tile_description.blackwell_threadblock_shape
1330
-
1331
- values = {
1332
- "operation_name": operation.procedural_name(),
1333
- "operation_suffix": self.operation_suffix,
1334
- "element_a": DataTypeTag[operation.A.element],
1335
- "layout_a": LayoutTag[operation.A.layout],
1336
- "element_b": DataTypeTag[operation.B.element],
1337
- "layout_b": LayoutTag[operation.B.layout],
1338
- "element_c": DataTypeTag[operation.C.element],
1339
- "layout_c": LayoutTag[operation.C.layout],
1340
- "element_d": DataTypeTag[operation.epilogue_functor.element_output],
1341
- "layout_d": LayoutTag[operation.C.layout],
1342
- "element_accumulator": DataTypeTag[operation.accumulator_type()],
1343
- "element_epilogue": DataTypeTag[operation.epilogue_functor.element_epilogue],
1344
- "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
1345
- "arch": "cutlass::arch::Sm%d" % operation.arch,
1346
- "threadblock_shape_m": str(emit_tile_m),
1347
- "threadblock_shape_n": str(emit_tile_n),
1348
- "threadblock_shape_k": str(emit_tile_k),
1349
- "cluster_m": str(operation.tile_description.cluster_shape[0]),
1350
- "cluster_n": str(operation.tile_description.cluster_shape[1]),
1351
- "cluster_k": str(operation.tile_description.cluster_shape[2]),
1352
- "align_a": str(operation.A.alignment),
1353
- "align_b": str(operation.B.alignment),
1354
- "align_c": str(operation.C.alignment),
1355
- "align_d": str(operation.C.alignment),
1356
- "stage_count_type": stage_count_type,
1357
- "kernel_schedule": KernelScheduleTag[kschedule],
1358
- "epilogue_schedule": EpilogueScheduleTag[eschedule],
1359
- "tile_scheduler": TileSchedulerTag[tschedule]
1360
- }
1361
- if hasattr(operation.epilogue_functor, "visitor"):
1362
- callback_name, callback_decl = operation.epilogue_functor.emit(operation)
1363
- values["callback_name"] = callback_name
1364
- values["callback_decl"] = callback_decl
1365
- return SubstituteTemplate(self.gemm_template_kernel_visitor, values)
1366
-
1367
- else:
1368
- values["epilogue_functor"] = operation.epilogue_functor.emit()
1369
- return SubstituteTemplate(gemm_template, values)
1370
-
1371
-
1372
- ###################################################################################################
1373
- # Runtime module for GEMM Grouped
1374
- ###################################################################################################
1375
-
1376
-
1377
- class GemmRTGrouped(GemmRTbase):
1378
- """
1379
- GemmRTGrouped manages the CUTLASS runtime components
1380
- """
1381
-
1382
- KernelTemplate = r"""
1383
- extern "C"
1384
- __global__ void
1385
- ${operation_name}(${operation_name}${operation_suffix}::Params params) {
1386
-
1387
- // Dynamic shared memory base pointer
1388
- extern __shared__ int SharedStorageBase[];
1389
-
1390
- // Declare pointer to dynamic shared memory.
1391
- ${operation_name}${operation_suffix}::SharedStorage *shared_storage =
1392
- reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
1393
-
1394
- ${operation_name}${operation_suffix} op;
1395
-
1396
- op(params, *shared_storage);
1397
- }
1398
- """
1399
-
1400
- HostTemplate = r"""
1401
- extern "C" {
1402
-
1403
- // precompute scheduling information
1404
- char * ${operation_name}_precompute(${operation_name}_base::Arguments const &args, int tile_count, size_t workspace_bytes) {
1405
- char* host_workspace = new char[workspace_bytes];
1406
- ${operation_name}_base::ProblemVisitor::host_precompute(
1407
- args.host_problem_sizes,
1408
- args.problem_count,
1409
- args.threadblock_count,
1410
- (void*)host_workspace
1411
- );
1412
- return host_workspace;
1413
- }
1414
-
1415
- // Get the size of params in bytes
1416
- int ${operation_name}_get_param_size(){
1417
- return sizeof(${operation_name}${operation_suffix}::Params);
1418
- }
1419
-
1420
- // Get the size of dynamic shared memory in bytes
1421
- int ${operation_name}_shared_memory_size() {
1422
- return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
1423
- }
1424
-
1425
- // Get the params as byte array
1426
- char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int tile_count, void* workspace=nullptr){
1427
- ${operation_name}_base::Params* params;
1428
- params = new ${operation_name}_base::Params(*argument, workspace, tile_count);
1429
-
1430
- char *bytes = ((char*)(params));
1431
- char *output = new char[sizeof(${operation_name}_base::Params)];
1432
- for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++)
1433
- output[i] = bytes[i];
1434
-
1435
- return output;
1436
- }
1437
-
1438
- cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape(
1439
- cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) {
1440
- return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape(
1441
- problem_size, tile_size, split_k_slices);
1442
- }
1443
-
1444
- dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) {
1445
- return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape);
1446
- }
1447
- }
1448
- """
1449
-
1450
- def __init__(self, operation: "GemmOperation"):
1451
- super(GemmRTGrouped, self).__init__(operation)
1452
- self.extra_funcs = {
1453
- "precompute": None,
1454
- "get_tiled_shape": GemmCoord_,
1455
- "get_grid_shape": dim3_,
1456
- }
1457
- self.emitter = EmitGemmGroupedInstance("_type")
1458
- self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor)
1459
- self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p]
1460
-
1461
- def host_precompute(self, arguments, workspace_bytes):
1462
- self.precompute.argtype = [
1463
- self.argtype[0], ctypes.c_int, ctypes.c_longlong]
1464
- self.precompute.restype = ctypes.POINTER(ctypes.c_byte * workspace_bytes)
1465
-
1466
- problem_info = self.precompute(
1467
- ctypes.byref(arguments.arguments),
1468
- arguments.total_tiles,
1469
- workspace_bytes)
1470
- problem_info_array = bytearray(problem_info.contents)
1471
-
1472
- # copy to device memory
1473
- return todevice(problem_info_array).ptr
1474
-
1475
- def plan(self, arguments):
1476
- return LaunchConfiguration(
1477
- [arguments.total_tiles, 1, 1],
1478
- [self.threads, 1, 1],
1479
- self.shared_memory_capacity,
1480
- )
1481
-
1482
- def get_workspace_size(self, arguments):
1483
- if self.operation.precompute_mode == SchedulerMode.Device:
1484
- return 0
1485
- elif self.operation.precompute_mode == SchedulerMode.Host:
1486
- total_tiles = arguments.total_tiles
1487
- entries_per_block = 1
1488
- return 8 * entries_per_block * total_tiles # three int32_t
1489
-
1490
-
1491
- ################################################################################
1492
- # Runtime module for GEMM and grouped GEMM
1493
- ################################################################################
1494
-
1495
-
1496
- class GemmOperationBase:
1497
- """
1498
- CUTLASS GEMM operation
1499
- """
1500
-
1501
- def __init__(
1502
- self, gemm_kind, arch, tile_description: TileDescription,
1503
- A: TensorDescription, B: TensorDescription, C: TensorDescription,
1504
- epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1,
1505
- api=ApiVersion.v2x, emission_type=EmissionType.Kernel, **kwargs):
1506
- self.operation_kind: OperationKind = OperationKind.Gemm
1507
- self.arch: int = arch
1508
- self.tile_description: TileDescription = tile_description
1509
- self.gemm_kind: GemmKind = gemm_kind
1510
-
1511
- self.api = api
1512
- self.prefix = "3x" if self.api == ApiVersion.v3x else ""
1513
- self.emission_type = emission_type
1514
-
1515
- # Optionally swap the TensorDescriptions for operands A and B and transpose their
1516
- # layouts. This is needed to mimic the transpose performed by device::GemmUniversal.
1517
- # The code below uses deep copy to avoid overwritting the original TensorDescription
1518
- self.switched = (self.api != ApiVersion.v3x and
1519
- self.emission_type == EmissionType.Kernel and
1520
- C.layout == LayoutType.ColumnMajor)
1521
-
1522
- self.A, self.B, self.C = GemmOperationBase.get_operands(A, B, C, self.switched)
1523
-
1524
- self.epilogue_functor = epilogue_functor
1525
- self.swizzling_functor = swizzling_functor
1526
-
1527
- if "direct_store" in kwargs:
1528
- self.direct_store = kwargs["direct_store"]
1529
- else:
1530
- self.direct_store = False
1531
-
1532
- @staticmethod
1533
- def get_operands(A: TensorDescription, B: TensorDescription, C: TensorDescription, swap: bool):
1534
- """
1535
- Makes copies of A, B, and C, and possibly transposes their order. If ``swap`` is set,
1536
- A and B are swapped, and the layout of A, B, and C are transposed.
1537
-
1538
- :param A: description of operand A
1539
- :type A: TensorDescription
1540
- :param B: description of operand B
1541
- :type B: TensorDescription
1542
- :param C: description of operand C
1543
- :type C: TensorDescription
1544
-
1545
- :return: descriptions of operands A, B, and C
1546
- :rtype: tuple[TileDescription]
1547
- """
1548
- if swap:
1549
- A_out = copy.deepcopy(B)
1550
- B_out = copy.deepcopy(A)
1551
- C_out = copy.deepcopy(C)
1552
- A_out.layout = transpose_layout(A_out.layout)
1553
- B_out.layout = transpose_layout(B_out.layout)
1554
- C_out.layout = transpose_layout(C_out.layout)
1555
- else:
1556
- A_out = copy.deepcopy(A)
1557
- B_out = copy.deepcopy(B)
1558
- C_out = copy.deepcopy(C)
1559
- return A_out, B_out, C_out
1560
-
1561
- def run(self, arguments: GemmArguments) -> cuda.CUresult:
1562
- """
1563
- Configure and launch the cuda kernel with input arguments
1564
- """
1565
- if self.emission_type == EmissionType.Device:
1566
- raise Exception('Running a kernel via PyCUTLASS is only enabled with emission type "Kernel"')
1567
-
1568
- err = self.rt_module.run(
1569
- arguments.host_workspace,
1570
- arguments.device_workspace,
1571
- arguments.launch_config,
1572
- arguments.stream
1573
- )
1574
-
1575
- if err != cuda.CUresult.CUDA_SUCCESS:
1576
- raise RuntimeError("CUDA Error %s" % str(err))
1577
-
1578
- return err
1579
-
1580
- def is_complex(self):
1581
- complex_operators = [
1582
- MathOperation.multiply_add_complex,
1583
- MathOperation.multiply_add_complex_gaussian,
1584
- MathOperation.multiply_add_complex_fast_f32,
1585
- ]
1586
- return self.tile_description.math_instruction.math_operation in complex_operators
1587
-
1588
- def is_planar_complex(self):
1589
- return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
1590
-
1591
- def accumulator_type(self):
1592
- accum = self.tile_description.math_instruction.element_accumulator
1593
-
1594
- if self.is_complex():
1595
- return get_complex_from_real(accum)
1596
-
1597
- return accum
1598
-
1599
- def short_math_name(self):
1600
- if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
1601
- return "g%s" % ShortDataTypeNames[self.accumulator_type()]
1602
- return ShortDataTypeNames[self.accumulator_type()]
1603
-
1604
- def core_name(self):
1605
- """The basic operation kind is prefixed with a letter indicating the accumulation type."""
1606
-
1607
- inst_shape = ""
1608
- inst_operation = ""
1609
- intermediate_type = ""
1610
-
1611
- math_operations_map = {
1612
- MathOperation.xor_popc: "xor",
1613
- }
1614
-
1615
- if (self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or
1616
- self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp):
1617
- math_op = self.tile_description.math_instruction.math_operation
1618
- math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ""
1619
-
1620
- if self.tile_description.math_instruction.instruction_shape is not None:
1621
- if self.api == ApiVersion.v3x and self.arch >= 90:
1622
- inst_shape = "%dx%dx%d" % tuple(
1623
- self.tile_description.math_instruction.instruction_shape)
1624
- else:
1625
- inst_shape = "%d%d%d" % tuple(
1626
- self.tile_description.math_instruction.instruction_shape)
1627
- else:
1628
- inst_shape = "Default"
1629
- inst_shape += math_op_string
1630
-
1631
- if (self.tile_description.math_instruction.element_a != self.A.element and
1632
- self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator):
1633
- intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
1634
-
1635
- return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
1636
-
1637
- def extended_name(self):
1638
- """Append data types if they differ from compute type."""
1639
- if self.is_complex():
1640
- extended_name = "${core_name}"
1641
- else:
1642
- if (self.C.element != self.tile_description.math_instruction.element_accumulator and
1643
- self.A.element != self.tile_description.math_instruction.element_accumulator):
1644
- extended_name = "${element_c}_${core_name}_${element_a}"
1645
- elif (self.C.element == self.tile_description.math_instruction.element_accumulator and
1646
- self.A.element != self.tile_description.math_instruction.element_accumulator):
1647
- extended_name = "${core_name}_${element_a}"
1648
- else:
1649
- extended_name = "${core_name}"
1650
-
1651
- extended_name = SubstituteTemplate(extended_name, {
1652
- "element_a": DataTypeNames[self.A.element],
1653
- "element_c": DataTypeNames[self.C.element],
1654
- "core_name": self.core_name(),
1655
- })
1656
-
1657
- return extended_name
1658
-
1659
- def extended_name_3x(self):
1660
- """Generates a string representing the MMA atom. Assumes accumulator type is C type."""
1661
- extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
1662
- element_a=DataTypeNames[self.A.element],
1663
- element_b=DataTypeNames[self.B.element],
1664
- element_acc=DataTypeNames[self.accumulator_type()],
1665
- element_c=DataTypeNames[self.C.element],
1666
- element_d=DataTypeNames[self.epilogue_functor.element_output],
1667
- core_name=self.core_name())
1668
- return extended_name
1669
-
1670
- def layout_name(self):
1671
- if self.is_complex() or self.is_planar_complex():
1672
- return "%s%s" % (
1673
- ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
1674
- ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
1675
- )
1676
- return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
1677
-
1678
- # Generates a short string representing the ABC layout tags (e.g. ntn or tnn)
1679
- def layout_name_3x(self):
1680
- if self.is_complex() or self.is_planar_complex():
1681
- return "{}{}{}".format(
1682
- ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
1683
- ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
1684
- ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)])
1685
- else:
1686
- return "{}{}{}".format(
1687
- ShortLayoutTypeNames[self.A.layout],
1688
- ShortLayoutTypeNames[self.B.layout],
1689
- ShortLayoutTypeNames[self.C.layout])
1690
-
1691
- # Generates a short string representing underlying kernel schedule type
1692
- def kernel_schedule_name_3x(self):
1693
- if self.tile_description.kernel_schedule is None:
1694
- return KernelScheduleSuffixes[KernelScheduleType.ScheduleAuto]
1695
- else:
1696
- return KernelScheduleSuffixes[self.tile_description.kernel_schedule]
1697
-
1698
- # Generates a short string representing underlying epilogue schedule type
1699
- def epilogue_schedule_name_3x(self):
1700
- if self.tile_description.epilogue_schedule is None:
1701
- return EpilogueScheduleSuffixes[EpilogueScheduleType.ScheduleAuto]
1702
- else:
1703
- return EpilogueScheduleSuffixes[self.tile_description.epilogue_schedule]
1704
-
1705
- def procedural_name(self):
1706
- """The full procedural name indicates architecture, extended name, tile size, and layout."""
1707
- opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
1708
- if self.api == ApiVersion.v3x and self.arch >= 90:
1709
- kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}"
1710
- return kernel_name_template.format(
1711
- p=self.prefix,
1712
- ar=self.arch,
1713
- op=opcode_class_name,
1714
- ex=self.extended_name_3x(),
1715
- tbm=self.tile_description.threadblock_shape[0],
1716
- tbn=self.tile_description.threadblock_shape[1],
1717
- tbk=self.tile_description.threadblock_shape[2],
1718
- cm=self.tile_description.cluster_shape[0],
1719
- cn=self.tile_description.cluster_shape[1],
1720
- ck=self.tile_description.cluster_shape[2],
1721
- l=self.tile_description.stages,
1722
- s=self.layout_name_3x(),
1723
- al=str(self.A.alignment),
1724
- k=self.kernel_schedule_name_3x(),
1725
- e=self.epilogue_schedule_name_3x()
1726
- )
1727
- else:
1728
- threadblock = self.tile_description.procedural_name_2x()
1729
- return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
1730
- p=self.prefix,
1731
- op=opcode_class_name,
1732
- ex=self.extended_name(),
1733
- tb=threadblock,
1734
- l=self.layout_name(),
1735
- a=str(self.A.alignment)
1736
- )
1737
-
1738
- def configuration_name(self):
1739
- """The full procedural name indicates architecture, extended name, tile size, and layout."""
1740
- return self.procedural_name()
1741
-
1742
-
1743
- class GemmOperationUniversal(GemmOperationBase):
1744
- def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
1745
- epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs):
1746
- api = api_version(arch, tile_description.math_instruction.opcode_class, A.element)
1747
- super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description,
1748
- A, B, C, epilogue_functor, swizzling_functor,
1749
- api=api, **kwargs, )
1750
- if api == ApiVersion.v3x:
1751
- if swizzling_functor == SwizzlingFunctor.StreamK:
1752
- raise Exception("Stream K swizzle functor is currently only supported for CUTLASS 2.x kernels")
1753
- self.rt_module = GemmRTUniversal3x(self)
1754
- else:
1755
- if swizzling_functor == SwizzlingFunctor.StreamK:
1756
- self.rt_module = GemmRTUniversalStreamK(self)
1757
- else:
1758
- self.rt_module = GemmRTUniversal(self)
1759
- self.argument_type = self.rt_module.argument_type
1760
- self.epilogue_type = self.rt_module.epilogue_type
1761
-
1762
- def device_op(self):
1763
- """
1764
- Returns a new GemmOperationUniversal object that is constructed with emission type
1765
- ``EmissionType.Device``. Since the device-emitted kernel does not require swapping,
1766
- any swappng performed by the kernel-emitted operation is reversed.
1767
-
1768
- :return: operation ready for device-level code emission
1769
- :rtype: GemmUniversalOperation
1770
- """
1771
- A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched)
1772
- return GemmOperationUniversal(self.arch, self.tile_description, A, B, C,
1773
- self.epilogue_functor, self.swizzling_functor,
1774
- emission_type=EmissionType.Device, direct_store=self.direct_store)
1775
-
1776
-
1777
- class GemmOperationGrouped(GemmOperationBase):
1778
- def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
1779
- epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs):
1780
- super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description,
1781
- A, B, C, epilogue_functor, swizzling_functor, **kwargs)
1782
- assert "precompute_mode" in kwargs.keys(), "missing keyword arguement 'precompute_mode'."
1783
- self.precompute_mode = kwargs["precompute_mode"]
1784
- self.rt_module = GemmRTGrouped(self)
1785
- self.argument_type = self.rt_module.argument_type
1786
- self.epilogue_type = self.rt_module.epilogue_type
1787
-
1788
- def device_op(self):
1789
- """
1790
- Returns a new GemmOperationGrouped object that is constructed with emission type
1791
- ``EmissionType.Device``. Since the device-emitted kernel does not require swapping,
1792
- any swappng performed by the kernel-emitted operation is reversed.
1793
-
1794
- :return: operation ready for device-level code emission
1795
- :rtype: GemmOperationGrouped
1796
- """
1797
- A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched)
1798
- return GemmOperationGrouped(
1799
- self.arch, self.tile_description, A, B, C, self.epilogue_functor,
1800
- self.swizzling_functor, emission_type=EmissionType.Device,
1801
- direct_store=self.direct_store, precompute_mode=self.precompute_mode, )
1802
-
1803
-
1804
- ###################################################################################################
1805
- #
1806
- # Emits single instances of a CUTLASS device-wide operator
1807
- #
1808
- ###################################################################################################
1809
-
1810
-
1811
- class EmitGemmUniversalInstance:
1812
- """Responsible for emitting a CUTLASS template definition"""
1813
-
1814
- def __init__(
1815
- self,
1816
- operation_suffix="",
1817
- direct_store=False
1818
- ):
1819
- self.operation_suffix = operation_suffix
1820
- self.direct_store = direct_store
1821
- self.includes = [
1822
- "cutlass/cutlass.h",
1823
- "cutlass/gemm_coord.h",
1824
- "cutlass/numeric_types.h",
1825
- "cutlass/arch/arch.h",
1826
- "cutlass/arch/mma.h",
1827
- "cutlass/layout/matrix.h",
1828
- "cutlass/gemm/device/gemm.h",
1829
- "cutlass/gemm/device/gemm_universal_adapter.h",
1830
- "cutlass/gemm/kernel/default_gemm_universal.h",
1831
- ]
1832
- if self.direct_store:
1833
- self.includes.append(
1834
- "cutlass/epilogue/threadblock/default_epilogue_direct_store.h"
1835
- )
1836
- self.gemm_template_kernel = """
1837
- // Gemm operator ${operation_name}
1838
- using ${operation_name}_base =
1839
- typename cutlass::gemm::kernel::DefaultGemmUniversal<
1840
- ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1841
- ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1842
- ${element_c}, ${layout_c},
1843
- ${element_accumulator},
1844
- ${opcode_class},
1845
- ${arch},
1846
- cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1847
- cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1848
- cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1849
- ${epilogue_functor},
1850
- ${swizzling_functor},
1851
- ${stages},
1852
- ${math_operation}
1853
- >::GemmKernel;
1854
-
1855
- // Define named type
1856
- struct ${operation_name}${operation_suffix} :
1857
- public ${operation_name}_base { };
1858
- """
1859
-
1860
- self.gemm_template_device = """
1861
- // Gemm operator ${operation_name}
1862
- using DeviceKernel =
1863
- typename cutlass::gemm::device::GemmUniversal<
1864
- // Data type and layout of operand A
1865
- ${element_a}, ${layout_a},
1866
- // Data type and layout of operand B
1867
- ${element_b}, ${layout_b},
1868
- // Data type and layout of operand C
1869
- ${element_c}, ${layout_c},
1870
- // Data type of accumulator
1871
- ${element_accumulator},
1872
- // Class of operation
1873
- ${opcode_class},
1874
- // Compute capability of the target kernel
1875
- ${arch},
1876
- // Threadblock tile shape
1877
- cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1878
- // Warp tile shape
1879
- cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1880
- // Instruction shape
1881
- cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1882
- // Epilogue functor
1883
- ${epilogue_functor},
1884
- // Swizzling function
1885
- ${swizzling_functor},
1886
- // Number of pipeline stages
1887
- ${stages},
1888
- // Alignment of operands A and B
1889
- ${align_a}, ${align_b},
1890
- // Type of math operation
1891
- ${math_operation},
1892
- // Complex transform types of operands A and B
1893
- ${transform_a}, ${transform_b}
1894
- >;
1895
- """
1896
- self.gemm_template_direct_store = """
1897
- // Gemm operator ${operation_name}
1898
- using ${operation_name}_default =
1899
- typename cutlass::gemm::kernel::DefaultGemmUniversal<
1900
- ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1901
- ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1902
- ${element_c}, ${layout_c},
1903
- ${element_accumulator},
1904
- ${opcode_class},
1905
- ${arch},
1906
- cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1907
- cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1908
- cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1909
- ${epilogue_functor},
1910
- ${swizzling_functor},
1911
- ${stages},
1912
- ${math_operation}
1913
- >::GemmKernel;
1914
-
1915
- using ${operation_name}_base =
1916
- cutlass::gemm::kernel::GemmUniversal<
1917
- ${operation_name}_default::Mma,
1918
- cutlass::epilogue::threadblock::DefaultEpilogueDirectStore<
1919
- ${operation_name}_default::Epilogue
1920
- >::Epilogue,
1921
- ${operation_name}_default::ThreadblockSwizzle
1922
- >;
1923
-
1924
- // Define named type
1925
- struct ${operation_name}${operation_suffix} :
1926
- public ${operation_name}_base { };
1927
- """
1928
- self.gemm_template_kernel_visitor = """
1929
-
1930
- using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
1931
- cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1932
- cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1933
- ${element_c},
1934
- ${align_c},
1935
- ${epilogue_stages} /* epilogue stages */
1936
- >;
1937
-
1938
- ${callback_decl}
1939
-
1940
- // Gemm operator ${operation_name}
1941
- using ${operation_name}_base =
1942
- typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
1943
- ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1944
- ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1945
- ${element_c}, ${layout_c}, ${align_c},
1946
- ${element_accumulator},
1947
- ${element_epilogue},
1948
- ${opcode_class},
1949
- ${arch},
1950
- cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1951
- cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1952
- cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1953
- ${callback_name},
1954
- ${swizzling_functor},
1955
- ${stages},
1956
- ${math_operation},
1957
- ${epilogue_stages} /* epilogue stages */
1958
- >::GemmKernel;
1959
-
1960
- // Define named type
1961
- struct ${operation_name}${operation_suffix} :
1962
- public ${operation_name}_base { };
1963
- """
1964
-
1965
- def instance_template(self):
1966
- return """
1967
- ${compile_guard_start}
1968
- manifest.append(new ${gemm_kind}<
1969
- cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
1970
- >("${operation_name}"));
1971
- ${compile_guard_end}
1972
- """
1973
-
1974
- def emit(self, operation):
1975
- threadblock_shape = operation.tile_description.threadblock_shape
1976
- warp_count = operation.tile_description.warp_count
1977
-
1978
- warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
1979
-
1980
- instance_layout_A, instance_layout_B, instance_layout_C = \
1981
- (operation.A.layout, operation.B.layout, operation.C.layout)
1982
-
1983
- if operation.emission_type == EmissionType.Kernel:
1984
- if self.direct_store:
1985
- gemm_template = self.gemm_template_direct_store
1986
- else:
1987
- gemm_template = self.gemm_template_kernel
1988
- else:
1989
- gemm_template = self.gemm_template_device
1990
-
1991
- values = {
1992
- "operation_name": operation.procedural_name(),
1993
- "operation_suffix": self.operation_suffix,
1994
- "element_a": DataTypeTag[operation.A.element],
1995
- "layout_a": LayoutTag[instance_layout_A],
1996
- "element_b": DataTypeTag[operation.B.element],
1997
- "layout_b": LayoutTag[instance_layout_B],
1998
- "element_c": DataTypeTag[operation.C.element],
1999
- "layout_c": LayoutTag[instance_layout_C],
2000
- "element_accumulator": DataTypeTag[operation.accumulator_type()],
2001
- "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
2002
- "arch": "cutlass::arch::Sm%d" % operation.arch,
2003
- "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
2004
- "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
2005
- "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
2006
- "warp_shape_m": str(warp_shape[0]),
2007
- "warp_shape_n": str(warp_shape[1]),
2008
- "warp_shape_k": str(warp_shape[2]),
2009
- "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
2010
- "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
2011
- "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
2012
- "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
2013
- "stages": str(operation.tile_description.stages),
2014
- "align_a": str(operation.A.alignment),
2015
- "align_b": str(operation.B.alignment),
2016
- "transform_a": ComplexTransformTag[operation.A.complex_transform],
2017
- "transform_b": ComplexTransformTag[operation.B.complex_transform],
2018
- "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation],
2019
- }
2020
-
2021
- if hasattr(operation.epilogue_functor, "visitor"):
2022
- self.includes += [
2023
- "cutlass/epilogue/threadblock/fusion/visitors.hpp",
2024
- "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
2025
- ]
2026
- callback_name, callback_decl = operation.epilogue_functor.emit(operation)
2027
- values["callback_name"] = callback_name
2028
- values["callback_decl"] = callback_decl
2029
- values["align_c"] = str(operation.C.alignment)
2030
- values["element_epilogue"] = DataTypeTag[operation.epilogue_functor.element_epilogue]
2031
- if hasattr(operation.epilogue_functor, "epilogue_stages"):
2032
- epilogue_stages = operation.epilogue_functor.epilogue_stages
2033
- else:
2034
- epilogue_stages = 1
2035
- values["epilogue_stages"] = str(epilogue_stages)
2036
- return SubstituteTemplate(self.gemm_template_kernel_visitor, values)
2037
- else:
2038
- values["epilogue_functor"] = operation.epilogue_functor.emit()
2039
- return SubstituteTemplate(gemm_template, values)
2040
-
2041
-
2042
- class EmitGemmGroupedInstance:
2043
- """Responsible for emitting a CUTLASS template definition"""
2044
-
2045
- def __init__(self, operation_suffix=""):
2046
- self.operation_suffix = operation_suffix
2047
- self.includes = [
2048
- "cutlass/cutlass.h",
2049
- "cutlass/numeric_types.h",
2050
- "cutlass/arch/arch.h",
2051
- "cutlass/arch/mma.h",
2052
- "cutlass/layout/matrix.h",
2053
- "cutlass/gemm/kernel/gemm_grouped.h",
2054
- "cutlass/gemm/kernel/default_gemm_grouped.h",
2055
- ]
2056
- self.gemm_template_kernel = """
2057
- // Gemm operator ${operation_name}
2058
- using ${operation_name}_base =
2059
- typename cutlass::gemm::kernel::DefaultGemmGrouped<
2060
- ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
2061
- ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
2062
- ${element_c}, ${layout_c},
2063
- ${element_accumulator},
2064
- ${opcode_class},
2065
- ${arch},
2066
- cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
2067
- cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
2068
- cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
2069
- ${epilogue_functor},
2070
- ${swizzling_functor},
2071
- ${stages},
2072
- ${precompute_mode},
2073
- ${math_operation}
2074
- >::GemmKernel;
2075
-
2076
- // Define named type
2077
- struct ${operation_name}${operation_suffix} :
2078
- public ${operation_name}_base { };
2079
- """
2080
- self.gemm_template_device = (
2081
- self.gemm_template_kernel
2082
- + """
2083
- using DeviceKernel = cutlass::gemm::device::GemmGrouped<${operation_name}_base>;
2084
- """
2085
- )
2086
-
2087
- def instance_template(self):
2088
- return """
2089
- ${compile_guard_start}
2090
- manifest.append(new ${gemm_kind}<
2091
- cutlass::gemm::device::GemmGrouped<${operation_name}>
2092
- >("${operation_name}"));
2093
- ${compile_guard_end}
2094
- """
2095
-
2096
- def emit(self, operation):
2097
- threadblock_shape = operation.tile_description.threadblock_shape
2098
- warp_count = operation.tile_description.warp_count
2099
-
2100
- warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
2101
-
2102
- instance_layout_A, instance_layout_B, instance_layout_C = \
2103
- (operation.A.layout, operation.B.layout, operation.C.layout)
2104
-
2105
- # Support built-in epilogue functors or user-defined functions
2106
- epilogue_functor = operation.epilogue_functor.emit()
2107
-
2108
- values = {
2109
- "operation_name": operation.procedural_name(),
2110
- "operation_suffix": self.operation_suffix,
2111
- "element_a": DataTypeTag[operation.A.element],
2112
- "layout_a": LayoutTag[instance_layout_A],
2113
- "element_b": DataTypeTag[operation.B.element],
2114
- "layout_b": LayoutTag[instance_layout_B],
2115
- "element_c": DataTypeTag[operation.C.element],
2116
- "layout_c": LayoutTag[instance_layout_C],
2117
- "element_accumulator": DataTypeTag[operation.accumulator_type()],
2118
- "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
2119
- "arch": "cutlass::arch::Sm%d" % operation.arch,
2120
- "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
2121
- "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
2122
- "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
2123
- "warp_shape_m": str(warp_shape[0]),
2124
- "warp_shape_n": str(warp_shape[1]),
2125
- "warp_shape_k": str(warp_shape[2]),
2126
- "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
2127
- "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
2128
- "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
2129
- "epilogue_functor": epilogue_functor,
2130
- "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
2131
- "stages": str(operation.tile_description.stages),
2132
- "align_a": str(operation.A.alignment),
2133
- "align_b": str(operation.B.alignment),
2134
- "transform_a": ComplexTransformTag[operation.A.complex_transform],
2135
- "transform_b": ComplexTransformTag[operation.B.complex_transform],
2136
- "precompute_mode": SchedulerModeTag[operation.precompute_mode],
2137
- "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation],
2138
- }
2139
-
2140
- if operation.emission_type == EmissionType.Kernel:
2141
- gemm_template = self.gemm_template_kernel
2142
- else:
2143
- gemm_template = self.gemm_template_device
2144
-
2145
- return SubstituteTemplate(gemm_template, values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py DELETED
@@ -1,509 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Common data types and string names/tags for them
35
- """
36
-
37
- import enum
38
-
39
- from cutlass_library import (
40
- ComplexTransform,
41
- DataType,
42
- DataTypeSize,
43
- EpilogueScheduleType,
44
- KernelScheduleSuffixes,
45
- KernelScheduleType,
46
- MathOperation,
47
- OpcodeClass,
48
- TileSchedulerType
49
- )
50
-
51
-
52
- # The following block implements enum.auto() for Python 3.5 variants that don't include it such
53
- # as the default 3.5.2 on Ubuntu 16.04.
54
- #
55
- # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
56
-
57
- try:
58
- from enum import auto as enum_auto
59
- except ImportError:
60
- __cutlass_library_auto_enum = 0
61
-
62
- def enum_auto() -> int:
63
- global __cutlass_library_auto_enum
64
- i = __cutlass_library_auto_enum
65
- __cutlass_library_auto_enum += 1
66
- return i
67
-
68
-
69
- class DataTypeSizeBytes:
70
- """
71
- Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the
72
- data type key is less than a full byte or a non-integer number of bytes.
73
- """
74
-
75
- @staticmethod
76
- def __class_getitem__(datatype):
77
- """
78
- Returns the number of bytes in size the data type is. Raises an exception if the data type
79
- is either less than a full byte or a non-integer number of bytes in size.
80
-
81
- :param datatype: data type to query
82
-
83
- :return: number of bytes the data type occupies
84
- :rtype: int
85
- """
86
- bits = DataTypeSize[datatype]
87
- if bits < 8:
88
- raise Exception(
89
- f"Data type {datatype} is less than one byte in size."
90
- )
91
- elif bits % 8 != 0:
92
- raise Exception(
93
- f"Data type datatype is not an integer number of bytes."
94
- )
95
- return bits // 8
96
-
97
-
98
- class SchedulerMode(enum.Enum):
99
- Device = enum_auto()
100
- Host = enum_auto()
101
-
102
-
103
- SchedulerModeTag = {
104
- SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
105
- SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute",
106
- }
107
-
108
-
109
- ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"}
110
-
111
-
112
- class FunctionalOp(enum.Enum):
113
- AtomicAdd = enum_auto()
114
- AtomicMaximum = enum_auto()
115
- Divides = enum_auto()
116
- Maximum = enum_auto()
117
- Minimum = enum_auto()
118
- Minus = enum_auto()
119
- Multiplies = enum_auto()
120
- MultiplyAdd = enum_auto()
121
- Plus = enum_auto()
122
- Exp = enum_auto()
123
-
124
-
125
- FunctionalOpTag = {
126
- FunctionalOp.AtomicAdd: "cutlass::atomic_add",
127
- FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
128
- FunctionalOp.Divides: "cutlass::divides",
129
- FunctionalOp.Maximum: "cutlass::maximum",
130
- FunctionalOp.Minimum: "cutlass::minimum",
131
- FunctionalOp.Minus: "cutlass::minus",
132
- FunctionalOp.Multiplies: "cutlass::multiplies",
133
- FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
134
- FunctionalOp.Plus: "cutlass::plus",
135
- FunctionalOp.Exp: "cutlass::fast_exp_op",
136
- }
137
-
138
-
139
- class ActivationOp(enum.Enum):
140
- DGelu = enum_auto()
141
- Gelu = enum_auto()
142
- GeluTaylor = enum_auto()
143
- HardSwish = enum_auto()
144
- Identity = enum_auto()
145
- LeakyReLU = enum_auto()
146
- ReLU = enum_auto()
147
- Sigmoid = enum_auto()
148
- SiLU = enum_auto()
149
- Tanh = enum_auto()
150
-
151
-
152
- ActivationOpTag = {
153
- ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
154
- ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
155
- ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
156
- ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
157
- ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
158
- ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
159
- ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
160
- ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
161
- ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
162
- ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
163
- }
164
-
165
-
166
- def op_tag(op) -> str:
167
- """
168
- Dispatches `op` to the appropriate *Tag dictionary depending on whether
169
- `op` is an ActivationOp or FunctionalOp. This is useful for cases in which
170
- either type can be used.
171
-
172
- :param op: operation to emit a tag for
173
- :type op: ActivationOp | FunctionalOp
174
-
175
- :return: tag corresponding to op
176
- :rtype: str
177
- """
178
- if isinstance(op, ActivationOp):
179
- return ActivationOpTag[op]
180
- elif isinstance(op, FunctionalOp):
181
- return FunctionalOpTag[op]
182
- else:
183
- raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.")
184
-
185
-
186
- class FloatRoundStyle(enum.Enum):
187
- ToNearest = enum_auto()
188
- ToNearestSatfinite = enum_auto()
189
- Indeterminate = enum_auto()
190
- TowardZero = enum_auto()
191
- TowardInfinity = enum_auto()
192
- TowardNegInfinity = enum_auto()
193
- HalfUlpTruncDntz = enum_auto()
194
- HalfUlpTruncate = enum_auto()
195
-
196
-
197
- FloatRoundStyleTag = {
198
- FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest",
199
- FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite",
200
- FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate",
201
- FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero",
202
- FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity",
203
- FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity",
204
- FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz",
205
- FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate",
206
- }
207
-
208
-
209
- class MathInstruction:
210
- """
211
- Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel
212
- """
213
-
214
- def __init__(
215
- self,
216
- instruction_shape,
217
- element_a,
218
- element_b,
219
- element_accumulator,
220
- opcode_class=OpcodeClass.Simt,
221
- math_operation=MathOperation.multiply_add,
222
- ):
223
- """
224
- :param instruction_shape: size of the [M, N, K] dimensions of the instruction
225
- :type instruction_shape: list or tuple
226
- :param element_a: data type of operand A
227
- :param element_b: data type of operand B
228
- :param element_accumulator: data type used in accumulation
229
- :param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core)
230
- :type opcode_class: cutlass_library.library.OpcodeClass
231
- :param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate)
232
- :type math_operation: MathOperation
233
- """
234
- self.instruction_shape = instruction_shape
235
- self.element_a = element_a
236
- self.element_b = element_b
237
- self.element_accumulator = element_accumulator
238
- self.opcode_class = opcode_class
239
- self.math_operation = math_operation
240
-
241
-
242
- def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule):
243
- blackwell_threadblock_shape = tile_description.threadblock_shape
244
- is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule])
245
- if cluster_shape[0] > 0:
246
- blackwell_threadblock_shape = [
247
- tile_description.threadblock_shape[0] // cluster_shape[0],
248
- tile_description.threadblock_shape[1] // cluster_shape[1],
249
- tile_description.threadblock_shape[2] // cluster_shape[2]
250
- ]
251
- if is_2sm:
252
- blackwell_threadblock_shape[0] *= 2
253
- else:
254
- blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape
255
- return blackwell_threadblock_shape, is_2sm
256
-
257
-
258
- class TileDescription:
259
- """
260
- Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes,
261
- stage count, and math instruction specification
262
- """
263
-
264
- def __init__(
265
- self,
266
- threadblock_shape,
267
- stages,
268
- warp_count,
269
- math_instruction,
270
- cluster_shape=[1, 1, 1],
271
- kernel_schedule: KernelScheduleType = None,
272
- epilogue_schedule: EpilogueScheduleType = None,
273
- tile_scheduler: TileSchedulerType = None
274
- ):
275
- """
276
- :param threadblock_shape: shape of a threadblock tyle
277
- :type threadblock_shape: list or tuple
278
- :param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum
279
- number of stages that can be supported for an operation on a given architecture will be computed at a later time
280
- :type stages: int or None
281
- :param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile
282
- :type warp_count: list, tuple, or None
283
- :param math_instruction: specification of the instruction type and shape to be performed and the types of its operands
284
- :type math_instruction: MathInstruction
285
- :param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster
286
- :param kernel_schedule: type of kernel schedule to use (only available for SM90+)
287
- :type kernel_schedule: cutlass_library.KernelScheduleType
288
- :param epilogue_schedule: type of epilogue schedule to use (only available for SM90+)
289
- :type epilogue_schedule: cutlass_library.EpilogueScheduleType
290
- :param tile_scheduler: type of tile scheduler to use (only available for SM90+)
291
- :type tile_scheduler: cutlass_library.TileSchedulerType
292
- """
293
- if ((kernel_schedule is None and epilogue_schedule is not None) or
294
- (kernel_schedule is not None and epilogue_schedule is None)):
295
- raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.")
296
-
297
- self.threadblock_shape = threadblock_shape
298
- self.cluster_shape = cluster_shape
299
- self.kernel_schedule = kernel_schedule
300
- self.epilogue_schedule = epilogue_schedule
301
- self.tile_scheduler = tile_scheduler
302
- self.stages = stages
303
-
304
- self.math_instruction = math_instruction
305
- self.instruction_shape = math_instruction.instruction_shape
306
-
307
- # Number of warps along x, y, z directions
308
- self.warp_count = warp_count
309
-
310
- self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule)
311
-
312
- def clone_and_update(self, td: dict):
313
- attrs = {
314
- "cluster_shape": None,
315
- "threadblock_shape": None,
316
- "warp_count": None,
317
- "stages": None,
318
- "instruction_shape": None,
319
- "kernel_schedule": None,
320
- "epilogue_schedule": None,
321
- "tile_scheduler": None
322
- }
323
- for key in attrs.keys():
324
- if key in td.keys():
325
- attrs[key] = td[key]
326
- else:
327
- attrs[key] = getattr(self, key)
328
-
329
- attrs["math_instruction"] = MathInstruction(
330
- attrs["instruction_shape"],
331
- self.math_instruction.element_a,
332
- self.math_instruction.element_b,
333
- self.math_instruction.element_accumulator,
334
- self.math_instruction.opcode_class,
335
- self.math_instruction.math_operation
336
- )
337
-
338
- # Remove the instruction shape
339
- del attrs["instruction_shape"]
340
-
341
- return TileDescription(**attrs)
342
-
343
- @property
344
- def num_threads(self):
345
- """
346
- Returns the number of threads in the threadblock
347
-
348
- :return: number of threads in the threadblock
349
- :rtype: int or None (if warp count is None)
350
- """
351
- if self.warp_count is not None:
352
- threads = 32
353
- for cnt in self.warp_count:
354
- threads *= cnt
355
- return threads
356
- return None
357
-
358
- def procedural_name(self):
359
- """
360
- Returns a name identifying the tile description
361
-
362
- :return: name identifying the tile description
363
- :rtype: int
364
- """
365
- emit_stages = 0 if self.stages is None else self.stages
366
- name = "%dx%dx%d_%dx%d_%dx%d" % (
367
- self.cluster_shape[0],
368
- self.cluster_shape[1],
369
- self.cluster_shape[2],
370
- self.threadblock_shape[0],
371
- self.threadblock_shape[1],
372
- self.threadblock_shape[2],
373
- emit_stages
374
- )
375
-
376
- return name
377
-
378
- def procedural_name_2x(self):
379
- """
380
- Returns a name identifying the tile description
381
-
382
- :return: name identifying the tile description
383
- :rtype: int
384
- """
385
- return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
386
-
387
- def __str__(self):
388
- """
389
- Returns a string with containing each of the tile description's values
390
-
391
- :return: contents of tile description
392
- :rtype: str
393
- """
394
- if self.kernel_schedule is not None:
395
- kschedule = self.kernel_schedule
396
- else:
397
- kschedule = KernelScheduleType.ScheduleAuto
398
-
399
- if self.epilogue_schedule is not None:
400
- eschedule = self.epilogue_schedule
401
- else:
402
- eschedule = EpilogueScheduleType.ScheduleAuto
403
-
404
- if self.tile_scheduler is not None:
405
- tschedule = self.tile_scheduler.name
406
- else:
407
- tschedule = "None"
408
- return f"""
409
- {{
410
- ClusterShape: {self.cluster_shape}
411
- ThreadblockShape: {self.threadblock_shape}
412
- WarpCount: {self.warp_count}
413
- Stages: {self.stages if self.stages is not None else 'Auto'}
414
- InstructionShape: {self.math_instruction.instruction_shape}
415
- Kernel schedule: {kschedule.name}
416
- Epilogue schedule: {kschedule.name}
417
- TileScheduler: {tschedule}
418
- }}"""
419
-
420
-
421
- class TensorDescription:
422
- def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none):
423
- self.element = element
424
- self.layout = layout
425
- if element != DataType.void:
426
- self.alignment = min(128 // DataTypeSize[self.element], alignment)
427
- else:
428
- self.alignment = alignment
429
- self.complex_transform = complex_transform
430
-
431
-
432
- def CalculateSmemUsagePerStage(operation):
433
- """
434
- Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
435
-
436
- :param op: operation for which the maximum stages should be computed. If stages are
437
- set via the `op.tile_description.stages` parameter, this setting is ignored
438
- in the present calculation
439
- :type op: cutlass_cppgen.backend.Operation
440
-
441
- :return: number of bytes of shared memory consumed by a single stage
442
- :rtype: int
443
- """
444
- m, n, k = operation.tile_description.threadblock_shape
445
-
446
- if operation.operation_kind == OperationKind.Gemm:
447
- stage_barrier_bytes = 32
448
- return (
449
- (DataTypeSize[operation.A.element] * m * k // 8)
450
- + (DataTypeSize[operation.B.element] * k * n // 8)
451
- + stage_barrier_bytes
452
- )
453
- else:
454
- raise Exception("Unsupported operation kind {}.".format(operation.operation_kind))
455
-
456
-
457
- def CalculateSmemUsage(operation):
458
- """
459
- Returns the amount of shared memory in bytes consumed by a kernel.
460
-
461
- :param op: operation for which the maximum stages should be computed. If stages are
462
- set via the `op.tile_description.stages` parameter, this setting is ignored
463
- in the present calculation
464
- :type op: cutlass_cppgen.backend.Operation
465
-
466
- :return: int
467
- """
468
- return operation.tile_description.stages * CalculateSmemUsagePerStage(operation)
469
-
470
-
471
- class ApiVersion(enum.Enum):
472
- """
473
- Differentiate between CUTLASS 2.x and 3.x API versions
474
- """
475
-
476
- v2x = enum_auto()
477
- v3x = enum_auto()
478
-
479
-
480
- def api_version(arch, opclass, dtype):
481
- """
482
- Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x
483
- or 3.x for code emission.
484
-
485
- :param arch: compute capability of device on which to run
486
- :type arch: int
487
- :param opclass: class of the operation being performed
488
- :type opclass: cutlass_library.OpcodeClass
489
- :param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same)
490
- :type dtype: cutlass_library.DataType
491
-
492
- :return: API version to be used in code emission
493
- :rtype: ApiVersion
494
- """
495
- if (arch in [90, 100, 101, 103] and
496
- opclass == OpcodeClass.TensorOp and
497
- (dtype != DataType.f64)):
498
- return ApiVersion.v3x
499
- else:
500
- return ApiVersion.v2x
501
-
502
-
503
- class EmissionType(enum.Enum):
504
- """
505
- Tags for whether to emit a kernel- or device-level operation
506
- """
507
-
508
- Kernel = enum_auto()
509
- Device = enum_auto()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py DELETED
@@ -1,121 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import numpy as np
34
-
35
- import cutlass_cppgen
36
- from cutlass_cppgen.utils.datatypes import is_numpy_tensor
37
- from cutlass_cppgen.utils.lazy_import import lazy_import
38
-
39
- if cutlass_cppgen.use_rmm:
40
- import rmm
41
- else:
42
- cudart = lazy_import("cuda.cudart")
43
-
44
-
45
- class PoolMemoryManager:
46
- def __init__(self, init_pool_size: int, max_pool_size: int) -> None:
47
- self.pool = rmm.mr.PoolMemoryResource(
48
- rmm.mr.CudaMemoryResource(),
49
- initial_pool_size=init_pool_size,
50
- maximum_pool_size=max_pool_size
51
- )
52
- self.mr = rmm.mr.TrackingResourceAdaptor(self.pool)
53
- rmm.mr.set_current_device_resource(self.mr)
54
-
55
- def pool_size(self):
56
- return self.pool.pool_size()
57
-
58
-
59
- class DevicePtrWrapper:
60
- """
61
- Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer
62
- (at least in terms of the interface used by the CUTLASS Python interface)
63
- """
64
- def __init__(self, dev_ptr):
65
- self.dev_ptr = dev_ptr
66
-
67
- @property
68
- def ptr(self):
69
- return self.dev_ptr
70
-
71
-
72
- def _todevice(host_data):
73
- """
74
- Helper for transferring host data to device memory
75
- """
76
- if cutlass_cppgen.use_rmm:
77
- return rmm.DeviceBuffer.to_device(host_data.tobytes())
78
- else:
79
- nbytes = len(host_data.tobytes())
80
- dev_ptr_wrapper = device_mem_alloc(nbytes)
81
- err, = cudart.cudaMemcpy(
82
- dev_ptr_wrapper.ptr,
83
- host_data.__array_interface__['data'][0],
84
- nbytes,
85
- cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
86
- )
87
- if err != cudart.cudaError_t.cudaSuccess:
88
- raise Exception(f"cudaMemcpy failed with error {err}")
89
- return dev_ptr_wrapper
90
-
91
-
92
- def todevice(host_data, dtype=np.float32):
93
- """
94
- Pass the host_data to device memory
95
- """
96
- if isinstance(host_data, list):
97
- return _todevice(np.array(host_data, dtype=dtype))
98
- elif is_numpy_tensor(host_data):
99
- return _todevice(host_data)
100
-
101
-
102
- def device_mem_alloc(size):
103
- if cutlass_cppgen.use_rmm:
104
- return rmm.DeviceBuffer(size=size)
105
- else:
106
- err, ptr = cudart.cudaMalloc(size)
107
- if err != cudart.cudaError_t.cudaSuccess:
108
- raise Exception(f"cudaMalloc failed with error {err}")
109
- return DevicePtrWrapper(ptr)
110
-
111
-
112
- def align_size(size, alignment=256):
113
- return ((size + alignment - 1) // alignment) * alignment
114
-
115
-
116
- def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34):
117
- if cutlass_cppgen.use_rmm:
118
- memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size)
119
- return memory_pool
120
- else:
121
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py DELETED
@@ -1,140 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import ctypes
34
- from cutlass_cppgen.utils.lazy_import import lazy_import
35
- cuda = lazy_import("cuda.cuda")
36
-
37
- from cutlass_cppgen.backend.utils.device import device_cc
38
-
39
- _supports_cluster_launch = None
40
-
41
-
42
- def supports_cluster_launch():
43
- from cuda import __version__
44
- _version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")]
45
- global _supports_cluster_launch
46
- if _supports_cluster_launch is None:
47
- major, minor = _version_splits[0], _version_splits[1]
48
- _supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8))
49
- return _supports_cluster_launch
50
-
51
-
52
- class LaunchConfiguration:
53
- def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0):
54
- self.grid = grid
55
- self.block = block
56
- self.shared_memory_capacity = smem
57
-
58
-
59
- class ExecutableOperation:
60
- def __init__(self, operation):
61
- self.operation = operation
62
- self.module = None
63
- self.kernel = None
64
-
65
- def name(self):
66
- return self.operation.procedural_name()
67
-
68
- def emit(self):
69
- return ""
70
-
71
- def can_implement(self, configuration, arguments):
72
- raise NotImplementedError()
73
-
74
- def get_host_workspace_size(self, arguments):
75
- raise NotImplementedError()
76
-
77
- def get_device_workspace_size(self, arguments):
78
- raise NotImplementedError()
79
-
80
- def plan(self, arguments):
81
- raise NotImplementedError()
82
-
83
- def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None):
84
- raise NotImplementedError()
85
-
86
- def run_with_clusters(self, launch_config, kernel_params, stream=None):
87
- if not stream:
88
- stream = cuda.CUstream(0)
89
- if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"):
90
- attr = cuda.CUlaunchAttribute()
91
- attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape
92
- attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
93
- attrs = [attr]
94
-
95
- # Allow for non-portable cluster sizes
96
- err, = cuda.cuFuncSetAttribute(
97
- self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)
98
- if err != cuda.CUresult.CUDA_SUCCESS:
99
- return err
100
- else:
101
- attrs = []
102
-
103
- config = cuda.CUlaunchConfig()
104
- config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid
105
- config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block
106
- config.blockDimZ = launch_config.block[2]
107
- config.sharedMemBytes = launch_config.shared_memory_capacity
108
- config.hStream = stream
109
- config.attrs = attrs
110
- config.numAttrs = len(attrs)
111
-
112
- err, = cuda.cuLaunchKernelEx(
113
- config, f=self.kernel, kernelParams=kernel_params, extra=0)
114
- return err
115
-
116
- def run_without_clusters(self, launch_config, kernel_params, stream=None):
117
- if not stream:
118
- stream = cuda.CUstream(0)
119
- err, = cuda.cuLaunchKernel(
120
- self.kernel,
121
- launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
122
- launch_config.block[0], launch_config.block[1], launch_config.block[2],
123
- launch_config.shared_memory_capacity,
124
- stream,
125
- kernel_params,
126
- 0)
127
-
128
- return err
129
-
130
- def run(self, host_workspace, device_workspace, launch_config, stream=None):
131
- if not stream:
132
- stream = cuda.CUstream(0)
133
- cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
134
- packed = (ctypes.c_void_p * 1)()
135
- packed[0] = ctypes.addressof(cArg)
136
-
137
- if supports_cluster_launch():
138
- return self.run_with_clusters(launch_config, packed, stream)
139
- else:
140
- return self.run_without_clusters(launch_config, packed, stream)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py DELETED
@@ -1,455 +0,0 @@
1
- ################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- ################################################################################
32
- from __future__ import annotations
33
-
34
- import ctypes
35
- from typing import Union
36
-
37
- from cutlass_cppgen.utils.lazy_import import lazy_import
38
- cuda = lazy_import("cuda.cuda")
39
- cudart = lazy_import("cuda.cudart")
40
- import numpy as np
41
-
42
- from cutlass_library import (
43
- DataTypeNames,
44
- DataTypeSize,
45
- DataTypeTag,
46
- LayoutType,
47
- SubstituteTemplate
48
- )
49
-
50
- import cutlass_cppgen
51
- from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
52
- from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend
53
- from cutlass_cppgen.backend.library import TensorDescription
54
- from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
55
- from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
56
- from cutlass_cppgen.shape import MatrixCoord
57
- from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor
58
-
59
-
60
- class ReductionOperation:
61
- pass
62
-
63
-
64
- class ReductionArguments:
65
- """
66
- Arguments of reduction
67
- """
68
-
69
- def __init__(
70
- self,
71
- operation: ReductionOperation,
72
- problem_size: "list[int]",
73
- partitions: int,
74
- workspace: cuda.CUdeviceptr,
75
- destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
76
- source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
77
- **kwargs,
78
- ) -> None:
79
- # tensor_C can be interpreted as the bias with bias=True in keyword args
80
- if "bias" in kwargs.keys():
81
- self.bias = kwargs["bias"]
82
- else:
83
- # by default, tensor_C is not bias
84
- self.bias = False
85
- if "stream" in kwargs.keys():
86
- self.stream = kwargs["stream"]
87
- else:
88
- self.stream = cuda.CUstream(0)
89
-
90
- self.operation = operation
91
- self.ptr_workspace = workspace
92
-
93
- # number of split-k partitions
94
- self.partitions = partitions
95
-
96
- if is_numpy_tensor(destination):
97
- self.host_D = destination
98
- self.destination_buffer = NumpyFrontend.argument(destination, True)
99
- self.source_buffer = NumpyFrontend.argument(source, False)
100
- self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr)
101
- self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr)
102
- elif is_torch_tensor(destination):
103
- self.ptr_destination = TorchFrontend.argument(destination)
104
- self.ptr_source = TorchFrontend.argument(source)
105
- elif isinstance(destination, cuda.CUdeviceptr):
106
- self.ptr_destination = destination
107
- self.ptr_source = source
108
- else:
109
- raise TypeError("unknown Type")
110
-
111
- self.problem_size = MatrixCoord_(problem_size[0], problem_size[1])
112
-
113
- self.partition_stride = (
114
- problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8
115
- )
116
-
117
- if "output_op" in kwargs.keys():
118
- self.output_op = kwargs["output_op"]
119
- else:
120
- self.output_op = self.operation.epilogue_type(1.0, 0.0)
121
-
122
- self.get_arguments()
123
-
124
- @staticmethod
125
- def get_tensor_ref(
126
- extent: "tuple[int]",
127
- device_ptr: cuda.CUdeviceptr,
128
- layout: LayoutType,
129
- ):
130
- if layout == LayoutType.RowMajor:
131
- return TensorRef2D_(int(device_ptr), extent[1])
132
- else:
133
- raise ValueError(f"Unknown layout type {layout}")
134
-
135
- def get_arguments(self):
136
- ref_workspace = ReductionArguments.get_tensor_ref(
137
- extent=[
138
- self.problem_size.row,
139
- self.problem_size.column,
140
- ],
141
- device_ptr=self.ptr_workspace,
142
- layout=LayoutType.RowMajor,
143
- )
144
- if self.bias:
145
- ref_source = ReductionArguments.get_tensor_ref(
146
- extent=[0, 0],
147
- device_ptr=self.ptr_source,
148
- layout=LayoutType.RowMajor,
149
- )
150
- else:
151
- ref_source = ReductionArguments.get_tensor_ref(
152
- extent=[
153
- self.problem_size.row,
154
- self.problem_size.column,
155
- ],
156
- device_ptr=self.ptr_source,
157
- layout=LayoutType.RowMajor,
158
- )
159
-
160
- ref_destination = ReductionArguments.get_tensor_ref(
161
- extent=[
162
- self.problem_size.row,
163
- self.problem_size.column,
164
- ],
165
- device_ptr=self.ptr_destination,
166
- layout=LayoutType.RowMajor,
167
- )
168
-
169
- self.c_arguments = self.operation.argument_type(
170
- self.problem_size,
171
- self.partitions,
172
- self.partition_stride,
173
- ref_workspace,
174
- ref_destination,
175
- ref_source,
176
- self.output_op,
177
- )
178
-
179
- params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments))
180
- self.host_workspace = bytearray(params_.contents)
181
-
182
- def sync(self):
183
- (err,) = cudart.cudaDeviceSynchronize()
184
- if err != cuda.CUresult.CUDA_SUCCESS:
185
- raise RuntimeError(f"CUDA Error {str(err)}")
186
-
187
- if hasattr(self, "host_D"):
188
- (err,) = cuda.cuMemcpyDtoH(
189
- self.host_D,
190
- self.ptr_destination,
191
- self.host_D.size * self.host_D.itemsize,
192
- )
193
- if err != cuda.CUresult.CUDA_SUCCESS:
194
- raise RuntimeError("CUDA Error %s" % str(err))
195
-
196
- self.free()
197
-
198
- def free(self):
199
- """
200
- Frees allocated device-side memory
201
- """
202
- # Free any device memory allocated manually
203
- if not cutlass_cppgen.use_rmm:
204
- for attr in ["destination_buffer", "source_buffer"]:
205
- if hasattr(self, attr):
206
- buf = getattr(self, attr)
207
- if isinstance(buf, DevicePtrWrapper):
208
- err, = cudart.cudaFree(buf.ptr)
209
- if err != cudart.cudaError_t.cudaSuccess:
210
- raise RuntimeError(f"cudaFree failed with error {err}")
211
- del buf
212
-
213
-
214
- class ReductionRT(ExecutableOperation):
215
- """
216
- ReductionRT manages the CUTLASS runtime components for reduction
217
- """
218
-
219
- KernelTemplate = r"""
220
- extern "C"
221
- __global__ void
222
- ${operation_name}(${operation_name}${operation_suffix}::Params params) {
223
-
224
- // Dynamic shared memory base pointer
225
- extern __shared__ int SharedStorageBase[];
226
-
227
- // Declare pointer to dynamic shared memory.
228
- ${operation_name}${operation_suffix}::SharedStorage *shared_storage =
229
- reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
230
-
231
- ${operation_name}${operation_suffix} op;
232
-
233
- op(params, *shared_storage);
234
- }
235
- """
236
- HostTemplate = r"""
237
- extern "C" {
238
- // Get the size of params in bytes
239
- int ${operation_name}_get_param_size(){
240
- return sizeof(${operation_name}${operation_suffix}::Params);
241
- }
242
-
243
- // Get the size of dynamic shared memory in bytes
244
- int ${operation_name}_shared_memory_size() {
245
- return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
246
- }
247
-
248
- // Get the params as byte array
249
- char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){
250
- char *bytes = ((char*)(params));
251
- char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
252
- for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
253
- output[i] = bytes[i];
254
-
255
- return output;
256
- }
257
- }
258
- """
259
-
260
- def __init__(self, operation: ReductionOperation):
261
- super().__init__(operation)
262
-
263
- self.operation: ReductionOperation = operation
264
- self.emitter = EmitReductionInstance("_type")
265
-
266
- self.elements_per_access = self.operation.count
267
- (
268
- self.argument_type,
269
- self.epilogue_type,
270
- ) = get_reduction_params(operation.epilogue_functor)
271
- self.argtype = [ctypes.POINTER(self.argument_type)]
272
-
273
- def emit(self):
274
- return self.emitter.emit(self.operation)
275
-
276
- def plan(self, arguments: ReductionArguments):
277
- block_shape = [
278
- self.operation.shape.column // self.elements_per_access,
279
- self.operation.shape.row,
280
- 1,
281
- ]
282
- grid_shape = [
283
- (arguments.problem_size.row + self.operation.shape.row - 1)
284
- // self.operation.shape.row,
285
- (arguments.problem_size.column + self.operation.shape.column - 1)
286
- // self.operation.shape.column,
287
- 1,
288
- ]
289
- return LaunchConfiguration(
290
- grid_shape,
291
- block_shape,
292
- self.shared_memory_capacity,
293
- )
294
-
295
- def initialize(self):
296
- (err,) = cuda.cuFuncSetAttribute(
297
- self.kernel,
298
- attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
299
- value=self.shared_memory_capacity,
300
- )
301
- if err != cuda.CUresult.CUDA_SUCCESS:
302
- raise RuntimeError(f"CUDA Error: {err}")
303
-
304
-
305
- class ReductionOperation:
306
- """
307
- CUTLASS reduction Operation
308
- """
309
-
310
- def __init__(
311
- self,
312
- shape: MatrixCoord,
313
- C: TensorDescription,
314
- element_accumulator,
315
- element_workspace=None,
316
- element_compute=None,
317
- epilogue_functor=None,
318
- count: int = 1,
319
- partitions_per_stage: int = 4,
320
- ) -> None:
321
- self.shape = shape
322
- self.epilogue_functor = epilogue_functor
323
- self.element_accumulator = element_accumulator
324
-
325
- if element_workspace is None:
326
- self.element_workspace = element_accumulator
327
- else:
328
- self.element_workspace = element_workspace
329
-
330
- if element_compute is None:
331
- self.element_compute = element_accumulator
332
- else:
333
- self.element_compute = element_compute
334
-
335
- self.element_output = C.element
336
- self.C: TensorDescription = C
337
-
338
- # Reduce op processing size
339
- self.count: int = count
340
-
341
- # Number of partitions to reduce per stage
342
- self.partitions_per_stage: int = partitions_per_stage
343
-
344
- self.rt_module: ReductionRT = ReductionRT(self)
345
- self.argument_type = self.rt_module.argument_type
346
- self.epilogue_type = self.rt_module.epilogue_type
347
-
348
- def extended_name(self):
349
- extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}"
350
-
351
- return SubstituteTemplate(
352
- extend_name,
353
- {
354
- "element_workspace": DataTypeNames[self.element_workspace],
355
- "element_accumulator": DataTypeNames[self.element_accumulator],
356
- "element_compute": DataTypeNames[self.element_compute],
357
- "element_output": DataTypeNames[self.element_output],
358
- },
359
- )
360
-
361
- def configuration_name(self):
362
- """The full procedural name indicates architecture, extended name, tile size"""
363
-
364
- configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}"
365
-
366
- threadblock = "%dx%d" % (
367
- self.shape.row,
368
- self.shape.column,
369
- )
370
-
371
- return SubstituteTemplate(
372
- configuration_name,
373
- {
374
- "extended_name": self.extended_name(),
375
- "threadblock": threadblock,
376
- },
377
- )
378
-
379
- def procedural_name(self):
380
- """The full procedural name indicates architeture, extended name, tile size"""
381
- return self.configuration_name()
382
-
383
- def run(self, arguments: ReductionArguments) -> cuda.CUresult:
384
- """
385
- Configure and launch the cuda kernel with input arguments
386
- """
387
- launch_config = self.rt_module.plan(arguments)
388
-
389
- host_workspace = arguments.host_workspace
390
- device_workspace = None
391
-
392
- err = self.rt_module.run(
393
- host_workspace,
394
- device_workspace,
395
- launch_config,
396
- arguments.stream
397
- )
398
-
399
- if err != cuda.CUresult.CUDA_SUCCESS:
400
- raise RuntimeError(f"CUDA Error {str(err)}")
401
-
402
- return err
403
-
404
-
405
- class EmitReductionInstance:
406
- def __init__(self, operation_suffix="") -> None:
407
- self.operation_suffix = operation_suffix
408
- self.includes = [
409
- "cutlass/cutlass.h",
410
- "cutlass/numeric_types.h",
411
- "cutlass/arch/arch.h",
412
- "cutlass/arch/mma.h",
413
- "cutlass/layout/matrix.h",
414
- "cutlass/gemm/device/gemm.h",
415
- "cutlass/gemm/device/gemm_universal_adapter.h",
416
- "cutlass/gemm/kernel/default_gemm_universal.h",
417
- "cutlass/reduction/kernel/reduce_split_k.h",
418
- "cutlass/reduction/thread/reduction_operators.h",
419
- ]
420
- self.template = """
421
- // Reduction kernel instance
422
- using ${operation_name}_base =
423
- typename cutlass::reduction::kernel::ReduceSplitK<
424
- cutlass::MatrixShape<${shape_row}, ${shape_column}>,
425
- ${epilogue_functor},
426
- cutlass::reduction::thread::ReduceAdd<
427
- ${element_accumulator},
428
- ${element_output},
429
- ${count}>,
430
- ${partition_per_stage}>;
431
-
432
- struct ${operation_name}${operation_suffix}:
433
- public ${operation_name}_base { };
434
- """
435
-
436
- def emit(self, operation: ReductionOperation):
437
- vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
438
- epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element]
439
-
440
- values = {
441
- "operation_name": operation.configuration_name(),
442
- "operation_suffix": self.operation_suffix,
443
- "shape_row": str(operation.shape.row),
444
- "shape_column": str(operation.shape.column),
445
- "epilogue_functor": operation.epilogue_functor.emit(),
446
- "element_output": DataTypeTag[operation.element_output],
447
- "epilogue_vector_length": str(epilogue_vector_length),
448
- "element_accumulator": DataTypeTag[operation.element_accumulator],
449
- "element_compute": DataTypeTag[operation.element_compute],
450
- "element_workspace": DataTypeTag[operation.element_workspace],
451
- "count": str(operation.count),
452
- "partition_per_stage": str(operation.partitions_per_stage),
453
- }
454
-
455
- return SubstituteTemplate(self.template, values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py DELETED
@@ -1,35 +0,0 @@
1
- ################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- ################################################################################
32
-
33
- GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]"
34
-
35
- Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py DELETED
@@ -1,33 +0,0 @@
1
- ################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- ################################################################################
32
-
33
- from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py DELETED
@@ -1,126 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Utility functions for interacting with the device
35
- """
36
- from __future__ import annotations
37
-
38
- from cutlass_cppgen.utils.lazy_import import lazy_import
39
- cuda = lazy_import("cuda.cuda")
40
- cudart = lazy_import("cuda.cudart")
41
-
42
- import cutlass_cppgen
43
- from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
44
-
45
-
46
- def check_cuda_errors(result: list):
47
- """
48
- Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
49
- returns the result contained in the remaining fields of `result`.
50
-
51
- :param result: the results of the `cudart` method, consisting of an error code and any method results
52
- :type result: list
53
-
54
- :return: non-error-code results from the `results` parameter
55
- """
56
- # `result` is of the format : (cudaError_t, result...)
57
- err = result[0]
58
- if err.value:
59
- raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err)))
60
-
61
- if len(result) == 1:
62
- return None
63
- elif len(result) == 2:
64
- return result[1]
65
- else:
66
- return result[1:]
67
-
68
-
69
- def device_cc(device: int = -1) -> int:
70
- """
71
- Returns the compute capability of the device with ID `device`.
72
-
73
- :param device: ID of the device to query
74
- :type device: int
75
-
76
- :return: compute capability of the queried device (e.g., 80 for SM80)
77
- :rtype: int
78
- """
79
- if device == -1:
80
- device = cutlass_cppgen.device_id()
81
-
82
- deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
83
- major = str(deviceProp.major)
84
- minor = str(deviceProp.minor)
85
- return int(major + minor)
86
-
87
-
88
- def device_sm_count(device: int = -1):
89
- if device == -1:
90
- device = cutlass_cppgen.device_id()
91
- err, device_sm_count = cuda.cuDeviceGetAttribute(
92
- cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device
93
- )
94
- if err != cuda.CUresult.CUDA_SUCCESS:
95
- raise Exception(
96
- "Failed to retireve SM count. "
97
- f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}"
98
- )
99
-
100
- return device_sm_count
101
-
102
-
103
- def to_device_ptr(tensor) -> cuda.CUdeviceptr:
104
- """
105
- Converts a tensor to a CUdeviceptr
106
-
107
- :param tensor: tensor to convert
108
- :type tensor: np.ndarray | torch.Tensor | cp.ndarray | int
109
-
110
- :return: device pointer
111
- :rtype: cuda.CUdeviceptr
112
- """
113
- if is_numpy_tensor(tensor):
114
- ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
115
- elif is_torch_tensor(tensor):
116
- ptr = cuda.CUdeviceptr(tensor.data_ptr())
117
- elif is_cupy_tensor(tensor):
118
- ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
119
- elif isinstance(tensor, cuda.CUdeviceptr):
120
- ptr = tensor
121
- elif isinstance(tensor, int):
122
- ptr = cuda.CUdeviceptr(tensor)
123
- else:
124
- raise NotImplementedError(tensor)
125
-
126
- return ptr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py DELETED
@@ -1,33 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- from cutlass_cppgen.emit.pytorch import pytorch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py DELETED
@@ -1,267 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Common utilities for emitting CUTLASS kernels
35
- """
36
-
37
- import cutlass_cppgen
38
-
39
- # Strings used for printing information about the generation of emitted scripts
40
- _AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
41
-
42
-
43
- _CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR}
44
- """
45
-
46
-
47
- _PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR}
48
- """
49
-
50
- _CUTLASS_KERNEL_ARGS_2x = """
51
- typename DeviceKernel::Arguments arguments {
52
- cutlass::gemm::GemmUniversalMode::kGemm,
53
- {M, N, K}, // problem size
54
- 1,
55
- {alpha, beta},
56
- A, B, C, D,
57
- 0, 0, 0, 0, // batch strides
58
- DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
59
- DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
60
- DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
61
- DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
62
- };
63
- """
64
-
65
- _CUTLASS_KERNEL_ARGS_2x_STREAM_K = """
66
- typename DeviceKernel::Arguments arguments {
67
- cutlass::gemm::GemmUniversalMode::kGemm,
68
- {M, N, K}, // problem size
69
- 1,
70
- {alpha, beta},
71
- A, B, C, D,
72
- 0, 0, 0, 0, // batch strides
73
- DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
74
- DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
75
- DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
76
- DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
77
- -1 // avail_sms
78
- };
79
- """
80
-
81
- _CUTLASS_KERNEL_RUN_GEMM_2x = """
82
- using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
83
-
84
- cutlass::Status ${name}_kernel_run(int M, int N, int K,
85
- const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
86
- ElementCompute alpha, ElementCompute beta) {
87
- ${args}
88
- size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
89
- cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
90
-
91
- DeviceKernel gemm_op;
92
- cutlass::Status status = gemm_op.initialize(arguments,
93
- workspace.get(),
94
- nullptr); // CUDA stream
95
-
96
- if (status != cutlass::Status::kSuccess) {
97
- return status;
98
- }
99
-
100
- status = gemm_op();
101
- return status;
102
- }
103
- """
104
-
105
- _CUTLASS_KERNEL_RUN_GEMM_3x = """
106
- using StrideA = typename DeviceKernel::GemmKernel::StrideA;
107
- using StrideB = typename DeviceKernel::GemmKernel::StrideB;
108
- using StrideC = typename DeviceKernel::GemmKernel::StrideC;
109
- using StrideD = typename DeviceKernel::GemmKernel::StrideD;
110
-
111
- using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
112
-
113
- cutlass::Status ${name}_kernel_run(
114
- int M, int N, int K, int L,
115
- const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
116
- ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) {
117
-
118
- typename DeviceKernel::Arguments arguments{
119
- cutlass::gemm::GemmUniversalMode::kGemm,
120
- {M, N, K, L}, // problem size
121
- {
122
- A, // ptrA
123
- cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
124
- B, // ptrB
125
- cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
126
- },
127
- {
128
- {alpha, beta},
129
- C, // ptrC
130
- cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
131
- D, // ptrD
132
- cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
133
- },
134
- hw_info
135
- };
136
-
137
- size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
138
- cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
139
-
140
- DeviceKernel gemm_op;
141
- cutlass::Status status = gemm_op.run(arguments,
142
- workspace.get(),
143
- nullptr); // CUDA stream
144
-
145
- return status;
146
- }
147
- """
148
-
149
-
150
- _CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """
151
- using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
152
-
153
- int threadblock_count = DeviceKernel::sufficient();
154
-
155
- cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes,
156
- DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D,
157
- int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd,
158
- ElementCompute alpha, ElementCompute beta) {
159
-
160
- typename DeviceKernel::Arguments arguments {
161
- problem_sizes,
162
- problem_count,
163
- threadblock_count,
164
- {alpha, beta},
165
- A, B, C, D,
166
- lda, ldb, ldc, ldd
167
- };
168
-
169
- size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
170
- cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
171
-
172
- DeviceKernel gemm_op;
173
- cutlass::Status status = gemm_op.initialize(arguments,
174
- workspace.get(),
175
- nullptr); // CUDA stream
176
-
177
- if (status != cutlass::Status::kSuccess) {
178
- return status;
179
- }
180
-
181
- status = gemm_op();
182
- return status;
183
- }
184
- """
185
-
186
-
187
- _CUTLASS_KERNEL_RUN_CONV2D_2x = """
188
-
189
- using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
190
- namespace {
191
- using TensorRefA = typename UnderlyingKernel::TensorRefA;
192
- using TensorRefB = typename UnderlyingKernel::TensorRefB;
193
- using TensorRefC = typename UnderlyingKernel::TensorRefC;
194
- using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
195
- }
196
-
197
- template<typename TensorRef, typename Element>
198
- TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){
199
- cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
200
- TensorRef tensor_ref(ptr, layout);
201
- return tensor_ref;
202
- }
203
-
204
- cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size,
205
- UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B,
206
- UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D,
207
- ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
208
- cudaStream_t stream, int device_id=0) {
209
- // create the tensor references
210
- cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
211
- cutlass::conv::Operator::k${conv_kind_name}, *problem_size
212
- );
213
- cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
214
- cutlass::conv::Operator::k${conv_kind_name}, *problem_size
215
- );
216
- cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
217
- cutlass::conv::Operator::k${conv_kind_name}, *problem_size
218
- );
219
-
220
- TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
221
- TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
222
- TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
223
- TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
224
-
225
- cutlass::conv::SplitKMode mode;
226
- if (split_k_mode == "serial") {
227
- mode = cutlass::conv::SplitKMode::kSerial;
228
- } else if (split_k_mode == "parallel") {
229
- mode = cutlass::conv::SplitKMode::kParallel;
230
- } else {
231
- throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
232
- }
233
-
234
- typename DeviceKernel::Arguments arguments{
235
- *problem_size,
236
- tensor_ref_A,
237
- tensor_ref_B,
238
- tensor_ref_C,
239
- tensor_ref_D,
240
- {alpha, beta},
241
- mode
242
- };
243
-
244
- DeviceKernel implicit_gemm_op;
245
-
246
- size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
247
-
248
- void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
249
-
250
- cutlass::Status status = implicit_gemm_op.can_implement(arguments);
251
- if (status != cutlass::Status::kSuccess) {
252
- return status;
253
- }
254
-
255
- status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
256
- if (status != cutlass::Status::kSuccess) {
257
- return status;
258
- }
259
-
260
- //
261
- // Launch initialized CUTLASS kernel
262
- //
263
- status = implicit_gemm_op(stream);
264
-
265
- return status;
266
- }
267
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py DELETED
@@ -1,936 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.
35
- If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.
36
-
37
- Example usage with JIT compilation:
38
-
39
- .. highlight:: python
40
- .. code-block:: python
41
-
42
- plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
43
- op = plan.construct()
44
- mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
45
-
46
- # Generate inputs for the GEMM
47
- A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
48
-
49
- # Run the module
50
- D = mod.run(A, B, C)
51
-
52
-
53
- Example usage without JIT compilation:
54
-
55
- .. highlight:: python
56
- .. code-block:: python
57
-
58
- plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
59
- op = plan.construct()
60
- cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
61
-
62
- After this call, the directory ``output`` contains ``setup.py``,
63
- ``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
64
- within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.
65
-
66
- The module can later be used in Python via:
67
-
68
- .. highlight:: python
69
- .. code-block:: python
70
-
71
- import torch
72
- import cutlass_gemm
73
-
74
- # Generate inputs for the GEMM
75
- A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
76
-
77
- # Run the module
78
- D = cutlass_gemm.run(A, B, C)
79
- """
80
-
81
- import logging
82
- import os
83
-
84
- from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate
85
-
86
- from cutlass_cppgen import CUTLASS_PATH, logger, swizzle
87
- from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
88
- from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation
89
- from cutlass_cppgen.backend.library import ApiVersion
90
- from cutlass_cppgen.emit import common
91
- from cutlass_cppgen.utils.datatypes import is_torch_available
92
-
93
- if is_torch_available():
94
- import torch
95
-
96
-
97
- _PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
98
- #include <cuda_runtime.h>
99
- #include <torch/extension.h>
100
- #include <ATen/ATen.h>
101
- #include <ATen/cuda/CUDAContext.h>
102
- #include "cutlass/cutlass.h"
103
- #include "cutlass/util/device_memory.h"
104
-
105
- // helper function allocating the memory
106
- void* device_memory_allocation(size_t size, int device_id=0) {
107
- if (size > 0) {
108
- torch::Device device(torch::kCUDA, device_id);
109
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
110
- torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
111
- at::Tensor device_tensor = torch::empty({(long)size,}, options);
112
- return reinterpret_cast<void*>(device_tensor.data_ptr());
113
- } else {
114
- return nullptr;
115
- }
116
- }
117
-
118
- ${includes}
119
- ${declaration}
120
- ${impl}
121
- """
122
-
123
- _PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
124
- #include <torch/extension.h>
125
- #include <ATen/ATen.h>
126
- #include <pybind11/stl.h>
127
-
128
- // CUDA forward declarations
129
- at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);
130
-
131
- // C++ interface
132
- at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {
133
- return ${name}_kernel(A, B, C, alpha, beta);
134
- }
135
-
136
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
137
- m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
138
- }
139
- """
140
-
141
- _PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
142
- #include <torch/extension.h>
143
- #include <ATen/ATen.h>
144
- #include <pybind11/stl.h>
145
-
146
- // CUDA forward declarations
147
- std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);
148
-
149
- // C++ interface
150
- std::vector<at::Tensor> ${name}(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {
151
- return ${name}_kernel(A, B, C, alpha, beta);
152
- }
153
-
154
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
155
- m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&${name}),
156
- py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
157
- }
158
- """
159
-
160
- _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
161
- #include <torch/extension.h>
162
- #include <ATen/ATen.h>
163
- #include <pybind11/stl.h>
164
-
165
- // CUDA forward declarations
166
- at::Tensor ${name}_kernel(
167
- const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
168
- std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
169
- float alpha=1.f, float beta=0.f,
170
- std::string split_k_mode="serial", int split_k_slices=1);
171
-
172
- // C++ interface
173
- at::Tensor ${name}(
174
- const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
175
- std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
176
- float alpha=1.f, float beta=0.f,
177
- std::string split_k_mode="serial", int split_k_slices=1) {
178
- return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
179
- }
180
-
181
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
182
- m.def("run",
183
- py::overload_cast<
184
- const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
185
- std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
186
- &${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
187
- py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
188
- py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
189
- py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
190
- }
191
- """
192
-
193
- _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
194
- #include <torch/extension.h>
195
- #include <ATen/ATen.h>
196
- #include <pybind11/stl.h>
197
-
198
- // CUDA forward declarations
199
- at::Tensor ${name}_kernel(
200
- std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
201
- std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
202
- float alpha=1.f, float beta=0.f,
203
- std::string split_k_mode="serial", int split_k_slices=1);
204
-
205
- // C++ interface
206
- at::Tensor ${name}(
207
- std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
208
- std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
209
- float alpha=1.f, float beta=0.f,
210
- std::string split_k_mode="serial", int split_k_slices=1) {
211
- return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
212
- }
213
-
214
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
215
- m.def("run",
216
- py::overload_cast<
217
- std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
218
- std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
219
- &${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
220
- py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
221
- py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
222
- py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
223
- }
224
- """
225
-
226
- _PYTORCH_GEMM_INCLUDES = {
227
- ApiVersion.v2x: """
228
- #include "cutlass/gemm/device/gemm_universal.h"
229
- """,
230
- ApiVersion.v3x: """
231
- #include "cutlass/gemm/device/gemm_universal_adapter.h"
232
- #include "cutlass/gemm/collective/collective_builder.hpp"
233
- #include "cutlass/gemm/device/gemm_universal_adapter.h"
234
- #include "cutlass/gemm/kernel/gemm_universal.hpp"
235
- #include "cutlass/epilogue/collective/collective_builder.hpp"
236
- #include "cutlass/util/packed_stride.hpp"
237
- """,
238
- }
239
-
240
- _PYTORCH_GROUPED_GEMM_INCLUDES = """
241
- #include "cutlass/gemm/kernel/default_gemm_grouped.h"
242
- #include "cutlass/gemm/device/gemm_grouped.h"
243
- """
244
-
245
- _PYTORCH_CONV2D_INCLUDES = """
246
- #include "cutlass/conv/kernel/default_conv2d_fprop.h"
247
- #include "cutlass/conv/kernel/default_conv2d_dgrad.h"
248
- #include "cutlass/conv/kernel/default_conv2d_wgrad.h"
249
- #include "cutlass/conv/device/implicit_gemm_convolution.h"
250
- """
251
-
252
- _CUTLASS_TYPE_TO_TORCH_TYPE = {
253
- DataType.f16: "torch::kF16",
254
- DataType.f32: "torch::kF32",
255
- DataType.f64: "torch::kF64",
256
- DataType.s8: "torch::kI8",
257
- DataType.s32: "torch::kI32",
258
- DataType.bf16: "torch::kBFloat16",
259
- }
260
-
261
- _PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
262
- common._CUTLASS_KERNEL_RUN_GEMM_2x
263
- + """
264
- at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
265
- int M = A.size(0);
266
- int N = B.size(1);
267
- int K = A.size(1);
268
-
269
- typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
270
- nullptr :
271
- reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
272
- at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
273
-
274
- cutlass::Status status = ${name}_kernel_run(M, N, K,
275
- reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
276
- reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
277
- ptrC,
278
- reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
279
- ElementCompute(alpha), ElementCompute(beta));
280
-
281
- TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
282
- return D;
283
- }
284
- """
285
- )
286
-
287
- _PYTORCH_GEMM_IMPL_TEMPLATE_3x = (
288
- common._CUTLASS_KERNEL_RUN_GEMM_3x
289
- + """
290
- bool hw_info_queried = false;
291
- cutlass::KernelHardwareInfo hw_info;
292
-
293
- at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
294
- int M = A.size(0);
295
- int N = B.size(1);
296
- int K = A.size(1);
297
- int L = 1;
298
-
299
- // Query hardware info if we haven't already
300
- if (!hw_info_queried) {
301
- hw_info.device_id = 0;
302
- hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
303
- }
304
-
305
- typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
306
- nullptr :
307
- reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
308
- at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
309
-
310
- cutlass::Status status = ${name}_kernel_run(M, N, K, L,
311
- reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
312
- reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
313
- ptrC,
314
- reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
315
- ElementCompute(alpha), ElementCompute(beta),
316
- hw_info);
317
-
318
- TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
319
- return D;
320
- }
321
- """
322
- )
323
-
324
-
325
- _PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = (
326
- common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x
327
- + """
328
- std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {
329
- size_t num = A.size();
330
-
331
- // To avoid performing many small cudaMallocs and host-to-device copies,
332
- // we serialize the grouped GEMM arguments on the host, allocate one
333
- // large chunk of device memory, and perform a single cudaMemcpy to
334
- // copy the host data to the device. Allocation overheads could be
335
- // avoided by using a memory pool.
336
-
337
- // Calculate the total size of the data to be copied from host to device
338
- size_t total_size = sizeof(cutlass::gemm::GemmCoord) +
339
- sizeof(DeviceKernel::ElementA*) +
340
- sizeof(DeviceKernel::ElementB*) +
341
- sizeof(DeviceKernel::ElementC*) +
342
- sizeof(DeviceKernel::ElementC*) +
343
- sizeof(int64_t) +
344
- sizeof(int64_t) +
345
- sizeof(int64_t);
346
- total_size *= num;
347
-
348
- // num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple
349
- // of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).
350
- // To ensure that we don't end up having misaligned loads in the kernel,
351
- // we pad to the nearest multiple of 8.
352
- //
353
- // Note that, even on a 32-bit system (for which sizeof(X*) will not equal
354
- // sizeof(int64_t)), only padding between the list of GemmCoords and the
355
- // list of ptr_As is sufficient because the set of four equal-length lists of pointers
356
- // (A*, B*, C*, D*) will ensure that the first list of int64_ts will always
357
- // start on a multiple of 8.
358
- int64_t padding = 8 - (total_size % 8);
359
- total_size += padding;
360
-
361
- uint8_t* host_data = new uint8_t[total_size];
362
- cutlass::DeviceAllocation<uint8_t> device_data(total_size);
363
-
364
- uint8_t* start = host_data;
365
- cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);
366
-
367
- // Apply the padding after the list of GemmCoords
368
- start += num * sizeof(cutlass::gemm::GemmCoord) + padding;
369
-
370
- int64_t ptr_A_offset = start - host_data;
371
- DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);
372
- start += num * sizeof(DeviceKernel::ElementA*);
373
-
374
- int64_t ptr_B_offset = start - host_data;
375
- DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);
376
- start += num * sizeof(DeviceKernel::ElementB*);
377
-
378
- int64_t ptr_C_offset = start - host_data;
379
- DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
380
- start += num * sizeof(DeviceKernel::ElementC*);
381
-
382
- int64_t ptr_D_offset = start - host_data;
383
- DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
384
- start += num * sizeof(DeviceKernel::ElementC*);
385
-
386
- int64_t lda_offset = start - host_data;
387
- int64_t* lda_host = reinterpret_cast<int64_t*>(start);
388
- start += num * sizeof(int64_t);
389
-
390
- int64_t ldb_offset = start - host_data;
391
- int64_t* ldb_host = reinterpret_cast<int64_t*>(start);
392
- start += num * sizeof(int64_t);
393
-
394
- int64_t ldc_offset = start - host_data;
395
- int64_t* ldc_host = reinterpret_cast<int64_t*>(start);
396
- start += num * sizeof(int64_t);
397
-
398
- std::vector<at::Tensor> D(num);
399
-
400
- bool need_C = (C != at::nullopt) && (beta != 0.f);
401
- for (size_t i = 0; i < num; ++i) {
402
- int M = A[i].size(0);
403
- int N = B[i].size(1);
404
- int K = A[i].size(1);
405
- *(problem_sizes_host + i) = {M, N, K};
406
- *(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());
407
- *(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());
408
-
409
- if (need_C) {
410
- *(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());
411
- }
412
- else {
413
- *(ptr_C_host + i) = nullptr;
414
- }
415
-
416
- D[i] = B[i].new_empty({M, N}, ${torch_type_C});
417
- *(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());
418
-
419
- *(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);
420
- *(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);
421
- *(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);
422
- }
423
-
424
- device_data.copy_from_host(host_data);
425
-
426
- cutlass::Status status = ${name}_kernel_run(
427
- num,
428
- reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),
429
- reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),
430
- reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),
431
- reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),
432
- reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),
433
- reinterpret_cast<int64_t*>(device_data.get() + lda_offset),
434
- reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),
435
- reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
436
- reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
437
- ElementCompute(alpha), ElementCompute(beta));
438
-
439
- delete[] host_data;
440
-
441
- TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
442
- return D;
443
- }
444
- """
445
- )
446
-
447
- _PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
448
- cudaStream_t stream = at::cuda::getCurrentCUDAStream();
449
-
450
- cutlass::Status status = ${name}_kernel_run(
451
- &problem_size,
452
- reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
453
- reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
454
- ptrC,
455
- reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
456
- alpha, beta,
457
- split_k_mode, stream, B.device().index());
458
-
459
- TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
460
- return D;
461
- }
462
- """
463
-
464
- _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
465
- common._CUTLASS_KERNEL_RUN_CONV2D_2x
466
- + """
467
- at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
468
- std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
469
- float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
470
- int N, H, W, C_, K, R, S, P, Q;
471
- N = A.size(0);
472
- C_ = A.size(1);
473
- H = A.size(2);
474
- W = A.size(3);
475
-
476
- K = B.size(0);
477
- R = B.size(2);
478
- S = B.size(3);
479
-
480
- cutlass::conv::Conv2dProblemSize problem_size(
481
- cutlass::Tensor4DCoord(N, H, W, C_),
482
- cutlass::Tensor4DCoord(K, R, S, C_),
483
- cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
484
- cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
485
- cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
486
- cutlass::conv::Mode::kCrossCorrelation,
487
- split_k_slices
488
- );
489
-
490
- P = problem_size.P;
491
- Q = problem_size.Q;
492
-
493
- typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
494
- nullptr :
495
- reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
496
-
497
- torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
498
- at::Tensor D = torch::zeros({N, K, P, Q}, options);
499
- """ + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
500
- )
501
-
502
-
503
- _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
504
- common._CUTLASS_KERNEL_RUN_CONV2D_2x
505
- + """
506
- at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
507
- std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
508
- std::string split_k_mode="serial", int split_k_slices=1) {
509
- int N, H, W, C_, K, R, S;
510
- N = std::get<0>(input_size);
511
- C_ = std::get<1>(input_size);
512
- H = std::get<2>(input_size);
513
- W = std::get<3>(input_size);
514
-
515
- K = B.size(0);
516
- R = B.size(2);
517
- S = B.size(3);
518
-
519
- cutlass::conv::Conv2dProblemSize problem_size(
520
- cutlass::Tensor4DCoord(N, H, W, C_),
521
- cutlass::Tensor4DCoord(K, R, S, C_),
522
- cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
523
- cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
524
- cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
525
- cutlass::conv::Mode::kCrossCorrelation,
526
- split_k_slices
527
- );
528
-
529
- typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
530
- nullptr :
531
- reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
532
-
533
- torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
534
- at::Tensor D = torch::empty({N, C_, H, W}, options);
535
- """ + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
536
- )
537
-
538
-
539
- _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
540
- common._CUTLASS_KERNEL_RUN_CONV2D_2x
541
- + """
542
- at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
543
- std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
544
- std::string split_k_mode="serial", int split_k_slices=1) {
545
- int N, H, W, C_, K, R, S;
546
- K = std::get<0>(weight_size);
547
- C_ = std::get<1>(weight_size);
548
- R = std::get<2>(weight_size);
549
- S = std::get<3>(weight_size);
550
-
551
- N = B.size(0);
552
- H = B.size(2);
553
- W = B.size(3);
554
-
555
- cutlass::conv::Conv2dProblemSize problem_size(
556
- cutlass::Tensor4DCoord(N, H, W, C_),
557
- cutlass::Tensor4DCoord(K, R, S, C_),
558
- cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
559
- cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
560
- cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
561
- cutlass::conv::Mode::kCrossCorrelation,
562
- split_k_slices
563
- );
564
-
565
- typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
566
- nullptr :
567
- reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
568
-
569
- torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
570
- at::Tensor D = torch::empty({K, C_, R, S}, options);
571
- """ + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
572
- )
573
-
574
-
575
- _PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
576
- from setuptools import setup
577
- from torch.utils.cpp_extension import BuildExtension, CUDAExtension
578
-
579
- setup(
580
- name='${name}',
581
- ext_modules=[
582
- CUDAExtension('${name}', [
583
- '${name}.cpp',
584
- '${name}_kernel.cu',
585
- ],
586
- include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
587
- extra_compile_args={
588
- 'cxx': ['-std=c++17'],
589
- 'nvcc': ['-std=c++17', ${extra_compile_args}],
590
- },
591
- libraries=['cuda']
592
- ),
593
- ],
594
- cmdclass={
595
- 'build_ext': BuildExtension
596
- })
597
-
598
- """
599
-
600
-
601
- def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""):
602
- """
603
- Generates a setup.py file for the extension
604
-
605
- :param name: name of the module to generate
606
- :type name: str
607
- :param sourcedir: directory to which generated source files should be written
608
- :type sourcedir: str
609
- :param extra_compile_args: additional arguments to pass to setup.py
610
- :type extra_args: str
611
- """
612
- setup_py_file = os.path.join(sourcedir, "setup.py")
613
- setup_source = SubstituteTemplate(
614
- _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args}
615
- )
616
- with open(setup_py_file, "w") as outfile:
617
- outfile.write(setup_source)
618
-
619
-
620
- class _ArchListSetter:
621
- """
622
- Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``
623
- environment variable when building a PyTorch CUDA module.
624
-
625
- ``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch
626
- CUDA module should be compiled.
627
-
628
- For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of
629
- ``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the
630
- compilation of the module.
631
-
632
- This utility wraps the building of a PyTorch CUDA module with a setting of this environment
633
- variable according to the current compute capability being targetted.
634
-
635
- Example usage:
636
-
637
- .. highlight:: python
638
- .. code-block:: python
639
-
640
- # Temporarily set TORCH_CUDA_ARCH_LIST="8.0"
641
- with _ArchListSetter(80):
642
- # Perform JIT compilation and loading of the module
643
- mod = torch.utils.cpp_extension.load(...)
644
-
645
- :param cc: compute capability
646
- :type cc: int
647
- """
648
-
649
- _TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST"
650
-
651
- def __init__(self, cc: int):
652
- self.cc_str = ".".join(list(str(cc)))
653
-
654
- def __enter__(self):
655
- """
656
- Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``
657
- """
658
- self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST)
659
- os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str
660
-
661
- return self
662
-
663
- def __exit__(self, exc_type, exc_val, traceback):
664
- """
665
- Restores the old value of TORCH_CUDA_ARCH_LIST
666
- """
667
- if self.old_arch_list is None:
668
- del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST]
669
- else:
670
- os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list
671
-
672
-
673
- def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
674
- """
675
- JIT compiles and loads a PyTorch CUDA extension.
676
-
677
- :param name: name of the module to generate
678
- :type name: str
679
- :param cc: compute capability of the device the module should target
680
- :type cc: int
681
- :param cpp_file: path to file containing extension's C++ interface
682
- :type cpp_file: str
683
- :param cuda_file: path to file containing extension's CUDA interface
684
- :type cuda_file: str
685
-
686
- :return: loaded PyTorch module
687
- """
688
-
689
- from torch.utils.cpp_extension import load
690
-
691
- extra_cuda_cflags = ["-std=c++17"]
692
- if cc in [90, 100, 101, 103]:
693
- # PyTorch does not currently add the sm_90a target when compute capability
694
- # 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.
695
- extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a")
696
-
697
- with _ArchListSetter(cc):
698
- jitmodule = load(
699
- name,
700
- [cpp_file, cuda_file],
701
- extra_cuda_cflags=extra_cuda_cflags,
702
- extra_include_paths=[
703
- os.path.join(CUTLASS_PATH, "include"),
704
- os.path.join(CUTLASS_PATH, "tools/util/include"),
705
- ],
706
- extra_ldflags=["-lcuda"],
707
- verbose=(logger.level == logging.DEBUG)
708
- )
709
- return jitmodule
710
-
711
-
712
- def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
713
- """
714
- Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM
715
- specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
716
- compiled, loaded, and returned.
717
-
718
- :param op: operation to emit in the module
719
- :param name: name of the module to generate
720
- :type name: str
721
- :param cc: compute capability of the device the module should target
722
- :type cc: int
723
- :param jit: whether the module should be just-in-time compiled
724
- :type jit: bool
725
- :param sourcedir: directory to which generated source files should be written
726
- :type sourcedir: str
727
-
728
- :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
729
- """
730
- if sourcedir != "" and not os.path.isdir(sourcedir):
731
- os.makedirs(sourcedir)
732
-
733
- cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
734
- extra_kw = {}
735
- if op.api == ApiVersion.v3x:
736
- impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
737
- else:
738
- impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
739
- if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
740
- extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
741
- else:
742
- extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
743
- impl_template = (
744
- _PYTORCH_GEMM_IMPL_TEMPLATE_3x
745
- if op.api == ApiVersion.v3x
746
- else _PYTORCH_GEMM_IMPL_TEMPLATE_2x
747
- )
748
- cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
749
- cuda_source = SubstituteTemplate(
750
- _PYTORCH_CUDA_TEMPLATE,
751
- {
752
- "includes": _PYTORCH_GEMM_INCLUDES[op.api],
753
- "declaration": op.rt_module.emit(),
754
- "procedural_name": op.procedural_name(),
755
- "impl": cuda_impl,
756
- "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
757
- },
758
- )
759
- with open(cuda_file, "w") as outfile:
760
- outfile.write(cuda_source)
761
-
762
- cpp_file = os.path.join(sourcedir, name + ".cpp")
763
- cpp_source = SubstituteTemplate(
764
- _PYTORCH_GEMM_CPP_TEMPLATE,
765
- {"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"},
766
- )
767
- with open(cpp_file, "w") as outfile:
768
- outfile.write(cpp_source)
769
-
770
- extra_compile_args = ""
771
- if cc in [90, 100, 101, 103]:
772
- extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'"
773
- _generate_setup(name, sourcedir, extra_compile_args)
774
-
775
- if jit:
776
- return _jit(name, cc, cpp_file, cuda_file)
777
-
778
- return None
779
-
780
-
781
- def _pytorch_grouped_gemm(
782
- op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
783
- ):
784
- """
785
- Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM
786
- specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
787
- compiled, loaded, and returned.
788
-
789
- :param op: operation to emit in the module
790
- :param name: name of the module to generate
791
- :type name: str
792
- :param cc: compute capability of the device the module should target
793
- :type cc: int
794
- :param jit: whether the module should be just-in-time compiled
795
- :type jit: bool
796
- :param sourcedir: directory to which generated source files should be written
797
- :type sourcedir: str
798
-
799
- :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
800
- """
801
- if op.api != ApiVersion.v2x:
802
- raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x")
803
-
804
- if sourcedir != "" and not os.path.isdir(sourcedir):
805
- os.makedirs(sourcedir)
806
-
807
- cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
808
- cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name})
809
- cuda_source = SubstituteTemplate(
810
- _PYTORCH_CUDA_TEMPLATE,
811
- {
812
- "includes": _PYTORCH_GROUPED_GEMM_INCLUDES,
813
- "declaration": op.rt_module.emit(),
814
- "procedural_name": op.procedural_name(),
815
- "impl": cuda_impl,
816
- "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
817
- },
818
- )
819
- with open(cuda_file, "w") as outfile:
820
- outfile.write(cuda_source)
821
-
822
- cpp_file = os.path.join(sourcedir, name + ".cpp")
823
- cpp_source = SubstituteTemplate(
824
- _PYTORCH_GROUPED_GEMM_CPP_TEMPLATE,
825
- {"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"},
826
- )
827
- with open(cpp_file, "w") as outfile:
828
- outfile.write(cpp_source)
829
-
830
- _generate_setup(name, sourcedir)
831
-
832
- if jit:
833
- return _jit(name, cc, cpp_file, cuda_file)
834
-
835
- return None
836
-
837
-
838
- def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
839
- """
840
- Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d
841
- specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
842
- compiled, loaded, and returned.
843
-
844
- :param op: operation to emit in the module
845
- :param name: name of the module to generate
846
- :type name: str
847
- :param cc: compute capability of the device the module should target
848
- :type cc: int
849
- :param jit: whether the module should be just-in-time compiled
850
- :type jit: bool
851
- :param sourcedir: directory to which generated source files should be written
852
- :type sourcedir: str
853
-
854
- Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
855
- weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
856
- for H/W/R/S given the same P/Q.
857
-
858
- :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
859
- """
860
- if sourcedir != "" and not os.path.isdir(sourcedir):
861
- os.makedirs(sourcedir)
862
- cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
863
- extra_kw = {}
864
- if op.conv_kind == ConvKind.Fprop:
865
- impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
866
- cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
867
- elif op.conv_kind == ConvKind.Dgrad:
868
- impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
869
- cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
870
- elif op.conv_kind == ConvKind.Wgrad:
871
- impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
872
- cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
873
- extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
874
- extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element]
875
- cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
876
- cuda_source = SubstituteTemplate(
877
- _PYTORCH_CUDA_TEMPLATE,
878
- {
879
- "includes": _PYTORCH_CONV2D_INCLUDES,
880
- "declaration": op.rt_module.emit(),
881
- "procedural_name": op.procedural_name(),
882
- "impl": cuda_impl,
883
- "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
884
- },
885
- )
886
- with open(cuda_file, "w") as outfile:
887
- outfile.write(cuda_source)
888
-
889
- cpp_file = os.path.join(sourcedir, name + ".cpp")
890
- cpp_source = SubstituteTemplate(
891
- cpp_template,
892
- {"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"},
893
- )
894
- with open(cpp_file, "w") as outfile:
895
- outfile.write(cpp_source)
896
-
897
- _generate_setup(name, sourcedir)
898
-
899
- if jit:
900
- return _jit(name, cc, cpp_file, cuda_file)
901
-
902
- return None
903
-
904
-
905
- def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
906
- """
907
- Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
908
- specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
909
- compiled, loaded, and returned.
910
-
911
- The result of this method is files within ``sourcedir`` that can be used for building
912
- a PyTorch module.
913
-
914
- :param op: operation to emit in the module
915
- :param name: name of the module to generate
916
- :type name: str
917
- :param cc: compute capability of the device the module should target
918
- :type cc: int
919
- :param jit: whether the module should be just-in-time compiled
920
- :type jit: bool
921
- :param sourcedir: directory to which generated source files should be written
922
- :type sourcedir: str
923
-
924
- :return: loaded PyTorch module (if ``jit=True``) or None
925
- """
926
- device_op = op.device_op()
927
- if isinstance(op, GemmOperationUniversal):
928
- return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
929
- elif isinstance(op, GemmOperationGrouped):
930
- return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
931
- elif isinstance(op, Conv2dOperation):
932
- return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
933
- else:
934
- raise Exception(
935
- f"Operation type {type(op)} is not currently supported for PyTorch emission."
936
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py DELETED
@@ -1,56 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- from cutlass_cppgen.epilogue.epilogue import (
34
- get_activations,
35
- get_activation_epilogue,
36
- gelu,
37
- hardswish,
38
- identity,
39
- leaky_relu,
40
- relu,
41
- sigmoid,
42
- silu,
43
- tanh,
44
- trace
45
- )
46
-
47
- from cutlass_cppgen.epilogue.evt_ops import (
48
- max,
49
- multiply_add,
50
- sum,
51
- permute,
52
- reshape,
53
- maximum,
54
- minimum,
55
- exp
56
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py DELETED
@@ -1,176 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Registry of elementwise epilogues
35
-
36
- Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via
37
- code like the following for GEMM:
38
-
39
- .. highlight:: python
40
- .. code-block:: python
41
-
42
- plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
43
- plan.activation = cutlass_cppgen.epilogue.relu
44
- """
45
-
46
- from cutlass_cppgen.backend import epilogue, device_cc
47
-
48
-
49
- gelu = epilogue.gelu
50
- hardswish = epilogue.hardswish
51
- identity = epilogue.identity
52
- leaky_relu = epilogue.leaky_relu
53
- relu = epilogue.relu
54
- sigmoid = epilogue.sigmoid
55
- silu = epilogue.silu
56
- tanh = epilogue.tanh
57
-
58
-
59
- _activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh]
60
-
61
-
62
- def get_activations() -> list:
63
- """
64
- Returns a list of available activation functions
65
-
66
- :return: list of available activation functions
67
- :rtype: list
68
- """
69
- return _activations
70
-
71
-
72
- def get_activation_epilogue(
73
- activation,
74
- element_output,
75
- elements_per_access,
76
- element_accumulator,
77
- element_compute,
78
- ):
79
- """
80
- Return an epilogue corresponding to the activation function, data types, and alignment
81
- used in the kernel
82
-
83
- :param activation: elementwise activation function to use
84
- :param element_output: data type of the output
85
- :param elements_per_access: alignment of operand C of the kernel
86
- :type elements_per_access: int
87
- :param element_accumulator: data type of the accumulated output C
88
- :param element_compute: data type in which compute operations should be performed
89
-
90
- :return: epilogue functor
91
- """
92
- if activation not in _activations:
93
- raise Exception(
94
- f"Unsupported activation type {activation}. Available activations are: {_activations}"
95
- )
96
-
97
- if activation == identity:
98
- return epilogue.LinearCombination(
99
- element_output, elements_per_access, element_accumulator, element_compute
100
- )
101
- else:
102
- return epilogue.LinearCombinationGeneric(
103
- activation,
104
- element_output,
105
- elements_per_access,
106
- element_accumulator,
107
- element_compute,
108
- )
109
-
110
-
111
- """
112
- Frontend for EVT that generates epilogue functor through tracing the input function
113
- """
114
- from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
115
-
116
-
117
- def trace(fn, example_tensors, **kwargs):
118
- """
119
- Trace `fn(**example_tensors)` and generates epilogue visitor
120
-
121
- :param fn or str: Python callable or string of the epilogue function
122
- :param example_tensors: example inputs for fn
123
- :type example_tensors: dict
124
-
125
- .. hightlight:: python
126
- .. code-block:: python
127
- import cutlass_cppgen.backend.evt
128
-
129
- # Define epilogue function as Python callable
130
- def example_fn(accum, C, alpha, beta, gamma):
131
- D = ((accum + C) * alpha - gamma) / beta
132
- return D
133
-
134
- # Define the example tensors
135
- example_inputs = {
136
- "accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
137
- "C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
138
- "alpha": 1.5,
139
- "beta": 0.5,
140
- "gamma": 2.5,
141
- "D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda")
142
- }
143
-
144
- # Generate the epilogue functor
145
- epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs)
146
- """
147
- if callable(fn):
148
- class EpilogueFunctor(PythonASTFrontend):
149
- def __init__(self, cc=None, **kwargs):
150
- if not cc:
151
- cc = device_cc()
152
- super().__init__(cc, **kwargs)
153
- pass
154
- setattr(EpilogueFunctor, "__call__", staticmethod(fn))
155
-
156
- epilogue_functor = EpilogueFunctor(**kwargs)
157
- epilogue_functor.trace(example_tensors)
158
- return epilogue_functor
159
- elif isinstance(fn, str):
160
- class EpilogueFunctor(PythonASTFrontend):
161
- def __init__(self, cc=None, **kwargs):
162
- self.source = textwrap.dedent(fn)
163
- if not cc:
164
- cc = device_cc()
165
- super().__init__(cc, **kwargs)
166
-
167
- def parse(self, example_inputs) -> None:
168
- self.example_inputs = example_inputs
169
- self.ast = ast.parse(self.source)
170
- self.visit(self.ast)
171
-
172
- epilogue_functor = EpilogueFunctor(**kwargs)
173
- epilogue_functor.trace(example_tensors)
174
- return epilogue_functor
175
- else:
176
- raise NotImplementedError("Expect a callable Python function")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py DELETED
@@ -1,98 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Collection of builtin functions used for host reference in EVT
35
- """
36
-
37
- import numpy as np
38
-
39
- from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor
40
-
41
- if is_torch_available():
42
- import torch
43
-
44
-
45
- def multiply_add(x, y, z):
46
- return x * y + z
47
-
48
-
49
- def sum(x, dim):
50
- if is_numpy_tensor(x):
51
- return x.sum(axis=tuple(dim))
52
- elif is_torch_tensor(x):
53
- return torch.sum(x, dim)
54
-
55
-
56
- def max(x, dim):
57
- if is_numpy_tensor(x):
58
- return x.max(axis=tuple(dim))
59
- elif is_torch_tensor(x):
60
- return torch.amax(x, dim)
61
-
62
-
63
- def maximum(x, y):
64
- if is_numpy_tensor(x):
65
- return np.maximum(x, y)
66
- elif is_torch_tensor(x):
67
- return torch.maximum(x, torch.tensor(y))
68
-
69
-
70
- def minimum(x, y):
71
- if is_numpy_tensor(x):
72
- return np.minimum(x, y)
73
- elif is_torch_tensor(x):
74
- return torch.minimum(x, torch.tensor(y))
75
-
76
- def exp(x):
77
- if is_numpy_tensor(x):
78
- return np.exp(x)
79
- elif is_torch_tensor(x):
80
- return torch.exp(x)
81
-
82
-
83
- ##############################################################################
84
- # Layout manipulate nodes
85
- ##############################################################################
86
-
87
- def permute(x, indices: tuple):
88
- if is_numpy_tensor(x):
89
- return np.transpose(x, axes=indices)
90
- elif is_torch_tensor(x):
91
- return x.permute(*indices)
92
-
93
-
94
- def reshape(x, new_shape: tuple):
95
- if is_numpy_tensor(x):
96
- return np.reshape(x, newshape=new_shape)
97
- elif is_torch_tensor(x):
98
- return x.view(new_shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py DELETED
@@ -1,569 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Classes containing valid operations for a given compute capability and data types.
35
- """
36
-
37
- from itertools import combinations_with_replacement
38
- import logging
39
-
40
- import cutlass_library
41
- from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
42
-
43
- import cutlass_cppgen
44
- from cutlass_cppgen.utils.check import valid_stage_count
45
- from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op
46
-
47
-
48
- _generator_ccs = [50, 60, 61, 70, 75, 80, 90, 100]
49
-
50
-
51
- class KernelsForDataType:
52
- """
53
- Container class for keeping track of kernels that correspond to a particular combination
54
- of data types for operands A, B, and accumulator
55
- """
56
-
57
- def __init__(self, datatype_comb: tuple, layout_comb: tuple):
58
- self.datatype_comb = datatype_comb
59
- self.layout_comb = layout_comb
60
- self.math_operations = set()
61
-
62
- # Dictionary mapping from alignment (int) to a list of kernels that fit the alignment
63
- # constraint for the data type combination
64
- self.kernels_by_alignment = {}
65
-
66
- def add(self, operation):
67
- """
68
- Add an operation to the list of supported kernels
69
- """
70
- alignment_key = f"{operation.A.alignment} {operation.B.alignment} {operation.C.alignment}"
71
- if alignment_key not in self.kernels_by_alignment:
72
- self.kernels_by_alignment[alignment_key] = []
73
- self.kernels_by_alignment[alignment_key].append(operation)
74
- self.math_operations.add(operation.tile_description.math_instruction.math_operation)
75
-
76
- def alignments(self, operand: str):
77
- """
78
- Returns an unsorted list of alignments supported by this data type combination
79
-
80
- :param operand: identifier of operand in question (e.g., A, B, C)
81
- :type operand: str
82
-
83
- :return: unsorted list of alignments supported by this data type combination
84
- :rtype: list
85
- """
86
- operand_idx = self._operand_idx(operand)
87
- return [int(key.split(" ")[operand_idx]) for key in self.kernels_by_alignment.keys()]
88
-
89
- @property
90
- def all_operations(self):
91
- """
92
- Returns a list of all operations supported by this data type combination
93
-
94
- :return: list of all operations supported by this data type combination
95
- :rtype: list
96
- """
97
- ops = []
98
- for _, alignment_ops in self.kernels_by_alignment.items():
99
- ops.extend(alignment_ops)
100
- return ops
101
-
102
- def default_operation(self, math_operation: cutlass_cppgen.MathOperation):
103
- key = sorted(list(self.kernels_by_alignment.keys()))[0]
104
- kernels = self.kernels_by_alignment[key]
105
- if math_operation is not None:
106
- kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation]
107
- return kernels[0]
108
-
109
- def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation):
110
- """
111
- Returns operations satisfying the alignment constraints
112
-
113
- :param alignment_A: alignment constraint of operations to return
114
- :type alignment_A: int
115
- :param alignment_B: alignment constraint of operations to return
116
- :type alignment_B: int
117
- :param alignment_C: alignment constraint of operations to return
118
- :type alignment_C: int
119
- :param math_operation: math operation to consider
120
- :type math_operation: cutlass_cppgen.MathOperation
121
-
122
- :return: list of operations
123
- :rtype: list
124
- """
125
- key = f"{alignment_A} {alignment_B} {alignment_C}"
126
-
127
- if key not in self.kernels_by_alignment:
128
- og_key = key
129
- # Reconcile A, B, and C alignments by trying to align to the minimum
130
- min_alignment = min(alignment_A, alignment_B, alignment_C)
131
- key = f"{min_alignment} {min_alignment} {min_alignment}"
132
- if key not in self.kernels_by_alignment:
133
- # Finally, go through all available alignment combinations and find
134
- # one for which all values are less than those passed in.
135
- key = None
136
- alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
137
- for align_A, align_B, align_C in alignments:
138
- if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0:
139
- key = f"{align_A} {align_B} {align_C}"
140
- break
141
-
142
- if key is None:
143
- raise Exception(
144
- f"No operations of alignment {og_key} found for data type and layout "
145
- f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments "
146
- f"are {self.kernels_by_alignment.keys()}"
147
- )
148
-
149
- ops = self.kernels_by_alignment[key]
150
- if math_operation is not None:
151
- ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation]
152
- return ops
153
-
154
- def _operand_idx(self, key: str) -> int:
155
- operand_list = ["A", "B", "C"]
156
- if key not in operand_list:
157
- raise Exception(f"Unexpected operand {operand}")
158
-
159
- return operand_list.index(key)
160
-
161
- def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int:
162
- """
163
- Returns the most preferable alignment for a given shape and layout
164
-
165
- :param shape: extent of each dimension of the tensor
166
- :type shape: tuple
167
- :param layout: layout of the tensor
168
- :type layout: cutlass_cppgen.LayoutType
169
- :param operand: descriptor of the operand in question
170
- :type operand: str
171
-
172
- :return: maximum alignment supported by the data type combination and tensor size
173
- :rtype: int
174
- """
175
- operand_idx = self._operand_idx(operand)
176
-
177
- # Determine the leading dimension of the shape
178
- if layout == cutlass_cppgen.LayoutType.ColumnMajor:
179
- ld = shape[-2]
180
- elif layout == cutlass_cppgen.LayoutType.RowMajor:
181
- ld = shape[-1]
182
- elif layout == cutlass_cppgen.LayoutType.TensorNHWC:
183
- ld = shape[-1]
184
- else:
185
- raise Exception(f"Unexpected or unsupported layout {layout}")
186
-
187
- for alignments in sorted(list(self.kernels_by_alignment.keys()), reverse=True):
188
- alignment = int(alignments.split(" ")[operand_idx])
189
- if ld % alignment == 0:
190
- return alignment
191
-
192
- # Default to alignment of 1 if no others match
193
- return 1
194
-
195
- def sort(self):
196
- """
197
- Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape
198
- """
199
- key = lambda op: (
200
- op.tile_description.threadblock_shape[0]
201
- * op.tile_description.threadblock_shape[1]
202
- * op.tile_description.threadblock_shape[2]
203
- )
204
- for alignment in self.kernels_by_alignment.keys():
205
- self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
206
-
207
- def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool:
208
- """
209
- Returns whether `math_operation` is supported by at least one operation.
210
-
211
- :param math_operation: math operation to consider
212
- :type math_operation: cutlass_cppgen.MathOperation
213
-
214
- :return: whether math_operation is supported by at least one operation
215
- :rtype: bool
216
- """
217
- return math_operation is None or math_operation in self.math_operations
218
-
219
-
220
- class ArchOptions:
221
- """
222
- Structure for keeping track of kernels available on a given compute capability
223
-
224
- :param target_cc: compute capability of the device on which kernels will be run
225
- :type target_cc: int
226
- :param kernel_cc: compute capability of the kernels to generate
227
- :type kernel_cc: int
228
- :param operation_kind: type of operation to register
229
- :type operation_kind: cutlass_library.OperationKind
230
- :param gemm_kinds: types of GEMM operations that can be included
231
- :type gemm_kinds: list
232
- :param allowed_math_operations: types of primitive math operations allowed
233
- :type allowed_math_operations: list
234
- """
235
-
236
- def __init__(
237
- self,
238
- target_cc: int,
239
- kernel_cc: int,
240
- operation_kind: cutlass_library.OperationKind,
241
- gemm_kinds: list,
242
- allowed_math_operations: list = [
243
- cutlass_library.MathOperation.multiply_add,
244
- cutlass_library.MathOperation.multiply_add_saturate,
245
- cutlass_library.MathOperation.multiply_add_mixed_input_upcast,
246
- cutlass_library.MathOperation.multiply_add_fast_f32
247
- ]
248
- ):
249
- self.cc = kernel_cc
250
-
251
- # Dictionary with following structure:
252
- # Key: OpcodeClass
253
- # Value: Dictionary with the following structure:
254
- # Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType),
255
- # representing ((element_a, element_b, element_accumulator), (layout_a, layout_b))
256
- # Value: KernelsForDataType
257
- self.operations_by_opclass = {}
258
- self.op_class = None
259
- self.allowed_math_operations = allowed_math_operations
260
-
261
- if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100:
262
- return
263
-
264
- # Identify the method within CUTLASS generator script that generates kernel
265
- # descriptions for the target CC
266
- generate_function_name = "GenerateSM" + str(kernel_cc)
267
- if not hasattr(cutlass_library.generator, generate_function_name):
268
- cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}")
269
- return
270
- generate_function = getattr(cutlass_library.generator, generate_function_name)
271
-
272
- # Initialize a default manifest and populate it with valid kernel descriptions
273
- # for the target CC
274
- args = [
275
- "--kernels=all",
276
- f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}"
277
- ]
278
- manifest_args = cutlass_library.generator.define_parser().parse_args(args)
279
- manifest = cutlass_library.manifest.Manifest(manifest_args)
280
- generate_function(manifest, cutlass_cppgen._nvcc_version)
281
-
282
- if operation_kind not in manifest.operations:
283
- # No kernels generated for this architecture, this could be because the CUDA
284
- # toolkit is insufficient to support operations in this CC
285
- cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
286
- return
287
-
288
- # Only one CC should be returned, given the setup above of calling only the generation scripts
289
- # for a given CC
290
- if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]:
291
- raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version "
292
- "is sufficient for the architecture in question.")
293
-
294
- # Iterate through the available operations for this operation kind and
295
- # find available opclasses and data types
296
- for name, op_list in manifest.operations[operation_kind][kernel_cc].items():
297
- for op in op_list:
298
-
299
- if operation_kind == cutlass_library.OperationKind.Gemm:
300
- if op.gemm_kind not in gemm_kinds:
301
- continue
302
-
303
- mi = op.tile_description.math_instruction
304
- if mi.math_operation not in self.allowed_math_operations:
305
- continue
306
-
307
- # Prune operations that don't fit in shared memory
308
- td = td_from_profiler_op(op)
309
- if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]:
310
- continue
311
-
312
- if mi.opcode_class not in self.operations_by_opclass:
313
- self.operations_by_opclass[mi.opcode_class] = {}
314
-
315
- datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
316
- layout_comb = (op.A.layout, op.B.layout)
317
-
318
- # Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations
319
- if datatype_comb == (cutlass_library.DataType.tf32, cutlass_library.DataType.tf32, cutlass_library.DataType.f32):
320
- # TF32 kernels only supported on SM80 and beyond
321
- if self.cc < 80:
322
- continue
323
- elif self.cc == 90 or self.cc == 100:
324
- if (op.A.element != cutlass_library.DataType.f32
325
- or op.B.element != cutlass_library.DataType.f32
326
- or op.C.element != cutlass_library.DataType.f32):
327
- continue
328
-
329
- datatype_comb = (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32)
330
-
331
- opclass_dict = self.operations_by_opclass[mi.opcode_class]
332
- key = (datatype_comb, layout_comb)
333
- if key not in opclass_dict:
334
- opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb)
335
- opclass_dict[key].add(op)
336
-
337
- # Set the default opclass to TensorOp, if available. Otherwise default to SIMT
338
- if cutlass_library.OpcodeClass.TensorOp in self.operations_by_opclass:
339
- self.op_class = cutlass_library.OpcodeClass.TensorOp
340
- else:
341
- self.op_class = cutlass_library.OpcodeClass.Simt
342
-
343
- # The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels.
344
- # Here, we generate additional versions via a generic TileDescription.
345
- if cutlass_library.OpcodeClass.Simt not in self.operations_by_opclass:
346
- self.operations_by_opclass[cutlass_library.OpcodeClass.Simt] = {}
347
-
348
- if operation_kind == cutlass_library.OperationKind.Gemm:
349
- types = [
350
- (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s8),
351
- (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32),
352
- (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
353
- (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
354
- (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
355
- (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
356
- ]
357
-
358
- # Add FP8 A/B/C
359
- fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2]
360
- for type_comb in combinations_with_replacement(fp8_types, 3):
361
- types.append(type_comb)
362
-
363
- # Add FP8 A/B with FP32 C
364
- for type_comb in combinations_with_replacement(fp8_types, 2):
365
- types.append(type_comb + (cutlass_cppgen.DataType.f32,))
366
-
367
- layouts = [
368
- (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor),
369
- (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor),
370
- (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.RowMajor),
371
- (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.ColumnMajor),
372
- ]
373
- elif operation_kind == cutlass_library.OperationKind.Conv2d:
374
- types = [
375
- (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
376
- (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
377
- (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
378
- (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
379
- ]
380
-
381
- layouts = [
382
- (cutlass_library.LayoutType.TensorNHWC, cutlass_library.LayoutType.TensorNHWC),
383
- ]
384
- else:
385
- raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.")
386
-
387
- alignment = 1
388
- epilogue_functor = cutlass_library.EpilogueFunctor.LinearCombination
389
- swizzling_functor = cutlass_library.SwizzlingFunctor.Identity8
390
- for type_comb in types:
391
- for layout_comb in layouts:
392
- comb = (type_comb, layout_comb)
393
- if comb in self.operations_by_opclass[cutlass_library.OpcodeClass.Simt]:
394
- continue
395
-
396
- A = cutlass_library.TensorDescription(type_comb[0], layout_comb[0], alignment)
397
- B = cutlass_library.TensorDescription(type_comb[1], layout_comb[1], alignment)
398
- C = cutlass_library.TensorDescription(type_comb[2], cutlass_library.LayoutType.ColumnMajor, alignment)
399
- math_inst = cutlass_library.MathInstruction(
400
- [1, 1, 1],
401
- type_comb[0],
402
- type_comb[1],
403
- type_comb[2],
404
- cutlass_library.OpcodeClass.Simt,
405
- cutlass_library.MathOperation.multiply_add
406
- )
407
-
408
- td = cutlass_library.TileDescription(
409
- [128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024)
410
-
411
- # Prune operations that don't fit in shared memory
412
- if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td), verbose=False)[0]:
413
- continue
414
-
415
- new_kernels = KernelsForDataType(type_comb, layout_comb)
416
-
417
- if operation_kind == cutlass_library.OperationKind.Gemm:
418
- new_operation = cutlass_library.manifest.GemmOperation(
419
- cutlass_library.GemmKind.Universal, td.minimum_compute_capability,
420
- td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
421
- new_kernels.add(new_operation)
422
- elif operation_kind == cutlass_library.OperationKind.Conv2d:
423
- for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
424
- new_operation = cutlass_library.manifest.Conv2dOperation(
425
- conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td,
426
- A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor,
427
- group_mode=GroupMode.SingleGroup
428
- )
429
- new_kernels.add(new_operation)
430
-
431
- self.operations_by_opclass[cutlass_library.OpcodeClass.Simt][comb] = new_kernels
432
-
433
- # Sort all operations
434
- for oc in self.operations_by_opclass.keys():
435
- for comb in self.operations_by_opclass[oc].keys():
436
- self.operations_by_opclass[oc][comb].sort()
437
-
438
- def opclass_supports_combination(
439
- self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation
440
- ) -> bool:
441
- """
442
- Returns whether the provided operation class supports the provided data type and layout combination
443
-
444
- :param op_class: operation class to consider
445
- :type op_class: cutlass_library.OpcodeClass
446
- :param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator)
447
- :type datatype_comb: tuple[cutlass_library.DataType]
448
- :param layout_comb: tuple of data types for (layout_A, layout_B)
449
- :type layout_comb: tuple[cutlass_library.LayoutType]
450
- :param math_operation: math operation to consider or None if any can be considered
451
- :type math_operation: cutlass_cppgen.MathOperation
452
-
453
- :return: set of operation classes that support the provided data type and layout combination
454
- :rtype: set
455
- """
456
- if op_class not in self.operations_by_opclass:
457
- raise Exception(f"Unexpected or unsupported operation class {op_class}")
458
-
459
- if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)):
460
- if math_operation is not None:
461
- return operations.supports_math_operation(math_operation)
462
- else:
463
- return True
464
-
465
- return False
466
-
467
-
468
- def supporting_opclasses(
469
- self,
470
- element_a: cutlass_library.DataType,
471
- element_b: cutlass_library.DataType,
472
- element_accumulator: cutlass_library.DataType,
473
- layout_a: cutlass_library.LayoutType,
474
- layout_b: cutlass_library.LayoutType,
475
- math_operation: cutlass_library.MathOperation,
476
- ) -> set:
477
- """
478
- Returns a set of operation classes that support the provided data type combination
479
-
480
- :param element_a: data type of operand A
481
- :type element_a: cutlass_library.DataType
482
- :param element_b: data type of operand B
483
- :type element_b: cutlass_library.DataType
484
- :param element_accumulator: data type of accumulator
485
- :type element_accumulator: cutlass_library.DataType
486
- :param layout_a: layout of operand A
487
- :type layout_a: cutlass_library.LayoutType
488
- :param layout_b: layout of operand B
489
- :type layout_b: cutlass_library.LayoutType
490
- :param math_operation: math operation to consider
491
- :type math_operation: cutlass_cppgen.MathOperation
492
-
493
- :return: set of operation classes that support the provided data type combination
494
- :rtype: set
495
- """
496
- supporting_op_classes = set()
497
- datatype_comb = (element_a, element_b, element_accumulator)
498
- layout_comb = (layout_a, layout_b)
499
-
500
- for op_class in self.operations_by_opclass.keys():
501
- if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
502
- supporting_op_classes.add(op_class)
503
- return supporting_op_classes
504
-
505
- def operations(
506
- self,
507
- op_class: cutlass_library.OpcodeClass,
508
- element_a: cutlass_library.DataType,
509
- element_b: cutlass_library.DataType,
510
- element_accumulator: cutlass_library.DataType,
511
- layout_a: cutlass_library.LayoutType,
512
- layout_b: cutlass_library.LayoutType,
513
- math_operation: cutlass_library.MathOperation,
514
- ) -> KernelsForDataType:
515
- """
516
- Returns whether the provided operation class supports the provided data type combination
517
-
518
- :param op_class: operation class to consider
519
- :type op_class: cutlass_library.OpcodeClass
520
- :param element_a: data type of operand A
521
- :type element_a: cutlass_library.DataType
522
- :param element_b: data type of operand B
523
- :type element_b: cutlass_library.DataType
524
- :param element_accumulator: data type of accumulator
525
- :type element_accumulator: cutlass_library.DataType
526
- :param layout_a: layout of operand A
527
- :type layout_a: cutlass_library.LayoutType
528
- :param layout_b: layout of operand B
529
- :type layout_b: cutlass_library.LayoutType
530
- :param math_operation: math operation to consider
531
- :type math_operation: cutlass_cppgen.MathOperation
532
-
533
- :return: container of kernels by alignment supported by the provided combination of parameters
534
- :rtype: KernelsForDataType
535
- """
536
- datatype_comb = (element_a, element_b, element_accumulator)
537
- layout_comb = (layout_a, layout_b)
538
- if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
539
- raise Exception(
540
- f"Data type layout combination {datatype_comb}, {layout_comb} "
541
- f"is not supported by opcode class {op_class} on CC {self.cc}."
542
- )
543
- return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)]
544
-
545
-
546
- class OptionRegistry:
547
- """
548
- Container of all architecture-specific options
549
-
550
- :param target_cc: compute capability of the device on which operations will be run
551
- :type target_cc: int
552
- """
553
-
554
- def __init__(self, target_cc: int):
555
- self.registry = {}
556
-
557
- if target_cc > 100 and (target_cc not in [101, 103, 120, 121]):
558
- raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to the Blackwell architecture.")
559
-
560
- gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x]
561
- operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d]
562
- # Construct options for each CC
563
- for kernel_cc in _generator_ccs:
564
- self.registry[kernel_cc] = {}
565
- for opkind in operation_kinds:
566
- self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds)
567
-
568
- def options_for_cc(self, cc: int, op_kind=cutlass_library.OperationKind.Gemm) -> ArchOptions:
569
- return self.registry.get(cc, None)[op_kind]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py DELETED
@@ -1,36 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
34
- from cutlass_cppgen.op.gemm import Gemm
35
- from cutlass_cppgen.op.gemm_grouped import GroupedGemm
36
- from cutlass_cppgen.op.op import OperationBase
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py DELETED
@@ -1,997 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Ease-of-use interface for constructing, compiling, and running CONVs
35
-
36
- The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run
37
- CONV2D operations in CUTLASS via Python, without specifying many configuration parameters.
38
- Under the hood, the interface will select sensible default parameters for the many template
39
- parameters for CUTLASS CONVs.
40
-
41
- Note: optimal performance is not to be expected from this interface. To achieve optimal
42
- performance, one should specify and tune each configuration parameter.
43
-
44
- The simplest example of using this interface is the following:
45
-
46
- .. highlight:: python
47
- .. code-block:: python
48
-
49
- # A, B, C, and D are torch/numpy/cupy tensor objects
50
- plan = cutlass_cppgen.op.Conv(A, B, C, D)
51
- plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
52
-
53
- One can also use the interface by specifying data types of operands at construction
54
- and using different tensor objects with these data types at runtime:
55
-
56
- .. highlight:: python
57
- .. code-block:: python
58
-
59
- # The following is shorthand for:
60
- # cutlass_cppgen.op.Conv2d(kind="fprop",
61
- # element_A=torch.float32, element_B=torch.float32,
62
- # element_C=torch.float32, element_D=torch.float32,
63
- # element_accumulator=torch.float32)
64
- plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32)
65
-
66
- A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
67
- B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
68
- C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda')
69
- D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda')
70
- plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
71
-
72
- A = torch.rand((32, 128), dtype=torch.float32, device='cuda')
73
- B = torch.rand((128, 256), dtype=torch.float32, device='cuda')
74
- C = torch.zeros((32, 256), dtype=torch.float32, device='cuda')
75
- D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda')
76
- plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
77
-
78
- The interface additionally enables one to decouple the compilation of the underlying CUTLASS
79
- kernel from its execution:
80
-
81
- .. highlight:: python
82
- .. code-block:: python
83
-
84
- plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
85
-
86
- # Do other work...
87
-
88
- plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
89
-
90
- # Do other work...
91
-
92
- plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
93
-
94
- Elementwise activation functions are easily fused to the GEMM via the interface:
95
-
96
- .. highlight:: python
97
- .. code-block:: python
98
-
99
- plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
100
- plan.activation = cutlass_cppgen.epilogue.relu
101
-
102
- Operations can also be run asynchronously:
103
-
104
- .. highlight:: python
105
- .. code-block:: python
106
-
107
- plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
108
- args = plan.run()
109
-
110
- # Do other work...
111
-
112
- args.sync()
113
- """
114
-
115
- from __future__ import annotations
116
- from typing import Optional
117
- from cutlass_cppgen.utils.lazy_import import lazy_import
118
- cuda = lazy_import("cuda.cuda")
119
- cudart = lazy_import("cuda.cudart")
120
- from cutlass_library import (
121
- ConvKind,
122
- ConvMode,
123
- DataTypeSize,
124
- IteratorAlgorithm,
125
- OperationKind,
126
- SplitKMode,
127
- StrideSupport,
128
- )
129
-
130
- import cutlass_cppgen
131
- from cutlass_cppgen import epilogue
132
- from cutlass_cppgen.backend import compiler
133
- from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
134
- from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments
135
- from cutlass_cppgen.backend.library import TensorDescription, TileDescription
136
- from cutlass_cppgen.op.op import OperationBase
137
- from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord
138
- from cutlass_cppgen.utils import check, datatypes
139
-
140
-
141
- class Conv2d(OperationBase):
142
- """
143
- Constructs a ``Conv2d`` object.
144
-
145
- The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
146
- along with the data type of output D and that used for accumulation, are bound to the ``Conv``
147
- object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
148
-
149
- The constructor has optional parameters for flexibly setting these parameters. The following
150
- constructors are equivalent:
151
-
152
- .. highlight:: python
153
- .. code-block:: python
154
-
155
- # Use F32 for A, B, C, D, and accumulation in fprop
156
-
157
- # Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
158
- Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32)
159
-
160
- # Explicitly specify the data types to use for A, B, C, and D.
161
- Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32,
162
- element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32)
163
-
164
- # Set the data types and elements from existing tensors. Note that one can use different tensors when
165
- # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
166
- # have the same data type as those passed in here).
167
- # A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
168
- Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
169
-
170
- # Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
171
- # those passed in via the generic ``element``
172
- Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
173
- element=cutlass_cppgen.DataType.f32)
174
-
175
- The order of precedence for the setting of the data type for a given operand/output is as follows:
176
- 1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
177
- 2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those
178
- 3) Otherwise, use the generic values (e.g., ``element``)
179
-
180
- :param kind: the convolution kind (i.e. fprop, wgrad, and dgrad)
181
- :type kind: str
182
- :param A: tensor representing data type of operand A
183
- :param B: tensor representing data type of operand B
184
- :param C: tensor representing data type of operand C
185
- :param D: tensor representing data type of operand D
186
- :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
187
- :param beta: scalar parameter beta from GEMM operation that scales operand C
188
- :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
189
- :type element: cutlass_cppgen.DataType
190
- :param element_A: data type to be used for operand A
191
- :type element_A: cutlass_cppgen.DataType
192
- :param element_B: data type to be used for operand B
193
- :type element_B: cutlass_cppgen.DataType
194
- :param element_C: data type to be used for operand C
195
- :type element_C: cutlass_cppgen.DataType
196
- :param element_D: data type to be used for operand D
197
- :type element_D: cutlass_cppgen.DataType
198
- :param element_accumulator: data type to be used in accumulation of the product of operands A and B
199
- :type element_accumulator: cutlass_cppgen.DataType
200
- :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
201
- :type cc: int
202
- :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
203
- :type kernel_cc: int
204
- """
205
- def __init__(
206
- self, kind="fprop",
207
- A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
208
- element=None,
209
- element_A=None, element_B=None, element_C=None, element_D=None,
210
- element_accumulator=None,
211
- cc: int = None, kernel_cc: int = None
212
- ):
213
- super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d)
214
- # Verify the kernel cc
215
- if self.current_cc in [90, 100, 101, 103]:
216
- # The Conv2d kernel on Hopper (SM90) is currently unsupported
217
- # Revert to use SM80-tagged kernels
218
- cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
219
- self.specified_kernel_cc = 80
220
- self._reset_options(80)
221
-
222
- # The arch is used in testing
223
- self.arch = self.current_cc
224
- self.name = "conv2d" + kind
225
-
226
- # The convolution kind. (concept: cutlass_library.library.ConvKind)
227
- self.conv_kind = datatypes.getattr_enum(ConvKind, kind)
228
-
229
- # The element types (concept: cutlass library types) of A, B, C, and D
230
- elements = []
231
- layouts = []
232
-
233
- # Complete the data types based on user-provided arguments
234
- for elt, tens, name in zip([element_A, element_B, element_C, element_D],
235
- [A, B, C, D],
236
- ["A", "B", "C", "D"]):
237
- if elt is not None and tens is not None:
238
- raise Exception(f'Must not specify both element_{name} and tensor {name}')
239
- if elt is None and tens is None and element is None:
240
- raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
241
-
242
- elt_to_set = None
243
- lay_to_set = None
244
-
245
- if tens is not None:
246
- elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
247
- else:
248
- elt_to_set = elt if elt is not None else element
249
-
250
- assert elt_to_set is not None
251
-
252
- # Currently we only support layout TensorNHWC
253
- lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC
254
- elements.append(datatypes.library_type(elt_to_set))
255
- layouts.append(lay_to_set)
256
-
257
- self._element_a, self._element_b, self._element_c, self._element_d = elements
258
- self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
259
-
260
- self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
261
-
262
- if element_accumulator is None:
263
- self._element_accumulator = self._element_c
264
- else:
265
- self._element_accumulator = datatypes.library_type(element_accumulator)
266
-
267
- # Default inputs if none is supplied in run()
268
- self.A = A
269
- self.B = B
270
- self.C = C
271
- self.D = D
272
-
273
- self.alpha = alpha
274
- self.beta = beta
275
-
276
- # We only specify the stride of the swizzling functor here
277
- # The actual swizzling functor is determined in run based on conv_kind and stride
278
- self._swizzling_stride = 1
279
-
280
- # Arguments that will be set to default value in _reset_operations
281
- # The default tile_description and op_class are fetched from manifest of cutlass library
282
- self._tile_description = None
283
- self.op_class = None
284
- # The default identity epilogue will be created
285
- self.epilogue_functor = None
286
-
287
- self._reset_operations()
288
-
289
- # Arguments that will be determined online based on arguments of "run"
290
- # based on stride, input/output channels, alignment, and conv_kind
291
- self._iterator_algorithm = None
292
- self._stride_support = None
293
-
294
- def _reset_operations(self, reset_epilogue: bool = True):
295
- # Set the default op class
296
- datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
297
- layout_comb = (self._layout_a, self._layout_b)
298
-
299
- self.possible_op_classes = self.options.supporting_opclasses(
300
- self._element_a, self._element_b, self._element_accumulator,
301
- self._layout_a, self._layout_b, self._math_operation
302
- )
303
-
304
- if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
305
- self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
306
- elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
307
- self.opclass = cutlass_cppgen.OpcodeClass.Simt
308
- else:
309
- if self._math_operation is not None:
310
- math_op_str = f' and math operation {self._math_operation}'
311
- else:
312
- math_op_str = ''
313
-
314
- raise Exception(f'No kernel configuration found for supported data type and layout '
315
- f'combination {datatype_comb}x{layout_comb}{math_op_str}')
316
-
317
- if reset_epilogue:
318
- self._reset_epilogue_functor_activation(epilogue.identity)
319
-
320
- self.alignment_pref_A = min(
321
- 128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
322
- self.alignment_pref_B = min(
323
- 128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
324
- self.alignment_pref_C = min(
325
- 128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C")))
326
-
327
- #
328
- # Tile description Related
329
- #
330
-
331
- @property
332
- def tile_description(self) -> TileDescription:
333
- """
334
- Returns the tile description
335
- """
336
- return self._tile_description
337
-
338
- @tile_description.setter
339
- def tile_description(
340
- self, td=None):
341
- """
342
- Set the tile description
343
-
344
- :param td: tile description
345
- :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
346
- {
347
- "threadblock_shape": [int, int, int],
348
- "warp_count": [int, int, int],
349
- "stages": int,
350
- "instruction_shape": [int, int, int] (optional),
351
- "cluster_shape": [int, int, int] (optional)
352
- }
353
- """
354
- if td is None:
355
- return
356
- if isinstance(td, dict):
357
- if self._tile_description is None:
358
- op = self.possible_operations.default_operation(self._math_operation)
359
- self._tile_description = datatypes.td_from_profiler_op(op)
360
- if "cluster_shape" in td.keys():
361
- if td["cluster_shape"] != [1, 1, 1]:
362
- cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
363
- td["cluster_shape"] = [1, 1, 1]
364
- td = self._tile_description.clone_and_update(td)
365
-
366
- valid, msg = self._valid_tile_description(td)
367
- if valid:
368
- self._tile_description = td
369
- else:
370
- raise Exception(msg)
371
-
372
- def _valid_tile_description(self, td: TileDescription) -> tuple:
373
- """
374
- Checks whether the provided tile description is valid for the given compute capability. At present,
375
- this checks the following:
376
-
377
- - Does the tile description use a number of stages supported by the compute capability in question?
378
- - Does the tile size requested fit within shared memory?
379
- - Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
380
- more non-unit cluster dimensions for pre-SM90 architectures)?
381
- - Is the kernel schedule being used supported on the architecture in question?
382
-
383
- :param td: tile description to validate
384
- :type td: cutlass_cppgen.backend.TileDescription
385
- :return: tuple in which the first element is a bool indicating that the tile description is valid
386
- and the second element is a string providing an optional error message.
387
- :rtype: tuple
388
- """
389
- valid, msg = check.valid_stage_count(self.cc, self.current_cc, td)
390
- if not valid:
391
- return (valid, msg)
392
-
393
- valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
394
- if not valid:
395
- return (valid, msg)
396
-
397
- return valid, msg
398
-
399
- def tile_descriptions(self) -> list:
400
- """
401
- Returns a list of valid tile descriptions for the operations
402
-
403
- :returns: list of valid tile descriptions for the operations
404
- :rtype: list
405
- """
406
- descriptions = []
407
- description_str = []
408
- for op in self.possible_operations.all_operations:
409
- td = datatypes.td_from_profiler_op(op)
410
-
411
- if self._math_operation is not None:
412
- if td.math_instruction.math_operation != self._math_operation:
413
- continue
414
-
415
- if str(td) not in description_str:
416
- description_str.append(str(td))
417
- descriptions.append(td)
418
- return descriptions
419
-
420
- #
421
- # Swizzling functor Related
422
- #
423
-
424
- @property
425
- def swizzling_stride(self):
426
- """
427
- Returns the stride of swizzling currently being used by the Conv2d
428
-
429
- :return: swizzing stride
430
- """
431
- return self._swizzling_stride
432
-
433
- @swizzling_stride.setter
434
- def swizzling_stride(self, stride: int):
435
- """
436
- Sets the swizzling functor to the type specified by `swizzling_functor`
437
- """
438
- if not isinstance(stride, int):
439
- raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}")
440
- self._swizzling_stride = stride
441
-
442
- def _propose_swizzling_functor(self, stride):
443
- """
444
- Automatically propose the swizzling functor based on the stride
445
- """
446
- if self.conv_kind == ConvKind.Dgrad:
447
- if stride[0] != 1 or stride[1] != 1:
448
- return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
449
-
450
- return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
451
-
452
- #
453
- # Iterator Algorithm Related
454
- #
455
-
456
- @property
457
- def iterator_algorithm(self) -> IteratorAlgorithm:
458
- """
459
- Returns the iterator algorithm
460
- """
461
- return self._iterator_algorithm
462
-
463
- @iterator_algorithm.setter
464
- def iterator_algorithm(self, alg: str):
465
- """
466
- Sets the iterator algorithm
467
-
468
- :param alg: The iterator algorithm
469
- :type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
470
- """
471
- iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg)
472
-
473
- # Check if the iterator algorithm is valid
474
- if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop:
475
- raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
476
-
477
- self._iterator_algorithm = iterator_alg
478
-
479
- def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm:
480
- """
481
- Propose a valid iterator algorithm based on problem size and alignment
482
- """
483
- if self.conv_kind == ConvKind.Fprop:
484
- # Check whether the fixed channel is applicable
485
- if problem_size.C == alignment_a:
486
- return IteratorAlgorithm.FixedChannels
487
- elif (problem_size.C % alignment_a == 0 and
488
- problem_size.R <= 32 and problem_size.S <= 32):
489
- return IteratorAlgorithm.Optimized
490
- else:
491
- return IteratorAlgorithm.Analytic
492
- elif self.conv_kind == ConvKind.Dgrad:
493
- if (problem_size.K % alignment_a == 0 and
494
- problem_size.R <= 32 and problem_size.S <= 32 and
495
- problem_size.C % alignment_b == 0):
496
- return IteratorAlgorithm.Optimized
497
- else:
498
- return IteratorAlgorithm.Analytic
499
- elif self.conv_kind == ConvKind.Wgrad:
500
- if (problem_size.K % alignment_a == 0 and
501
- problem_size.C % alignment_b == 0):
502
- return IteratorAlgorithm.Optimized
503
- else:
504
- return IteratorAlgorithm.Analytic
505
-
506
- def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
507
- """
508
- Validate whether the user provide iterator algorithm works for the given problem size
509
- """
510
- if self.conv_kind == ConvKind.Fprop:
511
- if iterator_algorithm == IteratorAlgorithm.FixedChannels:
512
- return problem_size.C == alignment_a
513
- elif iterator_algorithm == IteratorAlgorithm.Optimized:
514
- return (problem_size.C % alignment_a == 0 and
515
- problem_size.R <= 32 and problem_size.S <= 32)
516
- elif iterator_algorithm == IteratorAlgorithm.FewChannels:
517
- return problem_size.C % alignment_a == 0
518
- elif self.conv_kind == ConvKind.Dgrad:
519
- if iterator_algorithm == IteratorAlgorithm.Optimized:
520
- return (problem_size.K % alignment_a == 0 and
521
- problem_size.R <= 32 and problem_size.S <= 32 and
522
- problem_size.C % alignment_b == 0)
523
- elif self.conv_kind == ConvKind.Wgrad:
524
- if iterator_algorithm == IteratorAlgorithm.Optimized:
525
- return (problem_size.K % alignment_a == 0 and
526
- problem_size.C % alignment_b == 0)
527
-
528
- return True
529
-
530
- #
531
- # Stride Support Related
532
- #
533
-
534
- def _propose_stride_support(self, stride):
535
- if self.conv_kind == ConvKind.Dgrad:
536
- if stride[0] == 1 and stride[1] == 1:
537
- return StrideSupport.Unity
538
-
539
- return StrideSupport.Strided
540
-
541
- #
542
- # Construct and Compilation
543
- #
544
-
545
- def construct(
546
- self, tile_description: TileDescription = None,
547
- alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
548
- iterator_algorithm: IteratorAlgorithm = None,
549
- stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
550
- epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation:
551
- """
552
- Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current
553
- kernel specification of the ``Conv2d`` object.
554
-
555
- :param tile_description: tile description specifying shapes and operand types to use in the kernel
556
- :type tile_description: cutlass_cppgen.backend.TileDescription
557
- :param alignment_A: alignment of operand A
558
- :type alignment_A: int
559
- :param alignment_B: alignment of operand B
560
- :type alignment_B: int
561
- :param alignment_C: alignment of operand C
562
- :type alignment_C: int
563
- :param iterator_algorithm: the iterator algorithm used
564
- :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
565
- :param stride_support: the stride support of dgrad
566
- :type stride_support: cutlass_library.library.StrideSupport
567
- :param swizzling_functor: the swizzling functor
568
- :type swizzling_functor: cutlass_cppgen.swizzle
569
- :param epilogue_functor: the epilogue functor
570
-
571
- :return: operation that was constructed
572
- :rtype: cutlass_cppgen.backend.Conv2dOperation
573
- """
574
- # Get alignment
575
- alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
576
- alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
577
- alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
578
-
579
- tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
580
- tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
581
- tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
582
-
583
- if tile_description is None:
584
- if self.tile_description is not None:
585
- tile_description = self.tile_description
586
- else:
587
- op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
588
- tile_description = datatypes.td_from_profiler_op(op)
589
- else:
590
- valid, err_str = self._valid_tile_description(tile_description)
591
- if not valid:
592
- raise Exception(f"Invalid tile description. {err_str}")
593
- self.tile_description = tile_description
594
-
595
- if iterator_algorithm is None:
596
- # If the iterator algorithm is already set
597
- if self.iterator_algorithm is not None:
598
- iterator_algorithm = self.iterator_algorithm
599
- else:
600
- # Otherwise, we conservatively use the analytic iterator for correctness
601
- iterator_algorithm = IteratorAlgorithm.Analytic
602
-
603
- if stride_support is None:
604
- # If the stride support is already set
605
- if self._stride_support is not None:
606
- stride_support = self._stride_support
607
- else:
608
- # Otherwise, we assume strided
609
- stride_support = StrideSupport.Strided
610
-
611
- if swizzling_functor is None:
612
- # If the swizzling functor is already set
613
- swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
614
-
615
- if epilogue_functor is None:
616
- if self.epilogue_functor is not None:
617
- epilogue_functor = self.epilogue_functor
618
- else:
619
- epilogue_functor = self._create_epilogue_functor_activation(self._activation)
620
-
621
- # Reset the alignment of the epilogue functor
622
- epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
623
-
624
- operation = Conv2dOperation(
625
- conv_kind=self.conv_kind,
626
- iterator_algorithm=iterator_algorithm,
627
- arch=self.current_cc,
628
- tile_description=tile_description,
629
- A=tensor_A, B=tensor_B, C=tensor_C,
630
- stride_support=stride_support,
631
- epilogue_functor=epilogue_functor,
632
- swizzling_functor=swizzling_functor,
633
- )
634
-
635
- return operation
636
-
637
- def compile(self, tile_description: TileDescription = None,
638
- alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
639
- iterator_algorithm: IteratorAlgorithm = None,
640
- stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
641
- epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation:
642
- """
643
- Emits and compiles the kernel currently specified. If ``tile_description`` and any
644
- of the ``alignment`` parameters are set, the kernel will be chosen using this
645
- tile description and alignments. Otherwise, a default tile description and alignment
646
- will be used.
647
-
648
- ::param tile_description: tile description specifying shapes and operand types to use in the kernel
649
- :type tile_description: cutlass_cppgen.backend.TileDescription
650
- :param alignment_A: alignment of operand A
651
- :type alignment_A: int
652
- :param alignment_B: alignment of operand B
653
- :type alignment_B: int
654
- :param alignment_C: alignment of operand C
655
- :type alignment_C: int
656
- :param iterator_algorithm: the iterator algorithm used
657
- :type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
658
- :param stride_support: the stride support of dgrad
659
- :type stride_support: cutlass_library.library.StrideSupport
660
- :param swizzling_functor: the swizzling functor
661
- :type swizzling_functor: cutlass_cppgen.swizzle
662
- :param epilogue_functor: the epilogue functor
663
-
664
- :return: operation that was compiled
665
- :rtype: cutlass_cppgen.backend.Conv2dOperation
666
- """
667
-
668
- self.operation = self.construct(
669
- tile_description, alignment_A, alignment_B, alignment_C,
670
- iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
671
-
672
- if print_module:
673
- print(self.operation.rt_module.emit())
674
-
675
- compiler.add_module([self.operation,])
676
- return self.operation
677
-
678
- #
679
- # Run Related
680
- #
681
-
682
- def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
683
- """
684
- Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
685
- is raised if it does not.
686
-
687
- :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
688
- :type tensor: numpy/cupy/torch array/tensor object
689
- :param ref_dtype: data type for the tensor that this object was initialized to
690
- :param name: identifier of the tensor to verify. Used in raising exceptions
691
- :type name: str
692
- """
693
- dtype, _ = datatypes.get_datatype_and_layout(tensor)
694
- if dtype != ref_type:
695
- raise Exception(f'Tensor {name} with type and layout {dtype} '
696
- f'does not match the expected type of {ref_type}.')
697
-
698
- def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
699
- if self.conv_kind == ConvKind.Fprop:
700
- input = A
701
- weight = B
702
- output = C
703
- output_tensor = "C"
704
- elif self.conv_kind == ConvKind.Dgrad:
705
- output = A
706
- weight = B
707
- input = C
708
- output_tensor = "A"
709
- elif self.conv_kind == ConvKind.Wgrad:
710
- output = A
711
- input = B
712
- weight = C
713
- output_tensor = "A"
714
- else:
715
- raise Exception(f"Convolution kind {self.conv_kind} is not supported")
716
-
717
- N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV")
718
- K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV")
719
- _, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV")
720
-
721
- problem_size = Conv2DProblemSize(
722
- N_, H_, W_, C_,
723
- K_, R_, S_, C_,
724
- padding[0], padding[1],
725
- stride[0], stride[1],
726
- dilation[0], dilation[1],
727
- ConvMode.CrossCorrelation,
728
- 1, 1
729
- )
730
-
731
- if P_ != problem_size.P or Q_ != problem_size.Q:
732
- raise Exception(
733
- f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
734
-
735
- return problem_size
736
-
737
- def run(self, A=None, B=None, C=None, D=None,
738
- stride=(1, 1), padding=(0, 0), dilation=(1, 1),
739
- alpha=None, beta=None,
740
- split_k=("serial", 1), sync: bool = True,
741
- print_module: bool = False,
742
- stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
743
- """
744
- Runs the kernel currently specified. If it has not already been, the kernel is emitted and
745
- compiled. Tensors holding operands and outputs of the kernel are sourced either from the
746
- ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
747
- parameters provided in the call, or from those
748
- passed in on the construction of this object -- one of the two must be specified.
749
-
750
- By default, this call returns only once the kernel has completed. To launch the kernel
751
- and immediately return, set ``sync=False``. In this case, it is the responsibility of the
752
- caller to syncrhonize the results of the kernel before attempting to access outputs
753
- by calling ``sync()`` on the arguments returned from this call.
754
-
755
- :param A: tensor representing data type and layout of operand A
756
- :param B: tensor representing data type and layout of operand B
757
- :param C: tensor representing data type and layout of operand C
758
- :param D: tensor representing data type and layout of operand D
759
- :param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1)
760
- :param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0)
761
- :param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1)
762
- :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
763
- :param beta: scalar parameter beta from GEMM operation that scales operand C
764
- :param split_k: a tuple (split_k_mode, split_k_slices)
765
- :param sync: whether the call should wait for the kernel to complete before returning
766
- :type sync: bool
767
- :param print_module: whether to print the emitted C++ code
768
- :type print_module: bool
769
- :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
770
- :type stream: :class:`cuda.cuda.CUstream`
771
-
772
- :return: arguments passed in to the kernel
773
- :rtype: cutlass_cppgen.backend.Conv2dArguments
774
- """
775
- if not stream:
776
- stream = cuda.CUstream(0)
777
- super().run_setup()
778
-
779
- A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
780
- B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
781
- C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
782
- D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
783
- alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
784
- beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
785
-
786
- # handle the case when there is no C
787
- if C is None:
788
- if beta != 0:
789
- raise Exception(f"With beta {beta} != 0, C has to be provided.")
790
- else:
791
- C = D
792
-
793
- # Construct problem size based on input
794
- # It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
795
- problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
796
-
797
- # Propose stride support based on input
798
- stride_support = self._propose_stride_support(stride)
799
-
800
- # Propose swizzling functor
801
- swizzling_functor = self._propose_swizzling_functor(stride)
802
-
803
- shape_a = datatypes.get_tensor_shape(A, op="CONV")
804
- shape_b = datatypes.get_tensor_shape(B, op="CONV")
805
- shape_c = datatypes.get_tensor_shape(C, op="CONV")
806
-
807
- # Get the alignment
808
- alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A")
809
- alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B")
810
- alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C")
811
-
812
- alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
813
- alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
814
- alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
815
-
816
- # Propose iterator algorithm based on input
817
- if self._iterator_algorithm is None:
818
- # Propose a default iterator algorithm based on the problem size
819
- iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
820
- else:
821
- if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
822
- iterator_algorithm = self._iterator_algorithm
823
- else:
824
- raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
825
-
826
- epilogue_args = [alpha, beta]
827
-
828
- if hasattr(self, "_activation_args"):
829
- if isinstance(self._activation_args, list):
830
- epilogue_args += self._activation_args
831
- else:
832
- epilogue_args.append(self._activation_args)
833
-
834
- if split_k[0] == "parallel" and split_k[1] > 1:
835
- epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
836
- else:
837
- epilogue_functor = self.epilogue_functor
838
-
839
- # The alignment is determined by the iterator function (I believe)
840
- self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
841
- alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
842
- swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
843
-
844
- # Create reduction operation for parallel split-k
845
- if split_k[0] == "parallel" and split_k[1] > 1:
846
- epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
847
- self.reduction_operation = ReductionOperation(
848
- shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
849
- element_accumulator=self._element_accumulator,
850
- element_compute=self._element_accumulator,
851
- epilogue_functor=epilogue_functor_reduction,
852
- count=alignment_c
853
- )
854
- if print_module:
855
- print(self.reduction_operation.rt_module.emit())
856
- compiler.add_module([self.reduction_operation,])
857
-
858
- arguments = Conv2dArguments(
859
- operation=self.operation, problem_size=problem_size,
860
- A=A, B=B, C=C, D=D,
861
- output_op=self.operation.epilogue_type(*epilogue_args),
862
- split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
863
- split_k_slices=split_k[1],
864
- stream=stream
865
- )
866
-
867
- self.operation.run(arguments)
868
-
869
- if split_k[0] == "parallel" and split_k[1] > 1:
870
- implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind)
871
- reduction_arguments = ReductionArguments(
872
- self.reduction_operation,
873
- problem_size=[implicit_gemm_size.m, implicit_gemm_size.n],
874
- partitions=split_k[1],
875
- workspace=arguments.ptr_D,
876
- destination=D,
877
- source=C,
878
- output_op=self.reduction_operation.epilogue_type(*epilogue_args),
879
- stream=stream
880
- )
881
- self.reduction_operation.run(reduction_arguments)
882
-
883
- if sync:
884
- if split_k[0] == "parallel" and split_k[1] > 1:
885
- reduction_arguments.sync()
886
-
887
- # Free memory allocated by args because we are not
888
- # calling `arguments.sync()` in this case (which will free memory)
889
- arguments.free()
890
- else:
891
- arguments.sync()
892
-
893
- return arguments
894
-
895
- #
896
- # Helper functions
897
- #
898
- @staticmethod
899
- def output_size(input_size, weight_size, padding, stride, dilation):
900
- problem_size = Conv2DProblemSize(
901
- *input_size,
902
- *weight_size,
903
- padding[0], padding[1],
904
- stride[0], stride[1],
905
- dilation[0], dilation[1],
906
- ConvMode.CrossCorrelation,
907
- 1, 1
908
- )
909
- return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
910
-
911
-
912
- #
913
- # Easy to use interfaces for fprop, wgrad, and dgrad
914
- #
915
-
916
- class Conv2dFprop(Conv2d):
917
- def __init__(
918
- self,
919
- input=None, weight=None, C=None, output=None, alpha=1, beta=0,
920
- element=None,
921
- element_input=None, element_weight=None, element_C=None, element_output=None,
922
- element_accumulator=None,
923
- cc: int = None, kernel_cc: int = None):
924
- A, B, D = input, weight, output
925
- element_A, element_B, element_D = element_input, element_weight, element_output
926
- super().__init__(
927
- "fprop", A, B, C, D, alpha, beta, element,
928
- element_A, element_B, element_C, element_D,
929
- element_accumulator, cc, kernel_cc)
930
-
931
- def run(
932
- self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
933
- stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
934
- sync: bool = True, print_module: bool = False,
935
- stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
936
-
937
- if not stream:
938
- stream = cuda.CUstream(0)
939
-
940
- A, B, D = input, weight, output
941
- return super().run(
942
- A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
943
-
944
-
945
- class Conv2dDgrad(Conv2d):
946
- def __init__(
947
- self,
948
- grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
949
- element=None,
950
- element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
951
- element_accumulator=None,
952
- cc: int = None, kernel_cc: int = None):
953
- A, B, D = grad_output, weight, grad_input
954
- element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
955
- super().__init__(
956
- "dgrad", A, B, C, D, alpha, beta, element,
957
- element_A, element_B, element_C, element_D,
958
- element_accumulator, cc, kernel_cc)
959
-
960
- def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
961
- stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
962
- sync: bool = True, print_module: bool = False,
963
- stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
964
- #
965
- if not stream:
966
- stream = cuda.CUstream(0)
967
-
968
- A, B, D = grad_output, weight, grad_input
969
- return super().run(
970
- A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
971
-
972
-
973
- class Conv2dWgrad(Conv2d):
974
- def __init__(
975
- self,
976
- grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
977
- element=None,
978
- element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
979
- element_accumulator=None,
980
- cc: int = None, kernel_cc: int = None):
981
- A, B, D = grad_output, input, grad_weight
982
- element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
983
- super().__init__(
984
- "wgrad", A, B, C, D, alpha, beta, element,
985
- element_A, element_B, element_C, element_D,
986
- element_accumulator, cc, kernel_cc)
987
-
988
- def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
989
- stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
990
- sync: bool = True, print_module: bool = False,
991
- stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
992
- if not stream:
993
- stream = cuda.CUstream(0)
994
-
995
- A, B, D = grad_output, input, grad_weight
996
- return super().run(
997
- A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py DELETED
@@ -1,725 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Ease-of-use interface for constructing, compiling, and running GEMMs.
35
-
36
- The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run
37
- GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
38
- Under the hood, the interface will select sensible default parameters for the many template
39
- parameters for CUTLASS GEMMs.
40
-
41
- Note: optimal performance is not to be expected from this interface. To achieve optimal
42
- performance, one should specify and tune each configuration parameter.
43
-
44
- The simplest example of using this interface is the following:
45
-
46
- .. highlight:: python
47
- .. code-block:: python
48
-
49
- # A, B, C, and D are torch/numpy/cupy tensor objects
50
- plan = cutlass_cppgen.op.Gemm(A, B, C, D)
51
- plan.run()
52
-
53
-
54
- One can also use the interface by specifying data types of operands at construction
55
- and using different tensor objects with these data types at runtime:
56
-
57
- .. highlight:: python
58
- .. code-block:: python
59
-
60
- # The following is shorthand for:
61
- # cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32,
62
- # element_C=torch.float32, element_D=torch.float32,
63
- # element_accumulator=torch.float32,
64
- # layout=cutlass_cppgen.LayoutType.RowMajor)
65
- plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
66
-
67
- A0 = torch.rand((128, 256), device='cuda')
68
- B0 = torch.rand((256, 64), device='cuda')
69
- C0 = torch.zeros((128, 64), device='cuda')
70
- D0 = torch.zeros((128, 64), device.'cuda')
71
- plan.run(A0, B0, C0, D0)
72
-
73
- A = torch.rand((32, 128), device='cuda')
74
- B = torch.rand((128, 256), device='cuda')
75
- C = torch.zeros((32, 256), device='cuda')
76
- D = torch.zeros((32, 256), device.'cuda')
77
- plan.run(A1, B1, C1, D1)
78
-
79
- The interface additionally enables one to decouple the compilation of the underlying CUTLASS
80
- kernel from its execution:
81
-
82
- .. highlight:: python
83
- .. code-block:: python
84
-
85
- plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
86
- plan.compile()
87
-
88
- # Do other work...
89
-
90
- plan.run(A0, B0, C0, D0)
91
-
92
- # Do other work...
93
-
94
- plan.run(A1, B1, C1, D1)
95
-
96
- Elementwise activation functions are easily fused to the GEMM via the interface:
97
-
98
- .. highlight:: python
99
- .. code-block:: python
100
-
101
- plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
102
- plan.activation = cutlass_cppgen.epilogue.relu
103
-
104
- Operations can also be run asynchronously:
105
-
106
- .. highlight:: python
107
- .. code-block:: python
108
-
109
- plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
110
- args = plan.run()
111
-
112
- # Do other work...
113
-
114
- args.sync()
115
- """
116
- from __future__ import annotations
117
- from typing import Optional
118
- from math import prod
119
-
120
- from cutlass_cppgen.utils.lazy_import import lazy_import
121
- cuda = lazy_import("cuda.cuda")
122
- from cutlass_library import (
123
- DataType,
124
- DataTypeSize,
125
- GemmUniversalMode,
126
- KernelScheduleSuffixes,
127
- )
128
-
129
- import cutlass_cppgen
130
- from cutlass_cppgen import epilogue, swizzle
131
- from cutlass_cppgen.backend import compiler
132
- from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
133
- from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
134
- from cutlass_cppgen.backend.library import TensorDescription, TileDescription
135
- from cutlass_cppgen.op.op import OperationBase
136
- from cutlass_cppgen.shape import GemmCoord
137
- from cutlass_cppgen.utils import check, datatypes
138
-
139
-
140
- class Gemm(OperationBase):
141
- """
142
- Constructs a ``Gemm`` object.
143
-
144
- The data types and layouts of operands A, B, and C, along with the data type of output D
145
- and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime --
146
- these are not to be changed after a ``Gemm`` has been constructed.
147
-
148
- The constructor has optional parameters for flexibly setting these parameters. The following
149
- constructors are equivalent:
150
-
151
- .. highlight:: python
152
- .. code-block:: python
153
-
154
- # Use F32 for A, B, C, D, and accumulation. All operands are row major.
155
-
156
- # Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
157
- # for operands to the same values.
158
- Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
159
-
160
- # Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
161
- Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32,
162
- element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
163
-
164
- # Set the data types and elements from existing tensors. Note that one can use different tensors when
165
- # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
166
- # have the same data type and layout as those passed in here).
167
- # A, B, C, and D are row-major torch.Tensor objects of type torch.float32
168
- Gemm(A=A, B=B, C=C, D=D)
169
-
170
- # Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
171
- # the same as that for D, at present)
172
- Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor,
173
- layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor)
174
-
175
- # Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
176
- # and layouts will inherit those passed in via the generic ``element`` and ``layout``
177
- Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor,
178
- element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
179
-
180
- The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
181
- 1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
182
- 2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those
183
- 3) Otherwise, use the generic values (e.g., ``element``, ``layout``)
184
-
185
- :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
186
- :type cc: int
187
- :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
188
- :type kernel_cc: int
189
- :param A: tensor representing data type and layout of operand A
190
- :param B: tensor representing data type and layout of operand B
191
- :param C: tensor representing data type and layout of operand C
192
- :param D: tensor representing data type and layout of operand D
193
- :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
194
- :param beta: scalar parameter beta from GEMM operation that scales operand C
195
- :param element_accumulator: data type to be used in accumulation of the product of operands A and B
196
- :type element_accumulator: cutlass_cppgen.DataType
197
- :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
198
- :type element: cutlass_cppgen.DataType
199
- :param layout: generic layout type to be used for operands A, B, C, and D
200
- :type layout: cutlass_cppgen.LayoutType
201
- :param element_A: data type to be used for operand A
202
- :type element_A: cutlass_cppgen.DataType
203
- :param element_B: data type to be used for operand B
204
- :type element_B: cutlass_cppgen.DataType
205
- :param element_C: data type to be used for operand C
206
- :type element_C: cutlass_cppgen.DataType
207
- :param element_D: data type to be used for operand D
208
- :type element_D: cutlass_cppgen.DataType
209
- :param layout_A: layout of operand A
210
- :type layout_A: cutlass_cppgen.LayoutType
211
- :param layout_B: layout of operand B
212
- :type layout_B: cutlass_cppgen.LayoutType
213
- :param layout_C: layout of operand C
214
- :type layout_C: cutlass_cppgen.LayoutType
215
- :param layout_D: layout of operand D
216
- :type layout_D: cutlass_cppgen.LayoutType
217
- """
218
-
219
- def __init__(
220
- self, A=None, B=None, C=None, D=None,
221
- alpha=1.0, beta=0.0, element_accumulator=None,
222
- element=None, layout=None,
223
- element_A=None, element_B=None, element_C=None, element_D=None,
224
- layout_A=None, layout_B=None, layout_C=None,
225
- cc: int = None, kernel_cc: int = None
226
- ):
227
- super().__init__(cc=cc, kernel_cc=kernel_cc)
228
- self.name = "gemm"
229
- self.compiled = False
230
-
231
- elements = []
232
- layouts = []
233
-
234
- # Check that at least one of the following is set for each tensor (illustrated assuming tensor A):
235
- # ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout``
236
- for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D],
237
- [layout_A, layout_B, layout_C, layout_C],
238
- [A, B, C, D],
239
- ["A", "B", "C", "D"]):
240
- if elt is not None and tens is not None:
241
- raise Exception(f'Must not specify both element_{name} and tensor {name}')
242
- if lay is not None and tens is not None:
243
- raise Exception(f'Must not specify both layout_{name} and tensor {name}')
244
- if elt is None and tens is None and element is None:
245
- raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
246
- if lay is None and tens is None and layout is None:
247
- raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.')
248
-
249
- elt_to_set = None
250
- lay_to_set = None
251
- if tens is not None:
252
- elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens)
253
- else:
254
- elt_to_set = elt if elt is not None else element
255
- lay_to_set = lay if lay is not None else layout
256
-
257
- elements.append(datatypes.library_type(elt_to_set))
258
- layouts.append(lay_to_set)
259
-
260
- self._element_a, self._element_b, self._element_c, self._element_d = elements
261
- self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
262
-
263
- if element_accumulator is None:
264
- self._element_accumulator = self._element_c
265
- else:
266
- self._element_accumulator = datatypes.library_type(element_accumulator)
267
-
268
- self.A = A
269
- self.B = B
270
- self.C = C
271
- self.D = D
272
-
273
- self.alpha = alpha
274
- self.beta = beta
275
-
276
- self.epilogue_functor = None
277
- self.op_class = None
278
- self._tile_description = None
279
-
280
- self._reset_operations()
281
-
282
- self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1
283
-
284
- def _reset_operations(self, reset_epilogue: bool = True):
285
- # Set the default op class
286
- datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
287
- layout_comb = (self._layout_a, self._layout_b)
288
-
289
- self.possible_op_classes = self.options.supporting_opclasses(
290
- self._element_a, self._element_b, self._element_accumulator,
291
- self._layout_a, self._layout_b, self._math_operation)
292
-
293
- if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
294
- self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
295
- elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
296
- self.opclass = cutlass_cppgen.OpcodeClass.Simt
297
- else:
298
- if self._math_operation is not None:
299
- math_op_str = f' and math operation {self._math_operation}'
300
- else:
301
- math_op_str = ''
302
-
303
- raise Exception(f'No kernel configuration found for supported data type and layout '
304
- f'combination {datatype_comb}x{layout_comb}{math_op_str}')
305
-
306
- if reset_epilogue:
307
- self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity)
308
-
309
- @property
310
- def swizzling_functor(self):
311
- """
312
- Returns the type of the swizzling functor currently being used by the GEMM
313
-
314
- :return: swizzing functor type
315
- """
316
- return self._swizzling_functor
317
-
318
- @swizzling_functor.setter
319
- def swizzling_functor(self, swizzling_functor):
320
- """
321
- Sets the swizzling functor to the type specified by `swizzling_functor`
322
- """
323
- if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK:
324
- if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
325
- raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
326
-
327
- if self.current_cc in [90, 100, 101, 103]:
328
- raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+')
329
- self._swizzling_functor = swizzling_functor
330
-
331
- #
332
- # Tile description Related
333
- #
334
-
335
- @property
336
- def tile_description(self) -> TileDescription:
337
- """
338
- Returns the tile description
339
- """
340
- return self._tile_description
341
-
342
- @tile_description.setter
343
- def tile_description(
344
- self, td=None):
345
- """
346
- Set the tile description
347
-
348
- :param td: tile description
349
- :type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
350
- {
351
- "threadblock_shape": [int, int, int],
352
- "warp_count": [int, int, int],
353
- "stages": int,
354
- "instruction_shape": [int, int, int] (optional),
355
- "cluster_shape": [int, int, int] (optional)
356
- }
357
- """
358
- if td is None:
359
- return
360
- if isinstance(td, dict):
361
- if self._tile_description is None:
362
- op = self.possible_operations.default_operation(self._math_operation)
363
- self._tile_description = datatypes.td_from_profiler_op(op)
364
- td = self._tile_description.clone_and_update(td)
365
-
366
- valid, msg = self._valid_tile_description(td)
367
- if valid:
368
- self._tile_description = td
369
- else:
370
- raise Exception(msg)
371
-
372
- def _valid_tile_description(self, td: TileDescription) -> tuple:
373
- """
374
- Checks whether the provided tile description is valid for the given compute capability. At present,
375
- this checks the following:
376
-
377
- - Does the tile description use a number of stages supported by the compute capability in question?
378
- - Does the tile size requested fit within shared memory?
379
- - Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
380
- more non-unit cluster dimensions for pre-SM90 architectures)?
381
- - Is the kernel schedule being used supported on the architecture in question?
382
-
383
- :param td: tile description to validate
384
- :type td: cutlass_cppgen.backend.TileDescription
385
- :return: tuple in which the first element is a bool indicating that the tile description is valid
386
- and the second element is a string providing an optional error message.
387
- :rtype: tuple
388
- """
389
- valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d)
390
- if not valid:
391
- return (valid, msg)
392
-
393
- valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
394
- if not valid:
395
- return (valid, msg)
396
-
397
- valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler)
398
-
399
- if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0:
400
- valid = False
401
- msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103"
402
-
403
- return valid, msg
404
-
405
- def tile_descriptions(self) -> list:
406
- """
407
- Returns a list of valid tile descriptions for the operations
408
-
409
- :returns: list of valid tile descriptions for the operations
410
- :rtype: list
411
- """
412
- tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
413
- if self._math_operation is not None:
414
- tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation]
415
- return tds
416
-
417
- def construct(
418
- self, tile_description: TileDescription = None,
419
- alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
420
- """
421
- Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current
422
- kernel specification of the ``Gemm`` object.
423
-
424
- :param tile_description: tile description specifying shapes and operand types to use in the kernel
425
- :type tile_description: cutlass_cppgen.backend.TileDescription
426
- :param alignment_A: alignment of operand A
427
- :type alignment_A: int
428
- :param alignment_B: alignment of operand B
429
- :type alignment_B: int
430
- :param alignment_C: alignment of operand C
431
- :type alignment_C: int
432
-
433
- :return: operation that was constructed
434
- :rtype: cutlass_cppgen.backend.GemmOperationUniversal
435
- """
436
- alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
437
- alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
438
- alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A)
439
- alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
440
-
441
- tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
442
- tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
443
-
444
- if alignment_C is None:
445
- alignment_C = max(self.possible_operations.alignments("C"))
446
- if self._element_c != DataType.void:
447
- alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C)
448
-
449
- if tile_description is None:
450
- if self._tile_description is None:
451
- op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
452
- tile_description = datatypes.td_from_profiler_op(op)
453
-
454
- # The selected op may have lower alignment than that determined above, so we must
455
- # reset alignment here.
456
- alignment_C = op.C.alignment
457
- else:
458
- tile_description = self._tile_description
459
- else:
460
- valid, err_str = self._valid_tile_description(tile_description)
461
- if not valid:
462
- raise Exception(f"Invalid tile description. {err_str}")
463
- self._tile_description = tile_description
464
-
465
- tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
466
- self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
467
-
468
- operation = GemmOperationUniversal(
469
- arch=self.current_cc,
470
- tile_description=tile_description,
471
- A=tensor_A, B=tensor_B, C=tensor_C,
472
- epilogue_functor=self.epilogue_functor,
473
- swizzling_functor=self._swizzling_functor,
474
- )
475
-
476
- return operation
477
-
478
- def compile(self, tile_description: TileDescription = None,
479
- alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
480
- print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal:
481
- """
482
- Emits and compiles the kernel currently specified. If ``tile_description`` and any
483
- of the ``alignment`` parameters are set, the kernel will be chosen using this
484
- tile description and alignments. Otherwise, a default tile description and alignment
485
- will be used.
486
-
487
- :param tile_description: tile description specifying shapes and operand types to use in the kernel
488
- :type tile_description: cutlass_cppgen.backend.TileDescription
489
- :param alignment_A: alignment of operand A
490
- :type alignment_A: int
491
- :param alignment_B: alignment of operand B
492
- :type alignment_B: int
493
- :param alignment_C: alignment of operand C
494
- :type alignment_C: int
495
- :param print_module: whether to print the emitted C++ code
496
- :type print_module: bool
497
-
498
- :return: operation that was compiled
499
- :rtype: cutlass_cppgen.backend.GemmOperationUniversal
500
- """
501
- self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
502
-
503
- if print_module:
504
- print(self.operation.rt_module.emit())
505
-
506
- compiler.add_module([self.operation,])
507
- return self.operation
508
-
509
- def _verify_rank(self, tensor):
510
- """
511
- Verifies that ``tensor`` has rank greater than 1
512
-
513
- :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
514
- :type tensor: numpy/cupy/torch array/tensor object
515
- """
516
- if len(tensor.shape) < 2:
517
- raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
518
-
519
- def _get_batch_count(self, A, B, C, D) -> int:
520
- """
521
- Returns the batch count specified by the tensors A, B, C, and D and verifies that these
522
- tensors match in batch size. Presence of a batch dimension is detected by one of the
523
- tensors being rank 3. If a batch dimension is present, it must be present in one of
524
- operands A, B, or C (but need not be in all), and must be present in D.
525
-
526
- :param A: tensor A
527
- :type A: numpy/cupy/torch array/tensor object
528
- :param B: tensor B
529
- :type B: numpy/cupy/torch array/tensor object
530
- :param C: tensor C
531
- :type C: numpy/cupy/torch array/tensor object
532
- :param D: tensor D
533
- :type D: numpy/cupy/torch array/tensor object
534
-
535
- :return: tuple of batch count dimensions
536
- :rtype: tuple
537
- """
538
- A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1
539
- B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1
540
-
541
- if 1 not in [A_batch, B_batch]:
542
- if A_batch != B_batch:
543
- raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}")
544
- return max(A_batch, B_batch)
545
-
546
- def _get_batch_stride(self, tensor) -> int:
547
- """
548
- Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
549
-
550
- :param tensor: tensor object to process
551
- :type tensor: numpy/cupy/torch array/tensor object
552
-
553
- :return: stride between each matrix in the batch
554
- :rtype: int
555
- """
556
- if tensor is not None and len(tensor.shape) > 2:
557
- return tensor.shape[-2] * tensor.shape[-1]
558
- else:
559
- return 0
560
-
561
- def _get_problem_args(self, A, B, C, D) -> tuple:
562
- """
563
- Returns the problem size and GEMM universal mode to use for the
564
- given operands.
565
-
566
- :param A: tensor A
567
- :type A: numpy/cupy/torch array/tensor object
568
- :param B: tensor B
569
- :type B: numpy/cupy/torch array/tensor object
570
- :param C: tensor C
571
- :type C: numpy/cupy/torch array/tensor object
572
- :param D: tensor D
573
- :type D: numpy/cupy/torch array/tensor object
574
-
575
- :return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int)
576
- :rtype: tuple
577
- """
578
- M, K = A.shape[-2:]
579
- N = B.shape[-1]
580
- mode = GemmUniversalMode.Gemm
581
-
582
- batch_count = self._get_batch_count(A, B, C, D)
583
- returned_batch_count = batch_count
584
-
585
- # If we are running a batched GEMM in which there is a nonzero batch stride
586
- # only for A, then we can fold the batched dimension of A into the M dimension
587
- # (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
588
- # and C are row major. A similar operation can be performed if only B has a nonzero
589
- # batch dimension
590
- if batch_count > 1:
591
- A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor
592
- B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor
593
- C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor
594
-
595
- # Consider a Tensor to be batched if its rank is > 2 and
596
- # the product of the modes beyond rank 2 equals our pre-determined batch size.
597
- batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count)
598
-
599
- if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row:
600
- M *= batch_count
601
- returned_batch_count = 1
602
- elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row:
603
- N *= batch_count
604
- returned_batch_count = 1
605
- else:
606
- mode = GemmUniversalMode.Batched
607
-
608
- return GemmCoord(M, N, K), mode, returned_batch_count
609
-
610
- def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
611
- """
612
- Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
613
- is raised if it does not.
614
-
615
- :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
616
- :type tensor: numpy/cupy/torch array/tensor object
617
- :param ref_dtype: data type for the tensor that this object was initialized to
618
- :param ref_layout: layout for the tensor that this object was initialized to
619
- :param name: identifier of the tensor to verify. Used in raising exceptions
620
- :type name: str
621
- """
622
- dtype, layout = datatypes.get_datatype_and_layout(tensor)
623
- if dtype != ref_type or layout != ref_layout:
624
- try:
625
- # Attempt to transpose the tensor to fit the desired layout
626
- tensor = tensor.transpose(-1, -2)
627
- except:
628
- raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
629
- f'does not match the expected type and '
630
- f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
631
-
632
- def run(self, A=None, B=None, C=None, D=None,
633
- alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
634
- stream: Optional[cuda.CUstream] = None) -> GemmArguments:
635
- """
636
- Runs the kernel currently specified. If it has not already been, the kernel is emitted and
637
- compiled. Tensors holding operands and outputs of the kernel are sourced either from the
638
- ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
639
- parameters provided in this call, or from those
640
- passed in on the construction of this object -- one of the two must be specified.
641
-
642
- By default, this call returns only once the kernel has completed. To launch the kernel
643
- and immediately return, set ``sync=False``. In this case, it is the responsibility of the
644
- caller to syncrhonize the results of the kernel before attempting to access outputs
645
- by calling ``sync()`` on the arguments returned from this call.
646
-
647
- :param A: tensor representing data type and layout of operand A
648
- :param B: tensor representing data type and layout of operand B
649
- :param C: tensor representing data type and layout of operand C
650
- :param D: tensor representing data type and layout of operand D
651
- :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
652
- :param beta: scalar parameter beta from GEMM operation that scales operand C
653
- :param sync: whether the call should wait for the kernel to complete before returning
654
- :type sync: bool
655
- :param print_module: whether to print the emitted C++ code
656
- :type print_module: bool
657
- :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
658
- :type stream: :class:`cuda.cuda.CUstream`
659
-
660
- :return: arguments passed in to the kernel
661
- :rtype: cutlass_cppgen.backend.GemmArguments
662
- """
663
- if not stream:
664
- stream = cuda.CUstream(0)
665
- super().run_setup()
666
- A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
667
- B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
668
- C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
669
- D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
670
- alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
671
- beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
672
-
673
- is_void_c = self._element_c == DataType.void
674
-
675
- self._verify_rank(A)
676
- self._verify_rank(B)
677
- if not is_void_c:
678
- self._verify_rank(C)
679
- self._verify_rank(D)
680
-
681
- alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A")
682
- alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B")
683
-
684
- # Set C alignment based on D.shape so as to correctly get an alignment with void-C
685
- # kernels, for which `C` is None.
686
- alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C")
687
- self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
688
- alignment_C=alignment_c, print_module=print_module)
689
-
690
- problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
691
-
692
- if mode == GemmUniversalMode.Gemm or batch_count == 1:
693
- kwargs = {'split_k_slices': 1}
694
- else:
695
- kwargs = {
696
- 'batch': batch_count,
697
- 'batch_strides': {
698
- 'A': self._get_batch_stride(A),
699
- 'B': self._get_batch_stride(B),
700
- 'C': self._get_batch_stride(C),
701
- 'D': self._get_batch_stride(D)
702
- }
703
- }
704
-
705
- kwargs['stream'] = stream
706
-
707
- if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
708
- output_op = self.operation.epilogue_type(visitor_args)
709
- else:
710
- output_op = self.operation.epilogue_type(alpha, beta)
711
-
712
- arguments = GemmArguments(
713
- operation=self.operation, problem_size=problem_size,
714
- A=A, B=B, C=C, D=D,
715
- output_op=output_op,
716
- gemm_mode=mode,
717
- **kwargs
718
- )
719
-
720
- self.operation.run(arguments)
721
-
722
- if sync:
723
- arguments.sync()
724
-
725
- return arguments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py DELETED
@@ -1,269 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Ease-of-use interface for constructing, compiling, and running GEMMs.
35
-
36
- The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run
37
- grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
38
- Under the hood, the interface will select sensible default parameters for the many template
39
- parameters for CUTLASS grouped GEMMs.
40
-
41
- Note: optimal performance is not to be expected from this interface. To achieve optimal
42
- performance, one should specify and tune each configuration parameter.
43
-
44
- The simplest example of using this interface is the following:
45
-
46
- .. highlight:: python
47
- .. code-block:: python
48
-
49
- # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
50
- plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
51
- plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
52
- """
53
- from __future__ import annotations
54
- from typing import Optional
55
- from cutlass_library import DataTypeSize
56
-
57
- from cutlass_cppgen.utils.lazy_import import lazy_import
58
- cuda = lazy_import("cuda.cuda")
59
- from cutlass_cppgen.backend.gemm_operation import (
60
- GemmGroupedArguments,
61
- GemmOperationGrouped,
62
- )
63
- from cutlass_cppgen.backend.library import (
64
- SchedulerMode,
65
- TensorDescription,
66
- TileDescription,
67
- )
68
- from cutlass_cppgen.op.gemm import Gemm
69
- from cutlass_cppgen.shape import GemmCoord
70
- from cutlass_cppgen.utils import check, datatypes
71
-
72
-
73
- class GroupedGemm(Gemm):
74
- """
75
- Constructs a ``GroupedGemm`` object.
76
-
77
- The data types and layouts of operands A, B, and C, along with the data type of output D
78
- and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --
79
- these are not to be changed after a ``GroupedGemm`` has been constructed.
80
-
81
- The constructor has optional parameters for flexibly setting these parameters. Please see the constructor
82
- for ``Gemm`` for examples of these.
83
-
84
- :param cc: compute capability of device to generate kernels for
85
- :type cc: int
86
- :param A: tensor representing data type and layout of operands A
87
- :param B: tensor representing data type and layout of operands B
88
- :param C: tensor representing data type and layout of operands C
89
- :param D: tensor representing data type and layout of operands D
90
- :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
91
- :param beta: scalar parameter beta from GEMM operation that scales operand C
92
- :param element_accumulator: data type to be used in accumulation of the product of operands A and B
93
- :type element_accumulator: cutlass_cppgen.DataType
94
- :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
95
- :type element: cutlass_cppgen.DataType
96
- :param layout: generic layout type to be used for operands A, B, C, and D
97
- :type layout: cutlass_cppgen.LayoutType
98
- :param element_A: data type to be used for operand A
99
- :type element_A: cutlass_cppgen.DataType
100
- :param element_B: data type to be used for operand B
101
- :type element_B: cutlass_cppgen.DataType
102
- :param element_C: data type to be used for operand C
103
- :type element_C: cutlass_cppgen.DataType
104
- :param element_D: data type to be used for operand D
105
- :type element_D: cutlass_cppgen.DataType
106
- :type layout_A: layout of operand A
107
- :param layout_A: cutlass_cppgen.LayoutType
108
- :type layout_B: layout of operand B
109
- :param layout_B: cutlass_cppgen.LayoutType
110
- :type layout_C: layout of operand C
111
- :param layout_C: cutlass_cppgen.LayoutType
112
- :type layout_D: layout of operand D
113
- :param layout_D: cutlass_cppgen.LayoutType
114
- """
115
-
116
- def __init__(
117
- self, A=None, B=None, C=None, D=None,
118
- alpha=1.0, beta=0.0, element_accumulator=None,
119
- element=None, layout=None,
120
- element_A=None, element_B=None, element_C=None, element_D=None,
121
- layout_A=None, layout_B=None, layout_C=None,
122
- cc: int = None,
123
- ):
124
- super().__init__(
125
- A=A, B=B, C=C, D=D,
126
- alpha=alpha, beta=beta,
127
- element_accumulator=element_accumulator,
128
- element=element, layout=layout,
129
- element_A=element_A, element_B=element_B,
130
- element_C=element_C, element_D=element_D,
131
- layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
132
- cc=cc
133
- )
134
-
135
- # Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80
136
- if self.current_cc in [90, 100, 101, 103]:
137
- self._reset_options(80)
138
- self._reset_operations(reset_epilogue=False)
139
-
140
- self.name = "grouped_gemm"
141
-
142
- @Gemm.swizzling_functor.setter
143
- def swizzling_functor(self, swizzling_functor):
144
- """
145
- Sets the swizzling functor to the type specified by `swizzling_functor`
146
- """
147
- raise Exception('Grouped GEMM does not currently support different swizzling functors')
148
-
149
- def construct(self, tile_description: TileDescription = None,
150
- alignment_A: int = None,
151
- alignment_B: int = None,
152
- alignment_C: int = None) -> GemmOperationGrouped:
153
- """
154
- Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current
155
- kernel specification of the ``Gemm`` object.
156
-
157
- :param tile_description: tile description specifying shapes and operand types to use in the kernel
158
- :type tile_description: cutlass_cppgen.backend.TileDescription
159
- :param alignment_A: alignment of operand A
160
- :type alignment_A: int
161
- :param alignment_B: alignment of operand B
162
- :type alignment_B: int
163
- :param alignment_C: alignment of operand C
164
- :type alignment_C: int
165
-
166
- :return: operation that was constructed
167
- :rtype: cutlass_cppgen.backend.GemmOperationGrouped
168
- """
169
- alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A")))
170
- alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B")))
171
- alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C")))
172
-
173
- self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
174
-
175
- tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
176
- tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
177
- tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
178
-
179
- if tile_description is None:
180
- op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
181
- tile_description = datatypes.td_from_profiler_op(op)
182
- else:
183
- valid, err_str = self._valid_tile_description(tile_description)
184
- if not valid:
185
- raise Exception(f"Invalid tile description. {err_str}")
186
- self.tile_description = tile_description
187
-
188
- operation = GemmOperationGrouped(
189
- arch=self.current_cc,
190
- tile_description=tile_description,
191
- A=tensor_A, B=tensor_B, C=tensor_C,
192
- epilogue_functor=self.epilogue_functor,
193
- swizzling_functor=self._swizzling_functor,
194
- precompute_mode=SchedulerMode.Device)
195
-
196
- return operation
197
-
198
- def run(self, A, B, C, D,
199
- alpha=None, beta=None, sync: bool = True,
200
- print_module: bool = False,
201
- stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments:
202
- """
203
- Runs the kernel currently specified.
204
-
205
- By default, this call returns only once the kernel has completed. To launch the kernel
206
- and immediately return, set ``sync=False``. In this case, it is the responsibility of the
207
- caller to syncrhonize the results of the kernel before attempting to access outputs
208
- by calling ``sync()`` on the arguments returned from this call.
209
-
210
- :param A: list of tensors representing data type and layout of operand A
211
- :type A: list
212
- :param B: list of tensors representing data type and layout of operand B
213
- :type B: list
214
- :param C: list of tensors representing data type and layout of operand C
215
- :type C: list
216
- :param D: list of tensors representing data type and layout of operand D
217
- :type D: list
218
- :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
219
- :param beta: scalar parameter beta from GEMM operation that scales operand C
220
- :param sync: whether the call should wait for the kernel to complete before returning
221
- :type sync: bool
222
- :param print_module: whether to print the emitted C++ code
223
- :type print_module: bool
224
- :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
225
- :type stream: :class:`cuda.cuda.CUstream`
226
-
227
- :return: arguments passed in to the kernel
228
- :rtype: cutlass_cppgen.backend.GemmGroupedArguments
229
- """
230
- if not stream:
231
- stream = cuda.CUstream(0)
232
-
233
- super().run_setup()
234
-
235
- if len(A) != len(B) or len(A) != len(C) or len(A) != len(D):
236
- raise Exception("Lengths of A, B, C, and D lists must be equal")
237
-
238
- problem_sizes = []
239
- As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4))
240
- for i in range(len(A)):
241
- As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A")
242
- Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
243
- Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
244
- Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
245
- problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
246
-
247
- alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
248
- beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
249
-
250
- alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As))
251
- alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs))
252
- alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs))
253
- self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
254
- alignment_C=alignment_c, print_module=print_module)
255
-
256
- arguments = GemmGroupedArguments(
257
- operation=self.operation,
258
- problem_sizes=problem_sizes,
259
- A=As, B=Bs, C=Cs, D=Ds,
260
- output_op=self.operation.epilogue_type(alpha, beta),
261
- stream=stream
262
- )
263
-
264
- self.operation.run(arguments)
265
-
266
- if sync:
267
- arguments.sync()
268
-
269
- return arguments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py DELETED
@@ -1,431 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
35
- """
36
-
37
- from bisect import bisect_left
38
-
39
- from cutlass_library import (
40
- DataType,
41
- DataTypeSize,
42
- MathOperation,
43
- OperationKind,
44
- SharedMemPerCC
45
- )
46
-
47
- import cutlass_cppgen
48
- from cutlass_cppgen import get_option_registry
49
- from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
50
- from cutlass_cppgen.backend.evt.passes.util import cc_map
51
- from cutlass_cppgen.backend.utils.device import device_cc
52
- from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
53
- from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
54
- from cutlass_cppgen.swizzle import get_swizzling_functors
55
- from cutlass_cppgen.utils import datatypes, check
56
-
57
-
58
- class OperationBase:
59
- """
60
- Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
61
- """
62
-
63
- def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm):
64
- """
65
- :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
66
- :type cc: int
67
- :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
68
- :type kernel_cc: int
69
- :param operation_kind: class of operation that will be performed (e.g., GEMM, Conv)
70
- :type operation_kind: cutlass_library.OperationKind
71
- """
72
- self.operation_kind = operation_kind
73
- self.cc = cc if cc is not None else device_cc()
74
- self.specified_kernel_cc = kernel_cc is not None
75
- self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
76
- self.tile_description = None
77
- self._math_operation = None
78
-
79
- self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
80
-
81
- if self.options is None:
82
- raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
83
-
84
- # Default activation function: identity
85
- self._activation = identity
86
-
87
- def _find_closest_cc(self, cc: int) -> int:
88
- """
89
- Returns the closest CC in _generator_ccs less than or equal to `cc`
90
-
91
- :param cc: compute capability to query
92
- :type cc: int
93
-
94
- :returns: closest CC in _generator_ccs less than or equal to `cc`
95
- :rtype: int
96
- """
97
- if cc in _generator_ccs:
98
- return cc
99
-
100
- # Find closest CC lower than this CC
101
- idx = bisect_left(_generator_ccs, cc)
102
- if idx == 0:
103
- raise Exception(f'No valid CC to fall back to for {cc}')
104
- return _generator_ccs[idx-1]
105
-
106
- def activations(self) -> list:
107
- """
108
- Returns possible activation functions that can be used
109
-
110
- :return: list of activation functions that can be used
111
- :rtype: list
112
- """
113
- return get_activations()
114
-
115
- def swizzling_functors(self) -> list:
116
- """
117
- Returns possible swizzling functions that can be used
118
-
119
- :return: list of swizzling functions that can be used
120
- :rtype: list
121
- """
122
- return get_swizzling_functors()
123
-
124
- def _reset_options(self, cc: int):
125
- """
126
- Resets the kernel options based on cc
127
-
128
- :param cc: compute capability to reset to
129
- :type cc: int
130
- """
131
- if cc != self.current_cc:
132
- if cc not in _generator_ccs:
133
- raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
134
- self.current_cc = cc
135
- self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind)
136
-
137
- def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
138
- """
139
- Verifies the following properties:
140
- 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
141
- 2) If ``scalar`` is not ``None``, its datatype must match matches the current version
142
- set by the plan (i.e., those in ``ref_dtype``)
143
-
144
- If either of these properties does not hold, an exception is raised. If these properties hold and
145
- ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
146
-
147
- :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
148
- :type scalar: numpy/cupy/torch scalar
149
- :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
150
- :type ref_scalar: numpy/cupy/torch scalar
151
- :param ref_dtype: data type for the scalar that this object was initialized to
152
- :param name: identifier of the scalar to verify. Used in raising exceptions
153
- :type name: str
154
-
155
- :return: valid scalar to use
156
- :rtype: numpy/cupy/torch scalar
157
- """
158
- if scalar is None:
159
- if ref_scalar is None:
160
- raise Exception(f"Scalar {name} must be set.")
161
- return ref_scalar
162
- if hasattr(scalar, "dtype"):
163
- dtype = datatypes.library_type(scalar.dtype)
164
- if dtype != ref_dtype:
165
- raise Exception(
166
- f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
167
- )
168
- return scalar
169
-
170
- def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
171
- """
172
- Verifies the following properties:
173
- If ref_dtype is not void:
174
- 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
175
- 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
176
- set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
177
- If ref_dtype is void:
178
- Neither ``tensor`` nor ``ref_tensor`` are set
179
-
180
- If either of these properties does not hold, an exception is raised. If these properties hold and
181
- ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
182
-
183
- :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
184
- :type tensor: numpy/cupy/torch array/tensor object
185
- :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
186
- :type ref_tensor: numpy/cupy/torch array/tensor object
187
- :param ref_dtype: data type for the tensor that this object was initialized to
188
- :param ref_layout: layout for the tensor that this object was initialized to
189
- :param name: identifier of the tensor to verify. Used in raising exceptions
190
- :type name: str
191
-
192
- :return: valid tensor object to use
193
- :rtype: numpy/cupy/torch array/tensor object
194
- """
195
- if ref_dtype == DataType.void:
196
- if tensor is not None or ref_tensor is not None:
197
- raise Exception("Operands with element DataType.void must not be provided a tensor")
198
- return None
199
-
200
- if tensor is None:
201
- if ref_tensor is None:
202
- raise Exception(f"Tensor {name} must be set.")
203
- return ref_tensor
204
-
205
- self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
206
- return tensor
207
-
208
- @property
209
- def opclass(self) -> cutlass_cppgen.OpcodeClass:
210
- """
211
- Returns the opcode class currently in use
212
-
213
- :return: opcode class currently in use
214
- :rtype: cutlass_cppgen.OpcodeClass
215
- """
216
- return self.op_class
217
-
218
- @opclass.setter
219
- def opclass(self, oc: cutlass_cppgen.OpcodeClass):
220
- if isinstance(oc, str):
221
- oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
222
- if oc in self.possible_op_classes:
223
- self.op_class = oc
224
- else:
225
- raise Exception(
226
- f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
227
- f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
228
- f'layout combination ({self._layout_a}, {self._layout_b}).')
229
-
230
- # Changing the op class also changes the possible operations available. Reset these.
231
- self.possible_operations = self.options.operations(
232
- self.op_class, self._element_a, self._element_b,
233
- self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
234
-
235
- # Changing the op class changes the elements per access in the epilogue. Reset this.
236
- if self.epilogue_functor is not None:
237
- self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
238
-
239
- @property
240
- def math_operation(self) -> cutlass_cppgen.MathOperation:
241
- """
242
- Returns the math operation currently in use
243
-
244
- :return: math operation currently in use
245
- :rtype: cutlass_cppgen.MathOperation
246
- """
247
- return self._math_operation
248
-
249
- @math_operation.setter
250
- def math_operation(self, mo: cutlass_cppgen.MathOperation):
251
- if isinstance(mo, str):
252
- mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
253
-
254
- if not self.specified_kernel_cc:
255
- if self.current_cc in [90, 100, 101, 103]:
256
- # CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
257
- # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
258
- cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
259
- self._reset_options(80)
260
- self._reset_operations(reset_epilogue=False)
261
- elif self.current_cc in [90, 100, 101, 103]:
262
- raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
263
- "To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
264
- "parameter when constructing the plan.")
265
-
266
- self._math_operation = mo
267
- self._reset_operations()
268
-
269
- def _elements_per_access(self):
270
- if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
271
- return 1
272
- elif self._element_c != DataType.void:
273
- return 128 // DataTypeSize[self._element_c]
274
- else:
275
- return 128 // max(self.possible_operations.alignments("C"))
276
-
277
- def _create_epilogue_functor_activation(self, activation):
278
- """
279
- Returns the epilogue functor with given activation function
280
- """
281
- if self.epilogue_functor is None:
282
- elements_per_access = self._elements_per_access()
283
- else:
284
- elements_per_access = self.epilogue_functor.epilogue_vector_length
285
-
286
- if not self.specified_kernel_cc:
287
- if self.current_cc in [90, 100, 101, 103] and activation != identity:
288
- # CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
289
- # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
290
- cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
291
- if self._element_c != self._element_d:
292
- raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
293
- self._reset_options(80)
294
- self._reset_operations(reset_epilogue=False)
295
- elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None):
296
- # SM80 fallback kernels are currently used. Since an identity activation is requested,
297
- # we can switch back to using SM90 kernels.
298
- self._reset_options(self.cc)
299
- self._reset_operations(reset_epilogue=False)
300
- else:
301
- if self.current_cc in [90, 100, 101, 103] and activation != identity:
302
- raise Exception("Epilogues with elementwise fusion are not currently supported "
303
- "in the Python interface for 3.x kernels. To use 2.x kernels "
304
- "with fused elementwise epilogues, do not set the `kernel_cc` "
305
- "parameter when constructing the plan.")
306
-
307
- return get_activation_epilogue(
308
- activation,
309
- self._element_d,
310
- elements_per_access,
311
- self._element_accumulator,
312
- self._element_accumulator,
313
- )
314
-
315
- def _reset_epilogue_functor_activation(self, activation):
316
- """
317
- Set the epilogue functor based on the provided activation function
318
- """
319
- self.epilogue_functor = self._create_epilogue_functor_activation(activation)
320
-
321
- def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
322
- """
323
- Reset the alignment of the current epilogue functor based on alignment C
324
- """
325
- if isinstance(epilogue_functor, EpilogueFunctorVisitor):
326
- return epilogue_functor
327
-
328
- if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
329
- # Identity epilogue does not have 'activation_functor'
330
- activation = identity
331
- else:
332
- activation = epilogue_functor.activation_functor
333
-
334
- epilogue_functor = get_activation_epilogue(
335
- activation,
336
- self._element_d,
337
- alignment,
338
- self._element_accumulator,
339
- self._element_accumulator,
340
- )
341
- return epilogue_functor
342
-
343
- @property
344
- def activation(self):
345
- """
346
- Returns the type of the current activation function used
347
- """
348
- if hasattr(self.epilogue_functor, "activation_functor"):
349
- return self.epilogue_functor.activation_functor
350
- else:
351
- return identity
352
-
353
- @activation.setter
354
- def activation(self, act):
355
- """
356
- Sets the type of the activation function to use
357
- Activation can come with a set of arguments
358
-
359
- :param act: type of activation function to use
360
- :type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
361
-
362
- """
363
- if isinstance(act, tuple):
364
- if isinstance(act[0], str):
365
- act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
366
- else:
367
- act_fn = act[0]
368
- self._reset_epilogue_functor_activation(act_fn)
369
- self._activation_args = act[1]
370
- self._activation = act[0]
371
- else:
372
- if isinstance(act, str):
373
- act = getattr(cutlass_cppgen.backend.epilogue, act)
374
- self._reset_epilogue_functor_activation(act)
375
- self._activation = act
376
-
377
- @property
378
- def epilogue_visitor(self):
379
- """
380
- Return the epilogue functor
381
- """
382
- return self.epilogue_functor
383
-
384
- @epilogue_visitor.setter
385
- def epilogue_visitor(self, visitor):
386
- """
387
- Create the epilogue visitor
388
- """
389
- self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor)
390
-
391
- # The epilogue_functor may consume too much shared memory
392
- # Reset the possible operations
393
- if self.cc not in [90, 100, 101, 103]:
394
- # The shared memory is only a concern for sm90+ epilogue
395
- # In sm80, the epilogue and mainloop share the shared memory
396
- return
397
-
398
- datatype_comb = self.possible_operations.datatype_comb
399
- layout_comb = self.possible_operations.layout_comb
400
- new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
401
- for operation in self.possible_operations.all_operations:
402
- td = datatypes.td_from_profiler_op(operation)
403
- # Filter invalid epilogue schedules
404
- if cc_map[self.cc] == 90 and td.epilogue_schedule not in [
405
- cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
406
- cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
407
- continue
408
- epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
409
-
410
- # Verify the maximum number of mainloop stages
411
- mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm)
412
- smem_capacity_bytes = SharedMemPerCC[self.cc] << 10
413
- mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
414
- if mainloop_stages < 2:
415
- # Mainloop stages must >= 2
416
- continue
417
-
418
- new_possible_operations.add(operation)
419
- if len(new_possible_operations.all_operations) == 0:
420
- raise RuntimeError(
421
- "The epilogue consumes too much shared memory. "
422
- "No valid tile description is found in the generator.")
423
- self.possible_operations = new_possible_operations
424
-
425
-
426
- def run_setup(self):
427
- """
428
- Steps that must be taken before caling `plan.run()`
429
- """
430
- # Initialize the memory pool if, if not already done
431
- cutlass_cppgen.get_memory_pool()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py DELETED
@@ -1,184 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Utilities for expressing shapes
35
- """
36
-
37
- from cutlass_library import (
38
- ConvMode,
39
- ConvKind,
40
- LayoutType
41
- )
42
- from cutlass_cppgen.backend.c_types import (
43
- Conv2DProblemSize_,
44
- GemmCoord_,
45
- GemmCoordBatched_
46
- )
47
-
48
-
49
- class MatrixCoord:
50
- def __init__(self, row, col):
51
- self._row = row
52
- self._col = col
53
-
54
- @property
55
- def row(self):
56
- return self._row
57
-
58
- @property
59
- def column(self):
60
- return self._col
61
-
62
- def leading_dimension(self, layout: LayoutType) -> int:
63
- """
64
- Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord.
65
-
66
- :param layout: layout of matrix
67
- :type layout: cutlass_library.LayoutType
68
-
69
- :returns: leading dimension
70
- :rtype: int
71
- """
72
- if layout == LayoutType.RowMajor:
73
- return self._col
74
- elif layout == LayoutType.ColumnMajor:
75
- return self._row
76
- else:
77
- raise Exception(f'Unsupported layout for leading dimension calculation: {layout}')
78
-
79
-
80
- class GemmCoord:
81
- def __init__(self, m: int, n: int, k: int):
82
- self._m = m
83
- self._n = n
84
- self._k = k
85
-
86
- @property
87
- def m(self) -> int:
88
- return self._m
89
-
90
- @property
91
- def n(self) -> int:
92
- return self._n
93
-
94
- @property
95
- def k(self) -> int:
96
- return self._k
97
-
98
- @property
99
- def mk(self) -> MatrixCoord:
100
- return MatrixCoord(self._m, self._k)
101
-
102
- @property
103
- def mn(self) -> MatrixCoord:
104
- return MatrixCoord(self._m, self._n)
105
-
106
- @property
107
- def kn(self) -> MatrixCoord:
108
- return MatrixCoord(self._k, self._n)
109
-
110
- @property
111
- def ctype(self) -> GemmCoord_:
112
- return GemmCoord_(self._m, self._n, self._k)
113
-
114
- def batched_ctype(self, batch_count: int) -> GemmCoordBatched_:
115
- return GemmCoordBatched_(self._m, self._n, self._k, batch_count)
116
-
117
-
118
- class Conv2DProblemSize:
119
- def __init__(
120
- self, n: int, h: int, w: int, c: int,
121
- k: int, r: int, s: int, c_: int,
122
- pad_h: int, pad_w: int, stride_h: int, stride_w: int,
123
- dilation_h: int, dilation_w: int, mode: ConvMode=ConvMode.CrossCorrelation,
124
- split_k_slices: int=1, groups: int=1):
125
-
126
- self.N = n
127
- self.H = h
128
- self.W = w
129
- self.C = c
130
- self.K = k
131
- self.R = r
132
- self.S = s
133
- self.pad_h = pad_h
134
- self.pad_w = pad_w
135
- self.stride_h = stride_h
136
- self.stride_w = stride_w
137
- self.dilation_h = dilation_h
138
- self.dilation_w = dilation_w
139
- self.mode = int(mode)
140
- self.split_k_slices = split_k_slices
141
- self.groups = groups
142
- self.P = ((h + pad_h * 2 - r * dilation_h) // stride_h) + 1
143
- self.Q = ((w + pad_w * 2 - s * dilation_w) // stride_w) + 1
144
-
145
- @property
146
- def ctype(self) -> Conv2DProblemSize_:
147
- return Conv2DProblemSize_(self)
148
-
149
- def implicit_gemm_size(self, kind: ConvKind):
150
- if kind == ConvKind.Fprop:
151
- return GemmCoord(
152
- self.N * self.P * self.Q,
153
- self.K,
154
- self.R * self.S * self.C // self.groups
155
- )
156
- elif kind == ConvKind.Dgrad:
157
- return GemmCoord(
158
- self.N * self.H * self.W,
159
- self.C,
160
- self.R * self.S * self.K
161
- )
162
- elif kind == ConvKind.Wgrad:
163
- return GemmCoord(
164
- self.K,
165
- self.R * self.S * self.C,
166
- self.N * self.P * self.Q
167
- )
168
-
169
- @staticmethod
170
- def from_sizes(input_size, weight_size):
171
- K, R, S, _ = weight_size
172
- pad_h = R // 2
173
- pad_w = S // 2
174
- stride_h = 1
175
- stride_w = 1
176
- dilation_h = 1
177
- dilation_w = 1
178
- return Conv2DProblemSize(
179
- *input_size,
180
- *weight_size,
181
- pad_h, pad_w,
182
- stride_h, stride_w,
183
- dilation_h, dilation_w
184
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py DELETED
@@ -1,65 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Registry of swizzling functions
35
- """
36
-
37
- from cutlass_library import SwizzlingFunctor
38
-
39
-
40
- IdentitySwizzle1 = SwizzlingFunctor.Identity1
41
- IdentitySwizzle2 = SwizzlingFunctor.Identity2
42
- IdentitySwizzle4 = SwizzlingFunctor.Identity4
43
- IdentitySwizzle8 = SwizzlingFunctor.Identity8
44
- HorizontalSwizzle = SwizzlingFunctor.Horizontal
45
- ThreadblockSwizzleStreamK = SwizzlingFunctor.StreamK
46
- StridedDgradIdentitySwizzle1 = SwizzlingFunctor.StridedDgradIdentity1
47
- StridedDgradIdentitySwizzle4 = SwizzlingFunctor.StridedDgradIdentity4
48
- StridedDgradHorizontalSwizzle = SwizzlingFunctor.StridedDgradHorizontal
49
-
50
-
51
- _swizzling_functors = [
52
- IdentitySwizzle1,
53
- IdentitySwizzle2,
54
- IdentitySwizzle4,
55
- IdentitySwizzle8,
56
- HorizontalSwizzle,
57
- ThreadblockSwizzleStreamK,
58
- StridedDgradIdentitySwizzle1,
59
- StridedDgradIdentitySwizzle4,
60
- StridedDgradHorizontalSwizzle,
61
- ]
62
-
63
-
64
- def get_swizzling_functors():
65
- return _swizzling_functors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py DELETED
@@ -1,41 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- from cutlass_cppgen.utils.check import (
34
- alignment_or_default,
35
- calculate_smem_usage,
36
- calculate_smem_usage_per_stage,
37
- valid_cluster_shape,
38
- valid_schedule,
39
- valid_stage_count,
40
- update_alignment,
41
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py DELETED
@@ -1,262 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Utility functions for checking constraints on kernels and calculating kernel attributes
35
- """
36
-
37
- import ctypes
38
-
39
- from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC
40
-
41
- import cutlass_cppgen
42
- from cutlass_cppgen.backend.library import TileDescription
43
-
44
-
45
- def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int:
46
- """
47
- Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
48
-
49
- :param td: tile description to compute shared memory of
50
- :type td: TileDescription
51
- :param operation_kind: identifier for the type of operation being performed
52
- :type operation_kind: cutlass_library.OperationKind
53
-
54
- :return: number of bytes of shared memory consumed by a single stage
55
- :rtype: int
56
- """
57
- m, n, k = td.blackwell_threadblock_shape
58
- if td.is_2sm:
59
- m //= 2
60
-
61
- if operation_kind == OperationKind.Gemm:
62
- stage_barrier_bytes = 32
63
- return (
64
- (DataTypeSize[td.math_instruction.element_a] * m * k // 8)
65
- + (DataTypeSize[td.math_instruction.element_b] * k * n // 8)
66
- + stage_barrier_bytes
67
- )
68
- else:
69
- raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")
70
-
71
-
72
- def calculate_smem_usage(operation) -> int:
73
- """
74
- Returns the amount of shared memory in bytes consumed by a kernel.
75
-
76
- :return: number of bytes of shared memory consumed by the operation
77
- :return: int
78
- """
79
- _per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind)
80
- return _per_stage * operation.tile_description.stages
81
-
82
-
83
- def valid_stage_count(
84
- cc: int,
85
- kernel_cc: int,
86
- td: TileDescription,
87
- element_C: cutlass_cppgen.DataType = None,
88
- element_D: cutlass_cppgen.DataType = None,
89
- verbose: bool = True) -> tuple:
90
- """
91
- Checks whether a device with `cc` supports the number of stages within `tile_description`, both
92
- based on raw limits on the number of stages and based on shared memory capacity
93
-
94
- :param cc: compute capability of device in question
95
- :type cc: int
96
- :param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS)
97
- :type kernel_cc: int
98
- :param td: tile description to check
99
- :type td: TileDescription
100
- :param element_C: data type of operand C
101
- :type element_C: cutlass_cppgen.DataType
102
- :param element_D: data type of operand D
103
- :type element_D: cutlass_cppgen.DataType
104
- :param verbose: whether to log warnings
105
- :type verbose: bool
106
-
107
- :return: tuple with the first element indicating whether the provided tile description is
108
- valid for the provided device and the second element being an error message
109
- :rtype: tuple
110
- """
111
- if kernel_cc in [90, 100, 101, 103]:
112
- if (td.stages is None or td.stages == 0):
113
- # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
114
- # determines the stage count to use. Thus, all settings are valid in these scenarios.
115
- return (True, "")
116
- elif verbose:
117
- cutlass_cppgen.logger.warning(
118
- "Setting an explicit stage count for SM90 kernels currently may "
119
- "result in compilation errors if the combination of tile shape, "
120
- "stage count, and shared memory requirement of the epilogue exceeds "
121
- "the available shared memory per SM.")
122
-
123
- if td.stages <= 0:
124
- return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")
125
-
126
- if cc < 80 and td.stages != 2:
127
- return (False, f"Tile description has stage count of {td.stages}, "
128
- f"but only 2 stages are supported on SM{cc}.")
129
-
130
- # The calculation below does not consider shared memory used by the epilogue and, thus,
131
- # only catches cases in which the mainloop exceeds the device's shared memory capacity.
132
- # This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the
133
- # mainloop and epilogue is shared.
134
- smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm)
135
- smem_usage_mainloop = (smem_per_stage * td.stages)
136
- smem_arch = SharedMemPerCC[cc] << 10
137
- if smem_usage_mainloop > smem_arch:
138
- return ( False,
139
- "Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
140
- f"Details:\n"
141
- f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and "
142
- f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n"
143
- f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.")
144
-
145
- return (True, "")
146
-
147
-
148
- def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
149
- """
150
- Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`.
151
-
152
- :param cc: compute capability of device in question
153
- :type cc: int
154
- :param cluster_shape: dimensions of thread block cluster shape to check
155
- :type cluster_shape: list
156
-
157
- :return: tuple with the first element indicating whether the provided cluster shape is
158
- valid for the provided device and the second element being an error message
159
- :rtype: tuple
160
- """
161
-
162
- if cc < 90 or cc in [120, 121]:
163
- if cluster_shape != [1, 1, 1]:
164
- return (False,
165
- f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of "
166
- f"{cluster_shape} for SM{cc}.")
167
- else:
168
- return (True, "")
169
-
170
- if len(cluster_shape) != 3:
171
- return (False,
172
- f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}")
173
-
174
- if cluster_shape[2] != 1:
175
- return (False,
176
- "CUTLASS kernels currently require the third dimension of cluster shape to be 1. "
177
- f"Received cluster shape of {cluster_shape}.")
178
-
179
- return (True, "")
180
-
181
-
182
- def valid_schedule(
183
- cc: int,
184
- kernel_schedule: cutlass_cppgen.KernelScheduleType,
185
- epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
186
- tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple:
187
- """
188
- Checks that the kernel and epilogue schedules passed in are a valid combination for
189
- a device of compute capability ``cc``.
190
-
191
- :param cc: compute capability of device in question
192
- :type cc: int
193
- :param kernel_schedule: kernel schedule type
194
- :type kernel_schedule: cutlass_cppgen.KernelScheduleType
195
- :param epilogue_schedule: epilogue schedule type
196
- :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
197
- :param tile_scheduler: tile scheduler type
198
- :type tile_scheduler: cutlass_cppgen.TileSchedulerType
199
-
200
- :return: tuple with the first element indicating whether the provided schedules are
201
- valid for the provided device and the second element being an error message
202
- :rtype: tuple
203
- """
204
- kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto)
205
- epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto)
206
- tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default)
207
- if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default):
208
- return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)")
209
-
210
- if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)):
211
- return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
212
-
213
- if not tile_scheduler_default:
214
- cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
215
- cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
216
- if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
217
- return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
218
- return (True, "")
219
-
220
-
221
- def alignment_or_default(alignment_provided: int, default_alignment: int) -> int:
222
- """
223
- Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
224
- that `alignment_provided` does not exceed `default_alignment`.
225
-
226
- :param alignment_provided: alignment preference specified. Can be None.
227
- :type alignment_provided: int
228
- :param default_alignment: alignment to use if `alignment_provided` is None
229
- :type default_alignment: int
230
-
231
- :return: alignment to use
232
- :rtype: int
233
- """
234
- if alignment_provided is not None:
235
- if alignment_provided > default_alignment:
236
- raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
237
- return alignment_provided
238
-
239
- return default_alignment
240
-
241
-
242
- def update_alignment(alignment_provided:int, default_alignment: int) -> int:
243
- """
244
- Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
245
- that `alignment_provided` does not exceed `default_alignment`.
246
-
247
- :param alignment_provided: alignment preference specified. Can be None.
248
- :type alignment_provided: int
249
- :param default_alignment: alignment to use if `alignment_provided` is None
250
- :type default_alignment: int
251
-
252
- :return: alignment to use
253
- :rtype: int
254
- """
255
- if alignment_provided is not None:
256
- if alignment_provided > default_alignment:
257
- if alignment_provided % default_alignment == 0:
258
- return default_alignment
259
- raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
260
- return alignment_provided
261
-
262
- return default_alignment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py DELETED
@@ -1,362 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Utility functions for converting between frontend datatypes and CUTLASS datatypes
35
- """
36
-
37
- import cutlass_cppgen
38
- from cutlass_library import (
39
- DataTypeSize,
40
- MathOperation,
41
- MathInstruction
42
- )
43
- from cutlass_cppgen.backend.library import (
44
- TileDescription,
45
- )
46
-
47
- bfloat16_available = None
48
- cupy_available = None
49
- numpy_available = None
50
- torch_available = None
51
- _library_to_cupy_dict = None
52
- _library_to_numpy_dict = None
53
- _library_to_torch_dict = None
54
- _torch_to_library_dict = None
55
-
56
-
57
- def is_numpy_available():
58
- global numpy_available, _library_to_numpy_dict
59
- if numpy_available is None:
60
- try:
61
- import numpy as np
62
-
63
- numpy_available = True
64
- _library_to_numpy_dict = {
65
- cutlass_cppgen.DataType.f16: np.float16,
66
- cutlass_cppgen.DataType.f32: np.float32,
67
- cutlass_cppgen.DataType.f64: np.float64,
68
- cutlass_cppgen.DataType.s8: np.int8,
69
- cutlass_cppgen.DataType.s32: np.int32,
70
- }
71
- except ImportError:
72
- numpy_available = False
73
- _library_to_numpy_dict = {}
74
- return numpy_available
75
-
76
-
77
- def is_numpy_tensor(inp) -> bool:
78
- if is_numpy_available():
79
- import numpy as np
80
- return isinstance(inp, np.ndarray)
81
- return False
82
-
83
-
84
- def numpy_library_type(inp) -> cutlass_cppgen.DataType:
85
- if is_numpy_available():
86
- import numpy as np
87
- if inp == np.float16:
88
- return cutlass_cppgen.DataType.f16
89
- elif inp == np.float32:
90
- return cutlass_cppgen.DataType.f32
91
- elif inp == np.float64:
92
- return cutlass_cppgen.DataType.f64
93
- elif inp == np.int8:
94
- return cutlass_cppgen.DataType.s8
95
- elif inp == np.int32:
96
- return cutlass_cppgen.DataType.s32
97
- return None
98
-
99
-
100
- def numpy_type(inp):
101
- return _library_to_numpy_dict.get(inp, None)
102
-
103
-
104
- def is_cupy_available():
105
- global cupy_available
106
- if cupy_available is None:
107
- try:
108
- import cupy as cp
109
-
110
- cupy_available = True
111
- _library_to_cupy_dict = {
112
- cutlass_cppgen.DataType.f16: cp.float16,
113
- cutlass_cppgen.DataType.f32: cp.float32,
114
- cutlass_cppgen.DataType.f64: cp.float64,
115
- cutlass_cppgen.DataType.s8: cp.int8,
116
- cutlass_cppgen.DataType.s32: cp.int32,
117
- }
118
- except ImportError:
119
- cupy_available = False
120
- _library_to_cupy_dict = {}
121
- return cupy_available
122
-
123
-
124
- def is_cupy_tensor(inp) -> bool:
125
- if is_cupy_available():
126
- import cupy as cp
127
- return isinstance(inp, cp.ndarray)
128
- return False
129
-
130
-
131
- def cupy_library_type(inp) -> cutlass_cppgen.DataType:
132
- if is_cupy_available():
133
- import cupy as cp
134
- if inp == cp.float16:
135
- return cutlass_cppgen.DataType.f16
136
- elif inp == cp.float32:
137
- return cutlass_cppgen.DataType.f32
138
- elif inp == cp.float64:
139
- return cutlass_cppgen.DataType.f64
140
- return None
141
-
142
-
143
- def cupy_type(inp):
144
- return _library_to_cupy_dict.get(inp, None)
145
-
146
-
147
- def is_torch_available():
148
- global torch_available, _library_to_torch_dict, _torch_to_library_dict
149
- if torch_available is None:
150
- try:
151
- import torch
152
-
153
- torch_available = True
154
- _torch_to_library_dict = {
155
- torch.half: cutlass_cppgen.DataType.f16,
156
- torch.float16: cutlass_cppgen.DataType.f16,
157
- torch.bfloat16: cutlass_cppgen.DataType.bf16,
158
- torch.float: cutlass_cppgen.DataType.f32,
159
- torch.float32: cutlass_cppgen.DataType.f32,
160
- torch.double: cutlass_cppgen.DataType.f64,
161
- torch.float64: cutlass_cppgen.DataType.f64,
162
- torch.int8: cutlass_cppgen.DataType.s8,
163
- torch.int32: cutlass_cppgen.DataType.s32,
164
- torch.uint8: cutlass_cppgen.DataType.u8,
165
- }
166
-
167
- _library_to_torch_dict = {
168
- cutlass_cppgen.DataType.f16: torch.half,
169
- cutlass_cppgen.DataType.f16: torch.float16,
170
- cutlass_cppgen.DataType.bf16: torch.bfloat16,
171
- cutlass_cppgen.DataType.f32: torch.float,
172
- cutlass_cppgen.DataType.f32: torch.float32,
173
- cutlass_cppgen.DataType.f64: torch.double,
174
- cutlass_cppgen.DataType.f64: torch.float64,
175
- cutlass_cppgen.DataType.s8: torch.int8,
176
- cutlass_cppgen.DataType.s32: torch.int32,
177
- cutlass_cppgen.DataType.u8: torch.uint8,
178
- }
179
-
180
- def possibly_add_type(torch_type_name, cutlass_type):
181
- # Only try adding the type if the version of torch being used supports it
182
- if hasattr(torch, torch_type_name):
183
- torch_type = getattr(torch, torch_type_name)
184
- _torch_to_library_dict[torch_type] = cutlass_type
185
- _library_to_torch_dict[cutlass_type] = torch_type
186
-
187
- possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3)
188
- possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2)
189
-
190
- except ImportError:
191
- torch_available = False
192
- _torch_to_library_dict = {}
193
- _library_to_torch_dict = {}
194
- return torch_available
195
-
196
-
197
- def is_torch_tensor(inp) -> bool:
198
- if is_torch_available():
199
- import torch
200
- return isinstance(inp, torch.Tensor)
201
- return False
202
-
203
-
204
- def torch_library_type(inp) -> cutlass_cppgen.DataType:
205
- return _torch_to_library_dict.get(inp, None)
206
-
207
-
208
- def torch_type(inp):
209
- return _library_to_torch_dict.get(inp, None)
210
-
211
-
212
- def is_bfloat16_available():
213
- global bfloat16_available
214
-
215
- if bfloat16_available is None:
216
- try:
217
- import bfloat16
218
-
219
- bfloat16_available = True
220
- except ImportError:
221
- bfloat16_available = False
222
- return bfloat16_available
223
-
224
-
225
- def bfloat16_library_type(inp) -> cutlass_cppgen.DataType:
226
- if is_bfloat16_available():
227
- import bfloat16
228
- if inp == bfloat16.bfloat16:
229
- return cutlass_cppgen.DataType.bf16
230
-
231
-
232
- def bfloat16_type(inp):
233
- if is_bfloat16_available():
234
- import bfloat16
235
- if inp == cutlass_cppgen.DataType.bf16:
236
- return bfloat16.bfloat16
237
-
238
-
239
- def library_type(inp):
240
- if inp in DataTypeSize:
241
- return inp
242
-
243
- for cvt_fn in [
244
- bfloat16_library_type,
245
- cupy_library_type,
246
- numpy_library_type,
247
- torch_library_type,
248
- ]:
249
- out = cvt_fn(inp)
250
- if out is not None:
251
- return out
252
-
253
- raise Exception(f"No available conversion from type {inp} to a library type.")
254
-
255
-
256
- def _tensor_from_numpy(np_tensor):
257
- dtype = library_type(np_tensor.dtype)
258
- if np_tensor.flags.c_contiguous:
259
- layout = cutlass_cppgen.LayoutType.RowMajor
260
- elif np_tensor.flags.f_contiguous:
261
- layout = cutlass_cppgen.LayoutType.ColumnMajor
262
- return (dtype, layout)
263
-
264
-
265
- def _tensor_from_torch(pt_tensor):
266
- dtype = library_type(pt_tensor.dtype)
267
- return (dtype, cutlass_cppgen.LayoutType.RowMajor)
268
-
269
-
270
- def get_datatype_and_layout(tensor):
271
- if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
272
- return _tensor_from_numpy(tensor)
273
- elif is_torch_tensor(tensor):
274
- return _tensor_from_torch(tensor)
275
- elif isinstance(tensor, float) or isinstance(tensor, int):
276
- return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor)
277
- else:
278
- raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
279
-
280
-
281
- def get_tensor_shape(tensor, op="GEMM"):
282
- if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
283
- return tensor.shape
284
- elif is_torch_tensor(tensor):
285
- size = tensor.size()
286
- if op == "CONV":
287
- # PyTorch Tensors have shape NCHW
288
- return (size[0], size[2], size[3], size[1])
289
- else:
290
- return tuple(tensor.size())
291
- elif isinstance(tensor, float) or isinstance(tensor, int):
292
- return (1,)
293
- else:
294
- raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
295
-
296
-
297
- _math_operation_value_map = {x.value: x for x in MathOperation}
298
-
299
-
300
- def backend_math_operation(math_op: MathOperation):
301
- if math_op.value not in _math_operation_value_map.keys():
302
- raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.")
303
- return _math_operation_value_map[math_op.value]
304
-
305
-
306
- def construct_backend_td(td: cutlass_cppgen.TileDescription,
307
- kernel_schedule: cutlass_cppgen.KernelScheduleType,
308
- epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
309
- tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription:
310
- mi = td.math_instruction
311
- backend_mi = MathInstruction(
312
- mi.instruction_shape,
313
- mi.element_a,
314
- mi.element_b,
315
- mi.element_accumulator,
316
- mi.opcode_class,
317
- backend_math_operation(mi.math_operation)
318
- )
319
- cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1]
320
- return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
321
- backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler)
322
-
323
-
324
- def td_from_profiler_op(op) -> TileDescription:
325
- """
326
- Converts the profiler's TileDescription in ``op`` into the backend TileDescription
327
-
328
- :param op: profiler Operation
329
-
330
- :returns: backend TileDescription
331
- :rtype: cutlass_cppgen.backend.TileDescription
332
- """
333
- kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
334
- eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None
335
- tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None
336
- return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule)
337
-
338
-
339
- def td_from_profiler_td(td: TileDescription) -> TileDescription:
340
- """
341
- Converts the profiler's TileDescription into the backend TileDescription
342
-
343
- :param td: profiler TileDescription
344
- :type td: cutlass_cppgen.TileDescription
345
-
346
- :returns: backend TileDescription
347
- :rtype: cutlass_cppgen.backend.TileDescription
348
- """
349
- return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)
350
-
351
-
352
- def to_camel_case(snake_str):
353
- return "".join(x.capitalize() for x in snake_str.lower().split("_"))
354
-
355
-
356
- def getattr_enum(obj, attr_name):
357
- # The attr_name is under the snake_case
358
- camel_attr = to_camel_case(attr_name)
359
- if hasattr(obj, camel_attr):
360
- return getattr(obj, camel_attr)
361
- else:
362
- raise Exception(f"Invalid option: {attr_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py DELETED
@@ -1,41 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
- import importlib
33
- from typing import Any
34
-
35
- def lazy_import(mod_name: str) -> Any:
36
- class Lazy:
37
- def __getattr__(self, name:str) -> Any:
38
- module = importlib.import_module(mod_name)
39
- return getattr(module, name)
40
-
41
- return Lazy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py DELETED
@@ -1,196 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Profiler based on the cuda events
35
- """
36
-
37
- import re
38
- import subprocess
39
-
40
- from cutlass_cppgen.utils.lazy_import import lazy_import
41
- cuda = lazy_import("cuda.cuda")
42
- cudart = lazy_import("cuda.cudart")
43
- import numpy as np
44
-
45
- from cutlass_cppgen import CUTLASS_PATH
46
- from cutlass_cppgen.backend.library import DataTypeSize
47
- from cutlass_cppgen.op.op import OperationBase
48
- from cutlass_cppgen.shape import GemmCoord
49
- from cutlass_cppgen.utils.datatypes import is_numpy_tensor
50
-
51
-
52
- class GpuTimer:
53
- def __init__(self) -> None:
54
- self.events = [
55
- cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
56
- cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
57
- ]
58
-
59
- def start(self, stream=None):
60
- if not stream:
61
- stream = cuda.CUstream(0)
62
-
63
- (err,) = cuda.cuEventRecord(self.events[0], stream)
64
- if err != cuda.CUresult.CUDA_SUCCESS:
65
- raise RuntimeError(f"CUDA Error {str(err)}")
66
-
67
- def stop(self, stream=None):
68
- if not stream:
69
- stream = cuda.CUstream(0)
70
-
71
- (err,) = cuda.cuEventRecord(self.events[1], stream)
72
- if err != cuda.CUresult.CUDA_SUCCESS:
73
- raise RuntimeError(f"CUDA Error {str(err)}")
74
- pass
75
-
76
- def stop_and_wait(self, stream=None):
77
- if not stream:
78
- stream = cuda.CUstream(0)
79
-
80
- self.stop(stream)
81
- if stream:
82
- (err,) = cuda.cuStreamSynchronize(stream)
83
- if err != cuda.CUresult.CUDA_SUCCESS:
84
- raise RuntimeError(f"CUDA Error {str(err)}")
85
- else:
86
- (err,) = cudart.cudaDeviceSynchronize()
87
- if err != cuda.CUresult.CUDA_SUCCESS:
88
- raise RuntimeError(f"CUDA Error {str(err)}")
89
-
90
- def duration(self, iterations=1):
91
- err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
92
- if err != cuda.CUresult.CUDA_SUCCESS:
93
- raise RuntimeError(f"CUDA Error {str(err)}")
94
- return duration / float(iterations)
95
-
96
-
97
- class CUDAEventProfiler:
98
- def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None:
99
- self.arguments = op.run(*args, **kwargs)
100
- self.operation = op.operation
101
- self.warmup_iterations = warmup_iterations
102
- self.iterations = iterations
103
- self.timer = GpuTimer()
104
-
105
- #
106
- # Cutlass Python Interface Profiler
107
- #
108
-
109
- def __call__(self):
110
- for _ in range(self.warmup_iterations):
111
- self.operation.run(self.arguments)
112
-
113
- self.timer.start()
114
- for _ in range(self.iterations):
115
- self.operation.run(self.arguments)
116
-
117
- self.timer.stop_and_wait()
118
- runtime = self.timer.duration(self.iterations)
119
- return runtime
120
-
121
- #
122
- # CUTLASS Profiler
123
- #
124
-
125
- def run_cutlass_profiler(self):
126
- alpha = 1.0
127
- beta = 1.0
128
-
129
- profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler"
130
- kernel_name = self.operation.procedural_name()
131
- verification_providers = "device"
132
- provider = "cutlass"
133
- problem_size = self.arguments.problem_size
134
-
135
- if "cutlass3x" in kernel_name:
136
- # cutlass3x generator only have column-major output
137
- layout_name = self.operation.layout_name_3x()
138
- if layout_name[-1] == "t":
139
- new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"])
140
- problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
141
- kernel_name = kernel_name.replace(layout_name, new_layout_name)
142
-
143
- batch_count = self.arguments.batch_count
144
-
145
- cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \
146
- f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \
147
- f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\
148
- f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}"
149
-
150
- result = subprocess.getoutput(cmd)
151
-
152
- m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
153
- runtime = float(m.group("runtime"))
154
-
155
- m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
156
- bytes = int(m.group("bytes"))
157
-
158
- m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
159
- flops = int(m.group("flops"))
160
-
161
- # check if the problem size matches
162
- assert bytes == self.bytes(problem_size, batch_count, beta)
163
- assert flops == self.flops(problem_size, batch_count, beta)
164
-
165
- return runtime
166
-
167
- def bytes(self, problem_size, batch_count=1, beta=0.0):
168
- m = problem_size.m()
169
- n = problem_size.n()
170
- k = problem_size.k()
171
-
172
- bytes = (
173
- (DataTypeSize[self.operation.A.element] * m // 8) * k
174
- + (DataTypeSize[self.operation.B.element] * n // 8) * k
175
- + (DataTypeSize[self.operation.C.element] * m // 8) * n
176
- )
177
-
178
- if beta != 0:
179
- bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
180
-
181
- bytes *= batch_count
182
-
183
- return bytes
184
-
185
- def flops(self, problem_size, batch_count=1, beta=0.0):
186
- m = problem_size.m()
187
- n = problem_size.n()
188
- k = problem_size.k()
189
-
190
- flops_ = (m * n * k) * 2 * batch_count
191
-
192
- if beta != 0:
193
- flops_ += m * n * batch_count * 2
194
-
195
- return flops_
196
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py DELETED
@@ -1,63 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- import os
34
- import sys
35
-
36
- from . import conv2d_operation
37
- from . import conv3d_operation
38
- from . import emit_kernel_listing
39
- from . import gemm_operation
40
-
41
- if '-m' not in sys.argv:
42
- # Do not import generator when running python -m cutlass_library.generator to
43
- # avoid double-import warnings
44
- from . import generator
45
-
46
- from . import library
47
- from . import manifest
48
- from . import rank_2k_operation
49
- from . import rank_k_operation
50
- from . import symm_operation
51
- from . import trmm_operation
52
- # Make enum types from library.py accessible via cutlass_library.*
53
- from .library import *
54
-
55
- # Set up `source` to point to the path containing the CUTLASS source.
56
- # Check first if the path contains a `source` subdirectory -- this will
57
- # be the case when the package has been installed via pip. Otherwise,
58
- # default to the root of CUTLASS.
59
- install_source_path = os.path.join(__path__[0], 'source')
60
- if os.path.isdir(install_source_path):
61
- source_path = install_source_path
62
- else:
63
- source_path = os.path.join(__path__[0], '../..')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py DELETED
@@ -1,621 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Utilities for emitting Conv2d kernels
35
- """
36
-
37
- import enum
38
- import logging
39
- import os.path
40
- import shutil
41
- from string import Template
42
-
43
- try:
44
- import builtins
45
- if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
- raise ImportError("Disabling attempt to import cutlass_library")
47
- from cutlass_library.library import *
48
- from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
49
- except ImportError:
50
- from library import *
51
- from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
52
-
53
- _LOGGER = logging.getLogger(__name__)
54
-
55
- ###################################################################################################
56
-
57
- #
58
- class Conv2dOperation:
59
- #
60
- def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
61
- stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \
62
- group_mode = GroupMode.NoneGroup):
63
-
64
- self.operation_kind = OperationKind.Conv2d
65
- self.arch = arch
66
- self.tile_description = tile_description
67
- self.conv_kind = conv_kind
68
- self.A = A
69
- self.B = B
70
- self.C = C
71
- self.element_epilogue = element_epilogue
72
- self.epilogue_functor = epilogue_functor
73
- self.iterator_algorithm = iterator_algorithm
74
- self.stride_support = stride_support
75
- self.swizzling_functor = swizzling_functor
76
- self.group_mode = group_mode
77
- #
78
- def is_complex(self):
79
- complex_operators = [
80
- MathOperation.multiply_add_complex,
81
- MathOperation.multiply_add_complex_gaussian
82
- ]
83
- return self.tile_description.math_instruction.math_operation in complex_operators
84
-
85
- #
86
- def is_mixed_input(self):
87
- return self.A.element != self.B.element
88
-
89
- #
90
- def accumulator_type(self):
91
- accum = self.tile_description.math_instruction.element_accumulator
92
-
93
- if self.is_complex():
94
- return get_complex_from_real(accum)
95
-
96
- return accum
97
-
98
- #
99
- def core_name(self):
100
- ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
101
-
102
- intermediate_type = ''
103
-
104
- if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
105
- inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
106
- if self.tile_description.math_instruction.element_a != self.A.element and \
107
- self.tile_description.math_instruction.element_a != self.accumulator_type():
108
- intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
109
- else:
110
- inst_shape = ''
111
-
112
- return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
113
- inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
114
-
115
- #
116
- def extended_name(self):
117
- ''' Append data types if they differ from compute type. '''
118
- if self.C.element != self.tile_description.math_instruction.element_accumulator and \
119
- self.A.element != self.tile_description.math_instruction.element_accumulator:
120
- extended_name = "${element_c}_${core_name}_${element_a}"
121
- elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
122
- self.A.element != self.tile_description.math_instruction.element_accumulator:
123
- extended_name = "${core_name}_${element_a}"
124
- else:
125
- extended_name = "${core_name}"
126
-
127
- extended_name = SubstituteTemplate(extended_name, {
128
- 'element_a': DataTypeNames[self.A.element],
129
- 'element_c': DataTypeNames[self.C.element],
130
- 'core_name': self.core_name()
131
- })
132
-
133
- return extended_name
134
-
135
- #
136
- def layout_name(self):
137
- return "%s" % (ShortLayoutTypeNames[self.A.layout])
138
-
139
- #
140
- def configuration_name(self):
141
- ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
142
-
143
- opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
144
-
145
- threadblock = self.tile_description.procedural_name()
146
-
147
- # grouped conv
148
- if self.group_mode != GroupMode.NoneGroup:
149
- group_conv_name = f"{GroupModeNames[self.group_mode]}_"
150
- else:
151
- group_conv_name = ""
152
-
153
- if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad:
154
- configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
155
- else:
156
- configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"
157
-
158
- return SubstituteTemplate(
159
- configuration_name,
160
- {
161
- 'opcode_class': opcode_class_name,
162
- 'extended_name': self.extended_name(),
163
- 'threadblock': threadblock,
164
- 'layout': self.layout_name(),
165
- 'alignment': "%d" % self.A.alignment,
166
- 'group_conv_name': group_conv_name
167
- }
168
- )
169
-
170
- #
171
- def procedural_name(self):
172
- ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
173
- return self.configuration_name()
174
-
175
- ###################################################################################################
176
- #
177
- # Emits single instances of a CUTLASS device-wide operator
178
- #
179
- ###################################################################################################
180
-
181
- class EmitConv2dInstance:
182
- def __init__(self):
183
- # Emitter for CUTLASS 3 convolution operations
184
- self.conv3x_emitter = EmitConv3xInstance()
185
- self.template = """
186
- // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
187
- using ${operation_name}_base =
188
- typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
189
- ${element_a},
190
- ${layout_a},
191
- ${element_b},
192
- ${layout_b},
193
- ${element_c},
194
- ${layout_c},
195
- ${element_accumulator},
196
- ${opcode_class},
197
- ${arch},
198
- cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
199
- cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
200
- cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
201
- ${epilogue_functor}<
202
- ${element_c},
203
- ${epilogue_vector_length},
204
- ${element_accumulator},
205
- ${element_epilogue}
206
- >,
207
- ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
208
- ${stages},
209
- ${math_operator},
210
- ${iterator_algorithm},
211
- ${stride_support},
212
- ${align_a},
213
- ${align_b}
214
- >::Kernel;
215
- """
216
- self.template_group_conv = """
217
- // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
218
- using ${operation_name}_base =
219
- typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}<
220
- ${element_a},
221
- ${layout_a},
222
- ${element_b},
223
- ${layout_b},
224
- ${element_c},
225
- ${layout_c},
226
- ${element_accumulator},
227
- ${opcode_class},
228
- ${arch},
229
- cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
230
- cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
231
- cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
232
- ${epilogue_functor}<
233
- ${element_c},
234
- ${epilogue_vector_length},
235
- ${element_accumulator},
236
- ${element_epilogue}
237
- >,
238
- ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
239
- ${stages},
240
- ${math_operator},
241
- ${group_mode},
242
- ${iterator_algorithm},
243
- ${stride_support},
244
- ${align_a},
245
- ${align_b}
246
- >::Kernel;
247
- """
248
- self.template_depthwise_direct_conv = """
249
- // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
250
- using ${operation_name}_base =
251
- typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}<
252
- ${element_a},
253
- ${layout_a},
254
- ${element_b},
255
- ${layout_b},
256
- ${element_c},
257
- ${layout_c},
258
- ${element_accumulator},
259
- ${opcode_class},
260
- ${arch},
261
- cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
262
- cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>,
263
- cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>,
264
- cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
265
- cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
266
- ${epilogue_functor}<
267
- ${element_c},
268
- ${epilogue_vector_length},
269
- ${element_accumulator},
270
- ${element_epilogue},
271
- cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
272
- >,
273
-
274
- cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
275
- 1,
276
- ${threadblock_output_shape_n},
277
- ${threadblock_output_shape_p},
278
- ${threadblock_output_shape_q}>,
279
- ${stages},
280
- ${math_operator},
281
- ${iterator_algorithm},
282
- ${stride_support},
283
- cutlass::MatrixShape<${stride_r}, ${stride_s}>,
284
- cutlass::MatrixShape<${dilation_r}, ${dilation_s}>
285
- >::Kernel;
286
- """
287
-
288
- def arch_number_to_type(self, arch: int):
289
- return f"cutlass::arch::Sm{arch}"
290
-
291
- def emit(self, operation):
292
- _LOGGER.debug("*** EmitConv2dInstance::emit")
293
- _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
294
-
295
- if hasattr(operation, 'is_3x') and operation.is_3x:
296
- _LOGGER.debug("*** CUTLASS 3 operation")
297
- return self.conv3x_emitter.emit(operation)
298
-
299
- _LOGGER.debug("*** CUTLASS 2 operation")
300
-
301
- warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
302
-
303
- epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
304
-
305
- values = {
306
- 'operation_name': operation.procedural_name(),
307
- 'conv_kind': ConvKindTag[operation.conv_kind],
308
- 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
309
- 'element_a': DataTypeTag[operation.A.element],
310
- 'layout_a': LayoutTag[operation.A.layout],
311
- 'element_b': DataTypeTag[operation.B.element],
312
- 'layout_b': LayoutTag[operation.B.layout],
313
- 'element_c': DataTypeTag[operation.C.element],
314
- 'layout_c': LayoutTag[operation.C.layout],
315
- 'element_accumulator': DataTypeTag[operation.accumulator_type()],
316
- 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
317
- 'arch': "cutlass::arch::Sm%d" % operation.arch,
318
- 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
319
- 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
320
- 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
321
- 'warp_shape_m': str(warp_shape[0]),
322
- 'warp_shape_n': str(warp_shape[1]),
323
- 'warp_shape_k': str(warp_shape[2]),
324
- 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
325
- 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
326
- 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
327
- 'epilogue_vector_length': str(epilogue_vector_length),
328
- 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
329
- 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
330
- 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
331
- 'stages': str(operation.tile_description.stages),
332
- 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
333
- 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
334
- 'stride_support': StrideSupportTag[operation.stride_support],
335
- 'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \
336
- MathOperationTag[operation.tile_description.math_instruction.math_operation],
337
- 'align_a': str(operation.A.alignment),
338
- 'align_b': str(operation.B.alignment),
339
- }
340
-
341
- if operation.group_mode == GroupMode.NoneGroup:
342
- _LOGGER.debug("*** group_mode=NoneGroup")
343
- return SubstituteTemplate(self.template, values)
344
-
345
- elif operation.group_mode == GroupMode.Depthwise:
346
- _LOGGER.debug("*** group_mode=Depthwise")
347
- values['group_mode'] = GroupModeTag[operation.group_mode]
348
- # Setup other template params
349
- values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0])
350
- values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1])
351
- values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2])
352
-
353
- values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3])
354
-
355
- values['filter_shape_r'] = str(operation.tile_description.filter_shape[0])
356
- values['filter_shape_s'] = str(operation.tile_description.filter_shape[1])
357
-
358
- values['stride_r'] = str(operation.tile_description.stride[0])
359
- values['stride_s'] = str(operation.tile_description.stride[1])
360
-
361
- values['dilation_r'] = str(operation.tile_description.dilation[0])
362
- values['dilation_s'] = str(operation.tile_description.dilation[1])
363
-
364
- return SubstituteTemplate(self.template_depthwise_direct_conv, values)
365
-
366
- else:
367
- _LOGGER.debug("*** group_mode=" + GroupModeTag[operation.group_mode])
368
- values['group_mode'] = GroupModeTag[operation.group_mode]
369
- return SubstituteTemplate(self.template_group_conv, values)
370
-
371
- ###################################################################################################
372
- #
373
- # Generator functions for all layouts
374
- #
375
- ###################################################################################################
376
-
377
- #
378
- def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
379
- _LOGGER.debug("*** GenerateConv2dTensorOp")
380
-
381
- for tile in tile_descriptions:
382
- for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
383
-
384
- if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
385
-
386
- #
387
- output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
388
- if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
389
- else [tile.math_instruction.element_accumulator,]
390
-
391
- for output_type in output_types:
392
- A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
393
- B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
394
- C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type])))
395
-
396
- manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
397
-
398
- class EmitConv2dIncludes:
399
- '''Emit includes that are specific to the operation.'''
400
-
401
- def __init__(self):
402
- self.includes = ['conv2d_operation.h']
403
- self.emitter_3x = EmitConv3xIncludes()
404
-
405
- def operation_is_3x(self, operation) -> bool:
406
- """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
407
- return hasattr(operation, 'is_3x') and operation.is_3x
408
-
409
- def emit(self, operation) -> str:
410
- if self.operation_is_3x(operation):
411
- return self.emitter_3x.emit(operation)
412
-
413
- return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
414
- "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
415
-
416
- ###################################################################################################
417
- #
418
- # Emitters functions for all targets
419
- #
420
- ###################################################################################################
421
-
422
- class EmitConv2dConfigurationLibrary:
423
- def __init__(self, operation_path, configuration_name):
424
- self.configuration_name = configuration_name
425
- self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
426
-
427
- self.instance_emitter = EmitConv2dInstance()
428
- self.includes_emitter = EmitConv2dIncludes()
429
-
430
- self.header_template = """
431
- /*
432
- Generated by conv2d_operation.py - Do not edit.
433
- */
434
-
435
- ///////////////////////////////////////////////////////////////////////////////////////////////////
436
-
437
- #include "cutlass/cutlass.h"
438
- #include "cutlass/library/library.h"
439
- #include "cutlass/library/manifest.h"
440
-
441
- #include "library_internal.h"
442
- """
443
-
444
- self.instance_template = """
445
- ${stub_begin}
446
- ${operation_instance}
447
- // Derived class
448
- struct ${operation_name} :
449
- public ${operation_name}_base { };
450
- ${stub_end}
451
- ///////////////////////////////////////////////////////////////////////////////////////////////////
452
-
453
- """
454
-
455
- self.configuration_header = """
456
-
457
- namespace cutlass {
458
- namespace library {
459
-
460
- // Initialize all instances
461
- void initialize_${configuration_name}(Manifest &manifest) {
462
- """
463
-
464
- self.configuration_instance = """${stub_begin}
465
- using Operation_${operation_name} = cutlass::conv::device::${kernel_name}<
466
- ${operation_name}>;
467
-
468
- manifest.append(new cutlass::library::${operation_wrapper}<
469
- Operation_${operation_name}
470
- >(
471
- "${operation_name}"
472
- ));
473
- ${stub_end}
474
- """
475
-
476
- self.configuration_epilogue = "}\n"
477
-
478
- self.epilogue_template = """
479
-
480
- ///////////////////////////////////////////////////////////////////////////////////////////////////
481
-
482
- } // namespace library
483
- } // namespace cutlass
484
-
485
- ///////////////////////////////////////////////////////////////////////////////////////////////////
486
-
487
- """
488
-
489
- def operation_is_3x(self, operation):
490
- """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
491
- return hasattr(operation, 'is_3x') and operation.is_3x
492
-
493
- def __enter__(self):
494
- """
495
- Open the configuration_file, and write the "header" C++ code to it.
496
-
497
- The "header" consists of a comment (that this is generated code,
498
- so it should not be edited), and includes that are common
499
- to all kinds of kernels.
500
- """
501
- _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__enter__')
502
- _LOGGER.debug('*** configuration_path (file to write): ' +
503
- str(self.configuration_path))
504
- _LOGGER.debug('*** configuration_name: ' + self.configuration_name)
505
- self.configuration_file = open(self.configuration_path, "w")
506
-
507
- self.configuration_file.write(SubstituteTemplate(self.header_template, {
508
- 'configuration_name': self.configuration_name
509
- }))
510
- self.operations = []
511
- return self
512
-
513
- def emit(self, operation):
514
- """
515
- Write three pieces of C++ code to the configuration_file
516
- (that was opened by the __enter__ method above):
517
-
518
- 1. the header includes that are specific to the operation
519
- (CUTLASS 2 vs. CUTLASS 3);
520
-
521
- 2. the "operation instance" (a "using" declaration ending in "_base"); and
522
-
523
- 3. the "operation name" (declaration and definition of a derived class
524
- of the above operation instance).
525
-
526
- The "using" declaration turns a C++ class name, possibly namespace-qualified,
527
- possibly also with angle brackets, into a C-style, easily demangled identifier.
528
- """
529
- _LOGGER.debug('*** EmitConv2dConfigurationLibrary::emit')
530
- _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name())
531
- self.operations.append(operation)
532
-
533
- self.configuration_file.write(self.includes_emitter.emit(operation))
534
-
535
- stub_begin = ''
536
- stub_end = ''
537
- # It can be useful to stub (comment) out instantiations for testing.
538
- # In this case, one need only set is_stub to True.
539
- is_stub = False
540
- if is_stub:
541
- stub_begin = "// STUB for now\n#if 0"
542
- stub_end = '#endif // 0'
543
-
544
- self.configuration_file.write(Template(self.instance_template).substitute({
545
- 'configuration_name': self.configuration_name,
546
- 'operation_name': operation.procedural_name(),
547
- 'operation_instance': self.instance_emitter.emit(operation),
548
- 'stub_begin': stub_begin,
549
- 'stub_end': stub_end
550
- }))
551
-
552
- def __exit__(self, exception_type, exception_value, traceback):
553
- """
554
- Write the rest of the C++ code to the configuration_file, and close the file.
555
-
556
- The "rest of the C++ code" has the following components.
557
-
558
- 1. Configuration header: Open the namespace(s), and open the definition
559
- of the "initialize_${configuration_name}" registration function
560
- that registers the operation with the Manifest.
561
- ("Registration" helps turn C++ compile-time polymorphism
562
- (via template parameters) into a run-time choice of parameters.)
563
-
564
- 2. Configuration instance: In the body of the registration function,
565
- make a "using" declaration Operation_${operation_name} for the
566
- operation type (which uses operation_name as its template argument).
567
- Then, tell the manifest about the operation via a "manifest.append" call.
568
- The argument of the call is a new instance of
569
- "SomethingOperation<Operation_${operation_name}>"
570
- (replace Something with a specific name).
571
-
572
- 3. Configuration epilogue: Close the definition of the registration function.
573
-
574
- 4. Epilogue template: Close the namespace(s).
575
- """
576
-
577
- _LOGGER.debug('*** EmitConv2dConfigurationLibrary::__exit__')
578
- _LOGGER.debug('*** configuration_path (file to write): ' +
579
- str(self.configuration_path))
580
- _LOGGER.debug('*** configuration_name: ' + self.configuration_name)
581
-
582
- self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
583
- 'configuration_name': self.configuration_name
584
- }))
585
-
586
- for operation in self.operations:
587
- stub_begin = ''
588
- stub_end = ''
589
- # It can be useful to stub (comment) out instantiations for testing.
590
- # In this case, one need only set is_stub to True.
591
- is_stub = False
592
- if is_stub:
593
- stub_begin = "// STUB for now\n#if 0"
594
- stub_end = "#endif // 0"
595
-
596
- if operation.group_mode == GroupMode.Depthwise:
597
- kernel_name = 'DirectConvolution'
598
- operation_wrapper = 'DirectConv2dOperation'
599
- else:
600
- kernel_name = 'ImplicitGemmConvolution'
601
- operation_wrapper = 'Conv2dOperation'
602
- if self.operation_is_3x(operation):
603
- kernel_name = 'ConvUniversalAdapter'
604
- operation_wrapper = 'ConvOperation3x'
605
-
606
- self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
607
- 'configuration_name': self.configuration_name,
608
- 'operation_name': operation.procedural_name(),
609
- 'kernel_name': kernel_name,
610
- 'operation_wrapper': operation_wrapper,
611
- 'stub_begin': stub_begin,
612
- 'stub_end': stub_end
613
- }))
614
-
615
- self.configuration_file.write(self.configuration_epilogue)
616
- self.configuration_file.write(self.epilogue_template)
617
- self.configuration_file.close()
618
-
619
-
620
- ###################################################################################################
621
- ###################################################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py DELETED
@@ -1,482 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Utilities for emitting Conv3d kernels
35
- """
36
-
37
- import enum
38
- import logging
39
- import os.path
40
- import shutil
41
- from string import Template
42
-
43
- try:
44
- import builtins
45
- if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
- raise ImportError("Disabling attempt to import cutlass_library")
47
- from cutlass_library.library import *
48
- from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
49
- except ImportError:
50
- from library import *
51
- from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
52
-
53
- _LOGGER = logging.getLogger(__name__)
54
-
55
- ###################################################################################################
56
-
57
- #
58
- class Conv3dOperation:
59
- #
60
- def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
61
- stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
62
-
63
- self.operation_kind = OperationKind.Conv3d
64
- self.arch = arch
65
- self.tile_description = tile_description
66
- self.conv_kind = conv_kind
67
- self.A = A
68
- self.B = B
69
- self.C = C
70
- self.element_epilogue = element_epilogue
71
- self.epilogue_functor = epilogue_functor
72
- self.iterator_algorithm = iterator_algorithm
73
- self.stride_support = stride_support
74
- self.swizzling_functor = swizzling_functor
75
-
76
- #
77
- def is_mixed_input(self):
78
- return self.A.element != self.B.element
79
-
80
- #
81
- def core_name(self):
82
- ''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
83
-
84
- intermediate_type = ''
85
-
86
- if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
87
- inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
88
- if self.tile_description.math_instruction.element_a != self.A.element and \
89
- self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
90
- intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
91
- else:
92
- inst_shape = ''
93
-
94
- return "%s%s%s%s3d_%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], \
95
- inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
96
-
97
- #
98
- def extended_name(self):
99
- ''' Append data types if they differ from compute type. '''
100
- if self.C.element != self.tile_description.math_instruction.element_accumulator and \
101
- self.A.element != self.tile_description.math_instruction.element_accumulator:
102
- extended_name = "${element_c}_${core_name}_${element_a}"
103
- elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
104
- self.A.element != self.tile_description.math_instruction.element_accumulator:
105
- extended_name = "${core_name}_${element_a}"
106
- else:
107
- extended_name = "${core_name}"
108
-
109
- extended_name = SubstituteTemplate(extended_name, {
110
- 'element_a': DataTypeNames[self.A.element],
111
- 'element_c': DataTypeNames[self.C.element],
112
- 'core_name': self.core_name()
113
- })
114
-
115
- return extended_name
116
-
117
- #
118
- def configuration_name(self):
119
- ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
120
-
121
- opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
122
-
123
- threadblock = "%dx%d_%dx%d" % (
124
- self.tile_description.threadblock_shape[0],
125
- self.tile_description.threadblock_shape[1],
126
- self.tile_description.threadblock_shape[2],
127
- self.tile_description.stages
128
- )
129
-
130
- if self.stride_support == StrideSupport.Unity:
131
- configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_unity_stride"
132
- else:
133
- configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}"
134
-
135
- return SubstituteTemplate(
136
- configuration_name,
137
- {
138
- 'opcode_class': opcode_class_name,
139
- 'extended_name': self.extended_name(),
140
- 'threadblock': threadblock,
141
- }
142
- )
143
-
144
- #
145
- def procedural_name(self):
146
- ''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
147
- return self.configuration_name()
148
-
149
- ###################################################################################################
150
- #
151
- # Emits single instances of a CUTLASS device-wide operator
152
- #
153
- ###################################################################################################
154
-
155
- class EmitConv3dInstance:
156
- def __init__(self):
157
- # Emitter for CUTLASS 3 convolution operations
158
- self.conv3x_emitter = EmitConv3xInstance()
159
- self.template = """
160
- // Conv3d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
161
- using ${operation_name}_base =
162
- typename cutlass::conv::kernel::DefaultConv3d${conv_kind_name}<
163
- ${element_a},
164
- cutlass::layout::TensorNDHWC,
165
- ${element_b},
166
- cutlass::layout::TensorNDHWC,
167
- ${element_c},
168
- cutlass::layout::TensorNDHWC,
169
- ${element_accumulator},
170
- ${opcode_class},
171
- ${arch},
172
- cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
173
- cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
174
- cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
175
- ${epilogue_functor}<
176
- ${element_c},
177
- ${epilogue_vector_length},
178
- ${element_accumulator},
179
- ${element_epilogue}
180
- >,
181
- ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
182
- ${stages},
183
- cutlass::arch::OpMultiplyAdd,
184
- ${iterator_algorithm},
185
- ${stride_support}
186
- >::Kernel;
187
- """
188
-
189
- def emit(self, operation):
190
- _LOGGER.debug("*** EmitConv3dInstance::emit")
191
- _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
192
-
193
- if hasattr(operation, 'is_3x') and operation.is_3x:
194
- _LOGGER.debug("*** CUTLASS 3 operation")
195
- return self.conv3x_emitter.emit(operation)
196
-
197
- _LOGGER.debug("*** CUTLASS 2 operation")
198
-
199
- warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
200
-
201
- epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
202
-
203
- values = {
204
- 'operation_name': operation.procedural_name(),
205
- 'conv_kind': ConvKindTag[operation.conv_kind],
206
- 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
207
- 'element_a': DataTypeTag[operation.A.element],
208
- 'layout_a': LayoutTag[operation.A.layout],
209
- 'element_b': DataTypeTag[operation.B.element],
210
- 'layout_b': LayoutTag[operation.B.layout],
211
- 'element_c': DataTypeTag[operation.C.element],
212
- 'layout_c': LayoutTag[operation.C.layout],
213
- 'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
214
- 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
215
- 'arch': "cutlass::arch::Sm%d" % operation.arch,
216
- 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
217
- 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
218
- 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
219
- 'warp_shape_m': str(warp_shape[0]),
220
- 'warp_shape_n': str(warp_shape[1]),
221
- 'warp_shape_k': str(warp_shape[2]),
222
- 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
223
- 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
224
- 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
225
- 'epilogue_vector_length': str(epilogue_vector_length),
226
- 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
227
- 'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
228
- 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
229
- 'stages': str(operation.tile_description.stages),
230
- 'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
231
- 'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
232
- 'stride_support': StrideSupportTag[operation.stride_support]
233
- }
234
-
235
- return SubstituteTemplate(self.template, values)
236
-
237
- ###################################################################################################
238
- #
239
- # Generator functions for all layouts
240
- #
241
- ###################################################################################################
242
-
243
- #
244
- def GenerateConv3dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
245
-
246
- for tile in tile_descriptions:
247
- for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
248
-
249
- if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
250
-
251
- #
252
- output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
253
- if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
254
- else [tile.math_instruction.element_accumulator,]
255
-
256
- for output_type in output_types:
257
- A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
258
- B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
259
- C = TensorDescription(output_type, LayoutType.TensorNDHWC, max(1, int(align / DataTypeSize[output_type])))
260
-
261
- manifest.append(Conv3dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
262
-
263
- class EmitConv3dIncludes:
264
- '''Emit includes that are specific to the operation.'''
265
-
266
- def __init__(self):
267
- self.includes = ['conv3d_operation.h']
268
- self.emitter_3x = EmitConv3xIncludes()
269
-
270
- def operation_is_3x(self, operation) -> bool:
271
- """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
272
- return hasattr(operation, 'is_3x') and operation.is_3x
273
-
274
- def emit(self, operation) -> str:
275
- if self.operation_is_3x(operation):
276
- return self.emitter_3x.emit(operation)
277
-
278
- return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
279
- "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
280
-
281
- ###################################################################################################
282
- #
283
- # Emitters functions for all targets
284
- #
285
- ###################################################################################################
286
-
287
- class EmitConv3dConfigurationLibrary:
288
- def __init__(self, operation_path, configuration_name):
289
- self.configuration_name = configuration_name
290
- self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
291
-
292
- self.instance_emitter = EmitConv3dInstance()
293
- self.includes_emitter = EmitConv3dIncludes()
294
-
295
- self.header_template = """
296
- /*
297
- Generated by conv3d_operation.py - Do not edit.
298
- */
299
-
300
- ///////////////////////////////////////////////////////////////////////////////////////////////////
301
-
302
- #include "cutlass/cutlass.h"
303
- #include "cutlass/library/library.h"
304
- #include "cutlass/library/manifest.h"
305
-
306
- #include "library_internal.h"
307
- """
308
-
309
- self.instance_template = """
310
- ${stub_begin}
311
- ${operation_instance}
312
- // Derived class
313
- struct ${operation_name} :
314
- public ${operation_name}_base { };
315
- ${stub_end}
316
- ///////////////////////////////////////////////////////////////////////////////////////////////////
317
-
318
- """
319
-
320
- self.configuration_header = """
321
-
322
- namespace cutlass {
323
- namespace library {
324
-
325
- // Initialize all instances
326
- void initialize_${configuration_name}(Manifest &manifest) {
327
- """
328
-
329
- self.configuration_instance = """${stub_begin}
330
- using Operation_${operation_name} = cutlass::conv::device::${kernel_name}<
331
- ${operation_name}>;
332
-
333
- manifest.append(new cutlass::library::${operation_wrapper}<
334
- Operation_${operation_name}
335
- >(
336
- "${operation_name}"
337
- ));
338
- ${stub_end}
339
- """
340
-
341
- self.configuration_epilogue = "}\n"
342
-
343
- self.epilogue_template = """
344
-
345
- ///////////////////////////////////////////////////////////////////////////////////////////////////
346
-
347
- } // namespace library
348
- } // namespace cutlass
349
-
350
- ///////////////////////////////////////////////////////////////////////////////////////////////////
351
-
352
- """
353
-
354
- def operation_is_3x(self, operation):
355
- """Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
356
- return hasattr(operation, 'is_3x') and operation.is_3x
357
-
358
- def __enter__(self):
359
- """
360
- Open the configuration_file, and write the "header" C++ code to it.
361
-
362
- The "header" consists of a comment (that this is generated code,
363
- so it should not be edited), and includes that are common
364
- to both the CUTLASS 2 and the CUTLASS 3 cases.
365
- """
366
- _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__enter__')
367
- _LOGGER.debug('*** configuration_path (file to write): ' +
368
- str(self.configuration_path))
369
- _LOGGER.debug('*** configuration_name: ' + self.configuration_name)
370
- self.configuration_file = open(self.configuration_path, "w")
371
-
372
- self.configuration_file.write(SubstituteTemplate(self.header_template, {
373
- 'configuration_name': self.configuration_name
374
- }))
375
- self.operations = []
376
- return self
377
-
378
- def emit(self, operation):
379
- """
380
- Write three pieces of C++ code to the configuration_file
381
- (that was opened by the __enter__ method above):
382
-
383
- 1. the header includes that are specific to the operation
384
- (CUTLASS 2 vs. CUTLASS 3);
385
-
386
- 2. the "operation instance" (a "using" declaration ending in "_base"); and
387
-
388
- 3. the "operation name" (declaration and definition of a derived class
389
- of the above operation instance).
390
-
391
- The "using" declaration turns a C++ class name, possibly namespace-qualified,
392
- possibly also with angle brackets, into a C-style, easily demangled identifier.
393
- """
394
- _LOGGER.debug('*** EmitConv3dConfigurationLibrary::emit')
395
- _LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name())
396
- self.operations.append(operation)
397
-
398
- self.configuration_file.write(self.includes_emitter.emit(operation))
399
-
400
- stub_begin = ''
401
- stub_end = ''
402
- # It can be useful to stub (comment) out instantiations for testing.
403
- # In this case, one need only set is_stub to True.
404
- is_stub = False
405
- if is_stub:
406
- stub_begin = "// STUB for now\n#if 0"
407
- stub_end = '#endif // 0'
408
-
409
- self.configuration_file.write(Template(self.instance_template).substitute({
410
- 'configuration_name': self.configuration_name,
411
- 'operation_name': operation.procedural_name(),
412
- 'operation_instance': self.instance_emitter.emit(operation),
413
- 'stub_begin': stub_begin,
414
- 'stub_end': stub_end
415
- }))
416
-
417
- def __exit__(self, exception_type, exception_value, traceback):
418
- """
419
- Write the rest of the C++ code to the configuration_file, and close the file.
420
-
421
- The "rest of the C++ code" has the following components.
422
-
423
- 1. Configuration header: Open the namespace(s), and open the definition
424
- of the "initialize_${configuration_name}" registration function
425
- that registers the operation with the Manifest.
426
- ("Registration" helps turn C++ compile-time polymorphism
427
- (via template parameters) into a run-time choice of parameters.)
428
-
429
- 2. Configuration instance: In the body of the registration function,
430
- make a "using" declaration Operation_${operation_name} for the
431
- operation type (which uses operation_name as its template argument).
432
- Then, tell the manifest about the operation via a "manifest.append" call.
433
- The argument of the call is a new instance of
434
- "SomethingOperation<Operation_${operation_name}>"
435
- (replace Something with a specific name).
436
-
437
- 3. Configuration epilogue: Close the definition of the registration function.
438
-
439
- 4. Epilogue template: Close the namespace(s).
440
- """
441
-
442
- _LOGGER.debug('*** EmitConv3dConfigurationLibrary::__exit__')
443
- _LOGGER.debug('*** configuration_path (file to write): ' +
444
- str(self.configuration_path))
445
- _LOGGER.debug('*** configuration_name: ' + self.configuration_name)
446
-
447
- self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
448
- 'configuration_name': self.configuration_name
449
- }))
450
-
451
- for operation in self.operations:
452
- stub_begin = ''
453
- stub_end = ''
454
- # It can be useful to stub (comment) out instantiations for testing.
455
- # In this case, one need only set is_stub to True.
456
- is_stub = False
457
- if is_stub:
458
- stub_begin = "// STUB for now\n#if 0"
459
- stub_end = "#endif // 0"
460
-
461
- kernel_name = 'ImplicitGemmConvolution'
462
- operation_wrapper = 'Conv3dOperation'
463
- if self.operation_is_3x(operation):
464
- kernel_name = 'ConvUniversalAdapter'
465
- operation_wrapper = 'ConvOperation3x'
466
-
467
- self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
468
- 'configuration_name': self.configuration_name,
469
- 'operation_name': operation.procedural_name(),
470
- 'kernel_name': kernel_name,
471
- 'operation_wrapper': operation_wrapper,
472
- 'stub_begin': stub_begin,
473
- 'stub_end': stub_end
474
- }))
475
-
476
- self.configuration_file.write(self.configuration_epilogue)
477
- self.configuration_file.write(self.epilogue_template)
478
- self.configuration_file.close()
479
-
480
-
481
- ###################################################################################################
482
- ###################################################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py DELETED
@@ -1,250 +0,0 @@
1
- #################################################################################################
2
- #
3
- # Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
- # SPDX-License-Identifier: BSD-3-Clause
5
- #
6
- # Redistribution and use in source and binary forms, with or without
7
- # modification, are permitted provided that the following conditions are met:
8
- #
9
- # 1. Redistributions of source code must retain the above copyright notice, this
10
- # list of conditions and the following disclaimer.
11
- #
12
- # 2. Redistributions in binary form must reproduce the above copyright notice,
13
- # this list of conditions and the following disclaimer in the documentation
14
- # and/or other materials provided with the distribution.
15
- #
16
- # 3. Neither the name of the copyright holder nor the names of its
17
- # contributors may be used to endorse or promote products derived from
18
- # this software without specific prior written permission.
19
- #
20
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
- # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
- # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
- # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
- #
31
- #################################################################################################
32
-
33
- """
34
- Utilities for emitting CUTLASS >= 3 convolution kernels
35
- """
36
-
37
- import enum
38
- import os.path
39
- import shutil
40
- import logging
41
- from string import Template
42
-
43
- try:
44
- import builtins
45
- if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
46
- raise ImportError("Disabling attempt to import cutlass_library")
47
- from cutlass_library.library import *
48
- except ImportError:
49
- from library import *
50
-
51
- _LOGGER = logging.getLogger(__name__)
52
-
53
- ###################################################################################################
54
- #
55
- # Emits single instances of a CUTLASS device-wide operator
56
- #
57
- ###################################################################################################
58
-
59
- class EmitConv3xInstance:
60
- def __init__(self):
61
- _LOGGER.debug("*** EmitConv3xInstance::__init__")
62
-
63
- # Define epilogue type first, so that the mainloop type
64
- # can use it with StageCountAutoCarveout.
65
- self.template = """
66
-
67
- // CUTLASS >= 3 convolution ${conv_kind_name} kernel instance "${operation_name}"
68
- using ${operation_name}_epilogue =
69
- typename cutlass::epilogue::collective::CollectiveBuilder<
70
- ${arch},
71
- ${opcode_class_epi},
72
- ${mma_tile_shape}, // mma tile shape
73
- ${cluster_shape}, // cluster shape
74
- ${epi_tile_mn},
75
- ${element_accumulator},
76
- ${element_compute},
77
- ${element_c}, ${layout_c}, 128 / cute::sizeof_bits_v<${element_c}>,
78
- ${element_d}, ${layout_d}, 128 / cute::sizeof_bits_v<${element_d}>,
79
- ${epilogue_schedule}
80
- // , class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD,ElementCompute>
81
- >::CollectiveOp;
82
-
83
- using ${operation_name}_mainloop =
84
- typename cutlass::conv::collective::CollectiveBuilder<
85
- ${arch},
86
- ${opcode_class_main},
87
- ${conv_kind}, // kFprop, kDgrad, or kWgrad
88
- ${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>,
89
- ${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>,
90
- ${element_accumulator},
91
- ${mma_tile_shape}, // mma tile shape
92
- ${cluster_shape}, // cluster shape
93
- ${stages},
94
- ${kernel_schedule}
95
- >::CollectiveOp;
96
-
97
- using ${operation_name}_problem_shape = cutlass::conv::ConvProblemShape<${conv_kind}, ${operation_name}_mainloop::NumSpatialDimensions>;
98
-
99
- // Unit tests call this "ConvKernel".
100
- // Conv operator ${operation_name}
101
- using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
102
- ${operation_name}_problem_shape,
103
- ${operation_name}_mainloop,
104
- ${operation_name}_epilogue,
105
- ${tile_scheduler}
106
- >;
107
- """
108
-
109
- def arch_number_to_type(self, arch: int) -> str:
110
- return f"cutlass::arch::Sm{arch}"
111
-
112
- def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
113
- mma_m = cta_m
114
- mma_n = cta_n
115
- mma_k = cta_k
116
-
117
- if operation.arch >= 100:
118
- # MmaTileShape (mma_m, mma_n, mma_k) is passed to kernel mainloop where
119
- # mma_m = cta_m for 1sm version and mma_m = cta_m * 2 for 2sm version.
120
- # If schedule is auto and cluster size is static and cta_m % 64 == 0 and cluster_m % 2 == 0, 2sm kernel version is allocated,
121
- # otherwise 1sm kernel is allocated.
122
- cta_m_per_mma_instruction = 1
123
- if "2sm" in operation.procedural_name() :
124
- cta_m_per_mma_instruction = 2
125
- elif "1sm" in operation.procedural_name() :
126
- cta_m_per_mma_instruction = 1
127
- elif operation.tile_description.cluster_shape[0] > 0 and operation.tile_description.cluster_shape[0] % 2 == 0 and cta_m % 64 == 0 :
128
- cta_m_per_mma_instruction = 2
129
- mma_m = cta_m * cta_m_per_mma_instruction
130
-
131
- # For all three kinds of convolutions, the tile shape's K mode
132
- # differs from GEMM in that needs to be wrapped in a Shape.
133
- # For Wgrad convolutions specifically,
134
- # the N tile shape also needs to be wrapped in a Shape.
135
- m_template = 'cute::_${mma_m}'
136
- if operation.conv_kind == ConvKind.Wgrad:
137
- n_template = 'cute::Shape<cute::_${mma_n}>'
138
- else:
139
- n_template = 'cute::_${mma_n}'
140
- k_template = 'cute::Shape<cute::_${mma_k}>'
141
-
142
- mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
143
- values = {
144
- 'mma_m': mma_m,
145
- 'mma_n': mma_n,
146
- 'mma_k': mma_k
147
- }
148
- return Template(mma_tile_shape_template).substitute(values)
149
-
150
- def cluster_shape(self, operation) -> str:
151
- m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)'
152
- n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)'
153
- k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)'
154
- cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
155
- values = {
156
- 'cluster_shape_m': operation.tile_description.cluster_shape[0],
157
- 'cluster_shape_n': operation.tile_description.cluster_shape[1],
158
- 'cluster_shape_k': operation.tile_description.cluster_shape[2],
159
- }
160
- return Template(cluster_shape_template).substitute(values)
161
-
162
- def stage_count(self, operation) -> str:
163
- # stages == 0 tells builder to pick the number of stages automatically
164
- namespace_prefix = 'cutlass::conv::collective::'
165
- if operation.tile_description.stages > 0:
166
- return f"{namespace_prefix}StageCount<{str(operation.tile_description.stages)}>"
167
- else:
168
- return f"{namespace_prefix}StageCountAutoCarveout<sizeof(typename {operation.procedural_name()}_epilogue::SharedStorage)>"
169
-
170
- def emit(self, operation) -> str:
171
- _LOGGER.debug("*** EmitConv3xInstance::emit")
172
- _LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
173
-
174
- # Identify the operation as CUTLASS 3 by its is_3x field
175
- if (not hasattr(operation, 'is_3x')) or (not operation.is_3x):
176
- raise RuntimeError("operation must be a CUTLASS 3 operation")
177
-
178
- epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto"
179
- opcode_class_main = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class]
180
- opcode_class_epi = opcode_class_main
181
-
182
- tile_shape = operation.tile_description.tile_shape
183
- cluster_m = operation.tile_description.cluster_shape[0]
184
- cluster_n = operation.tile_description.cluster_shape[1]
185
-
186
- cta_m, cta_n, cta_k = tile_shape
187
- # account for static/dynamic cluster shapes
188
- if operation.arch >= 100:
189
- cta_m = cta_m // cluster_m if cluster_m > 0 else cta_m
190
- cta_n = cta_n // cluster_n if cluster_n > 0 else cta_n
191
-
192
- warp_count = operation.tile_description.warp_count
193
- epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule]
194
-
195
- # KernelScheduleTag and TileSchedulerTag both hard-code the
196
- # namespace qualification of KernelScheduleAuto as
197
- # "cutlass::gemm::collective::" (unless the tag is 'void').
198
- #
199
- # For TileSchedulerTag, this namespace is fine, since CUTLASS 3
200
- # convolutions use the same tile schedulers (from the same
201
- # cutlass::gemm::collective namespace) as GEMMs.
202
- kernel_schedule = KernelScheduleTag[operation.kernel_schedule].replace('gemm::', 'conv::')
203
- tile_scheduler = TileSchedulerTag[operation.tile_scheduler]
204
- opcode_class = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class]
205
-
206
- values = {
207
- 'operation_name': operation.procedural_name(),
208
- 'conv_kind': ConvKindTag[operation.conv_kind],
209
- 'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
210
- 'element_a': DataTypeTag[operation.A.element],
211
- 'layout_a': LayoutTag[operation.A.layout],
212
- 'align_a': int(operation.A.alignment),
213
- 'element_b': DataTypeTag[operation.B.element],
214
- 'layout_b': LayoutTag[operation.B.layout],
215
- 'align_b': int(operation.B.alignment),
216
- 'element_c': DataTypeTag[operation.C.element],
217
- 'layout_c': LayoutTag[operation.C.layout],
218
- 'align_c': int(operation.C.alignment),
219
- 'element_d': DataTypeTag[operation.D.element],
220
- 'layout_d': LayoutTag[operation.D.layout],
221
- 'align_d': int(operation.D.alignment),
222
- 'element_accumulator': DataTypeTag[operation.accumulator_type()],
223
- 'opcode_class': opcode_class,
224
- 'arch': self.arch_number_to_type(operation.arch),
225
- 'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k),
226
- 'cluster_shape': self.cluster_shape(operation),
227
- 'opcode_class_epi': opcode_class_epi,
228
- 'opcode_class_main': opcode_class_main,
229
- 'epi_tile_mn': epi_tile_mn,
230
- 'stages': self.stage_count(operation),
231
- 'kernel_schedule': kernel_schedule,
232
- 'epilogue_schedule': epilogue_schedule,
233
- 'tile_scheduler': tile_scheduler,
234
- 'element_compute': DataTypeTag[operation.element_compute]
235
- }
236
- return Template(self.template).substitute(values)
237
-
238
- class EmitConv3xIncludes:
239
- def __init__(self):
240
- _LOGGER.debug("*** EmitConv3xIncludes::__init__")
241
- self.includes = ['conv_operation_3x.hpp',
242
- 'cutlass/conv/device/conv_universal_adapter.hpp',
243
- 'cutlass/conv/kernel/conv_universal.hpp',
244
- 'cutlass/conv/collective/collective_builder.hpp',
245
- 'cutlass/epilogue/collective/collective_builder.hpp']
246
-
247
- def emit(self, operation) -> str:
248
- _LOGGER.debug("*** EmitConv3xIncludes::emit")
249
- return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
250
- "\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"