kernels-bot commited on
Commit
28b1792
·
verified ·
1 Parent(s): 95681b6

Uploaded using `kernel-builder` (batch 5/6).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py +116 -0
  3. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py +134 -0
  4. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py +47 -0
  5. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py +258 -0
  6. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py +98 -0
  7. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py +329 -0
  8. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py +168 -0
  9. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py +33 -0
  10. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py +272 -0
  11. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py +194 -0
  12. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py +53 -0
  13. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py +91 -0
  14. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py +254 -0
  15. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py +324 -0
  16. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py +336 -0
  17. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py +294 -0
  18. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py +306 -0
  19. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py +277 -0
  20. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py +137 -0
  21. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py +42 -0
  22. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py +143 -0
  23. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py +120 -0
  24. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py +169 -0
  25. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py +64 -0
  26. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py +90 -0
  27. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py +217 -0
  28. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py +164 -0
  29. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py +53 -0
  30. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py +97 -0
  31. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py +59 -0
  32. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py +319 -0
  33. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py +46 -0
  34. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py +109 -0
  35. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py +2145 -0
  36. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py +509 -0
  37. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py +121 -0
  38. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py +140 -0
  39. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py +455 -0
  40. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py +35 -0
  41. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py +33 -0
  42. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py +126 -0
  43. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py +33 -0
  44. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py +267 -0
  45. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py +936 -0
  46. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py +56 -0
  47. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py +176 -0
  48. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py +98 -0
  49. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py +569 -0
  50. build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py +36 -0
.gitattributes CHANGED
@@ -61,3 +61,4 @@ build/torch212-cxx11-cu132-aarch64-linux/_deep_gemm_cuda_388adb9.abi3.so filter=
61
  build/torch211-cu128-x86_64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
62
  build/torch211-cu130-x86_64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
63
  build/torch212-cu130-x86_64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
 
 
61
  build/torch211-cu128-x86_64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
62
  build/torch211-cu130-x86_64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
63
  build/torch212-cu130-x86_64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
64
+ build/torch212-cu132-x86_64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Emitter for Sm100 Epilogue Visitor
35
+ """
36
+
37
+ from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag
38
+ from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape
39
+ from cutlass_cppgen.backend import GemmOperationUniversal
40
+ from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
41
+ from cutlass_cppgen.backend.evt.ir.node import TupleEmitter
42
+
43
+
44
+ class Sm100CollectiveEpilogue:
45
+ def __init__(self, tile_description,
46
+ kernel_schedule,
47
+ epilogue_schedule,
48
+ element_accumulator,
49
+ element_d,
50
+ fusion_callbacks) -> None:
51
+
52
+ self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule)
53
+ self.element_accumulator = element_accumulator
54
+ if fusion_callbacks.dag_ir.has_node("C"):
55
+ self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element
56
+ else:
57
+ self.element_c = DataType.void
58
+ self.element_d = element_d
59
+ self.schedule = epilogue_schedule
60
+ self.fusion_callbacks = fusion_callbacks
61
+ self.opclass = tile_description.math_instruction.opcode_class
62
+
63
+ @property
64
+ def CtaTileMNK(self) -> str:
65
+ """
66
+ The threadblock shape
67
+ """
68
+ return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
69
+
70
+ @property
71
+ def EpilogueTileType(self) -> str:
72
+ """
73
+ The epilogue tile type
74
+ """
75
+ return "cutlass::epilogue::collective::EpilogueTileAuto"
76
+
77
+ @property
78
+ def Schedule(self) -> str:
79
+ return EpilogueScheduleTag[self.schedule]
80
+
81
+ def emit(self):
82
+ tuple_emitter = TupleEmitter("int64_t")
83
+ stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl
84
+ stride_C_str = stride_D_str
85
+ if self.fusion_callbacks.dag_ir.has_node("C"):
86
+ stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl
87
+
88
+ callback_decl, callback_name = self.fusion_callbacks.emit()
89
+ return callback_name, f"""
90
+ using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor<
91
+ {OpcodeClassTag[self.opclass]},
92
+ {self.CtaTileMNK}, {self.EpilogueTileType},
93
+ {DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
94
+ {self.Schedule}, {stride_C_str}, {stride_D_str},
95
+ false /* IsPerColScaleSupported */,
96
+ false /* IsBlockScaleSupported */
97
+ >;
98
+ {callback_decl}
99
+ """
100
+
101
+
102
+ class Sm100Emitter:
103
+ def __init__(self, operation: GemmOperationUniversal, graph) -> None:
104
+ fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False)
105
+
106
+ self.collective_epilogue = Sm100CollectiveEpilogue(
107
+ tile_description=operation.tile_description,
108
+ kernel_schedule=operation.tile_description.kernel_schedule,
109
+ epilogue_schedule=operation.tile_description.epilogue_schedule,
110
+ element_accumulator=operation.tile_description.math_instruction.element_accumulator,
111
+ element_d=fusion_callbacks.dag_ir.get_node_meta("D").element,
112
+ fusion_callbacks=fusion_callbacks
113
+ )
114
+
115
+ def emit(self):
116
+ return self.collective_epilogue.emit()
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from pycute import product
34
+
35
+ from cutlass_library import DataTypeSize, DataTypeTag
36
+
37
+ from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl
38
+ import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
39
+
40
+ from cutlass_cppgen.backend.library import FloatRoundStyleTag
41
+
42
+
43
+ Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl
44
+ Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl
45
+ Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl
46
+ Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl
47
+ Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl
48
+ Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl
49
+ Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl
50
+ Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl
51
+ Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl
52
+ Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl
53
+
54
+
55
+ class Sm100AuxLoadImpl(AuxLoadImpl):
56
+
57
+ @property
58
+ def descriptor(self) -> str:
59
+ """
60
+ Descriptor for Aux Load
61
+ """
62
+ return f"{self.name_camel}Descriptor"
63
+
64
+ def decl_descriptor(self) -> str:
65
+ """
66
+ Declare the descriptor type
67
+ """
68
+ return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
69
+
70
+ @property
71
+ def type_decl(self):
72
+ """
73
+ Return the string defining the type
74
+ """
75
+ if self._type_decl is not None:
76
+ return self._type_decl
77
+
78
+ self._type_decl = self.decl_descriptor()
79
+ self._type_decl += f"""
80
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
81
+ {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
82
+ {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
83
+ >;
84
+ """
85
+ return self._type_decl
86
+
87
+ def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
88
+ """
89
+ Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
90
+ """
91
+ return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
92
+
93
+
94
+ class Sm100AuxStoreImpl(AuxStoreImpl):
95
+
96
+ @property
97
+ def descriptor(self) -> str:
98
+ """
99
+ Descriptor for Aux Load
100
+ """
101
+ return f"{self.name_camel}Descriptor"
102
+
103
+ def decl_descriptor(self) -> str:
104
+ """
105
+ Declare the descriptor type
106
+ """
107
+ return f"""
108
+ using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor<
109
+ EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
110
+ >;
111
+ """
112
+ @property
113
+ def type_decl(self):
114
+ """
115
+ Return the string defining the type
116
+ """
117
+ if self._type_decl is not None:
118
+ return self._type_decl
119
+
120
+ self._type_decl = self.decl_descriptor()
121
+ self._type_decl += f"""
122
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
123
+ {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
124
+ {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
125
+ typename {self.descriptor}::CopyOpR2S
126
+ >;
127
+ """
128
+ return self._type_decl
129
+
130
+ def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
131
+ """
132
+ Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
133
+ """
134
+ return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Emitter for Sm80 Epilogue Visitor
35
+ """
36
+
37
+ from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
38
+ from cutlass_cppgen.backend import GemmOperationUniversal
39
+
40
+
41
+ class Sm80Emitter:
42
+ def __init__(self, operation: GemmOperationUniversal, graph) -> None:
43
+ self.fusion_callbacks = FusionCallbacks(graph, cc=80)
44
+
45
+ def emit(self):
46
+ callback_decl, callback_name = self.fusion_callbacks.emit()
47
+ return callback_name, callback_decl
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from cutlass_library import DataTypeSize, DataTypeTag
34
+
35
+ from cutlass_cppgen.backend.evt.ir import (
36
+ # Load Node
37
+ AccumulatorImpl,
38
+ AuxLoadImpl,
39
+ ColumnBroadcastImpl,
40
+ LoadNode,
41
+ LoadSrcImpl,
42
+ RowBroadcastImpl,
43
+ ScalarBroadcastImpl,
44
+ # Compute Node
45
+ ComputeImpl,
46
+ # Store Node
47
+ AuxStoreImpl,
48
+ ColumnReductionImpl,
49
+ RowReductionImpl,
50
+ ScalarReductionImpl
51
+ )
52
+
53
+ from cutlass_cppgen.backend.library import (
54
+ FloatRoundStyleTag,
55
+ FunctionalOp,
56
+ op_tag,
57
+ )
58
+
59
+
60
+ class Sm80AccumulatorImpl(AccumulatorImpl):
61
+
62
+ @property
63
+ def type_decl(self):
64
+ """
65
+ Return the string defining the type
66
+ """
67
+ if self._type_decl is not None:
68
+ return self._type_decl
69
+
70
+ self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n"""
71
+ return self._type_decl
72
+
73
+
74
+ class Sm80AuxLoadImpl(AuxLoadImpl):
75
+
76
+ @property
77
+ def type_decl(self):
78
+ """
79
+ Return the string defining the type
80
+ """
81
+ if self._type_decl is not None:
82
+ return self._type_decl
83
+
84
+ self._type_decl = f"""
85
+ using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad<
86
+ OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl}
87
+ >;
88
+ """
89
+ return self._type_decl
90
+
91
+
92
+ class Sm80LoadSrcImpl(Sm80AuxLoadImpl):
93
+ pass
94
+
95
+
96
+ class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl):
97
+ def __init__(self, node: LoadNode) -> None:
98
+ super().__init__(node)
99
+ self.broadcast_count = 1
100
+ self.reduction_fn = FunctionalOp.Multiplies
101
+
102
+ @property
103
+ def type_decl(self):
104
+ """
105
+ Return the string defining the type
106
+ """
107
+ if self._type_decl is not None:
108
+ return self._type_decl
109
+
110
+ self._type_decl = f"""
111
+ using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
112
+ {DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
113
+ >;
114
+ """
115
+ return self._type_decl
116
+
117
+
118
+ class Sm80RowBroadcastImpl(RowBroadcastImpl):
119
+
120
+ @property
121
+ def type_decl(self):
122
+ """
123
+ Return the string defining the type
124
+ """
125
+ if self._type_decl is not None:
126
+ return self._type_decl
127
+
128
+ self._type_decl = f"""
129
+ using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast<
130
+ OutputTileThreadMap, {DataTypeTag[self.element]},
131
+ {self.stride_mnl}
132
+ >;
133
+ """
134
+ return self._type_decl
135
+
136
+
137
+ class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl):
138
+
139
+ @property
140
+ def type_decl(self):
141
+ """
142
+ Return the string defining the type
143
+ """
144
+ if self._type_decl is not None:
145
+ return self._type_decl
146
+
147
+ self._type_decl = f"""
148
+ using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast<
149
+ OutputTileThreadMap, {DataTypeTag[self.element]},
150
+ {self.stride_mnl}
151
+ >;
152
+ """
153
+ return self._type_decl
154
+
155
+
156
+ class Sm80ComputeImpl(ComputeImpl):
157
+
158
+ @property
159
+ def type_decl(self):
160
+ """
161
+ Return the string defining the type
162
+ """
163
+ if self._type_decl is not None:
164
+ return self._type_decl
165
+
166
+ self._type_decl = f"""
167
+ using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute<
168
+ {op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
169
+ {FloatRoundStyleTag[self.round_style]}
170
+ >;
171
+ """
172
+ return self._type_decl
173
+
174
+
175
+ class Sm80AuxStoreImpl(AuxStoreImpl):
176
+
177
+ @property
178
+ def type_decl(self):
179
+ """
180
+ Return the string defining the type
181
+ """
182
+ if self._type_decl is not None:
183
+ return self._type_decl
184
+
185
+ self._type_decl = f"""
186
+ using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore<
187
+ OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]},
188
+ {self.stride_mnl}
189
+ >;
190
+ """
191
+ return self._type_decl
192
+
193
+
194
+ class Sm80StoreDImpl(Sm80AuxStoreImpl):
195
+ pass
196
+
197
+
198
+ class Sm80ColumnReductionImpl(ColumnReductionImpl):
199
+
200
+ @property
201
+ def type_decl(self):
202
+ """
203
+ Return the string defining the type
204
+ """
205
+ if self._type_decl is not None:
206
+ return self._type_decl
207
+
208
+ self._type_decl = f"""
209
+ using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction<
210
+ {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
211
+ OutputTileThreadMap, {DataTypeTag[self.element]},
212
+ {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
213
+ {self.stride_mnl}
214
+ >;
215
+ """
216
+ return self._type_decl
217
+
218
+
219
+ class Sm80RowReductionImpl(RowReductionImpl):
220
+
221
+ @property
222
+ def type_decl(self):
223
+ """
224
+ Return the string defining the type
225
+ """
226
+ if self._type_decl is not None:
227
+ return self._type_decl
228
+
229
+ self._type_decl = f"""
230
+ using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction<
231
+ {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
232
+ OutputTileThreadMap, {DataTypeTag[self.element]},
233
+ {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
234
+ {self.stride_mnl}
235
+ >;
236
+ """
237
+ return self._type_decl
238
+
239
+
240
+ class Sm80ScalarReductionImpl(ScalarReductionImpl):
241
+
242
+ @property
243
+ def type_decl(self):
244
+ """
245
+ Return the string defining the type
246
+ """
247
+ if self._type_decl is not None:
248
+ return self._type_decl
249
+
250
+ self._type_decl = f"""
251
+ using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction<
252
+ {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
253
+ OutputTileThreadMap, {DataTypeTag[self.element]},
254
+ {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
255
+ {self.stride_mnl}
256
+ >;
257
+ """
258
+ return self._type_decl
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Emitter for Sm90 Epilogue Visitor
35
+ """
36
+
37
+ from cutlass_library import DataTypeTag, EpilogueScheduleTag
38
+ from cutlass_cppgen.backend import GemmOperationUniversal
39
+ from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
40
+
41
+
42
+ class CollectiveEpilogue:
43
+ def __init__(self, tile_description,
44
+ schedule,
45
+ element_c,
46
+ element_d,
47
+ fusion_callbacks) -> None:
48
+
49
+ self.cta_tile_mnk = tile_description.threadblock_shape
50
+ self.element_c = element_c
51
+ self.element_d = element_d
52
+ self.schedule = schedule
53
+ self.fusion_callbacks = fusion_callbacks
54
+
55
+ @property
56
+ def CtaTileMNK(self) -> str:
57
+ """
58
+ The threadblock shape
59
+ """
60
+ return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
61
+
62
+ @property
63
+ def EpilogueTileType(self) -> str:
64
+ """
65
+ The epilogue tile type
66
+ """
67
+ return "cutlass::epilogue::collective::EpilogueTileAuto"
68
+
69
+ @property
70
+ def Schedule(self) -> str:
71
+ return EpilogueScheduleTag[self.schedule]
72
+
73
+ def emit(self):
74
+ callback_decl, callback_name = self.fusion_callbacks.emit()
75
+ return callback_name, f"""
76
+ using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
77
+ {self.CtaTileMNK}, {self.EpilogueTileType},
78
+ {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
79
+ {self.Schedule}
80
+ >;
81
+ {callback_decl}
82
+ """
83
+
84
+
85
+ class Sm90Emitter:
86
+ def __init__(self, operation: GemmOperationUniversal, graph) -> None:
87
+ fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False)
88
+
89
+ self.collective_epilogue = CollectiveEpilogue(
90
+ tile_description=operation.tile_description,
91
+ schedule=operation.tile_description.epilogue_schedule,
92
+ element_c=operation.C.element,
93
+ element_d=operation.C.element,
94
+ fusion_callbacks=fusion_callbacks
95
+ )
96
+
97
+ def emit(self):
98
+ return self.collective_epilogue.emit()
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from pycute import product
34
+
35
+ from cutlass_library import DataTypeSize, DataTypeTag
36
+ from cutlass_cppgen.backend.evt.ir import (
37
+ # Load Node
38
+ AccumulatorImpl,
39
+ AuxLoadImpl,
40
+ ColumnBroadcastImpl,
41
+ LoadNode,
42
+ LoadSrcImpl,
43
+ RowBroadcastImpl,
44
+ ScalarBroadcastImpl,
45
+ # Compute Node
46
+ ComputeImpl,
47
+ ComputeNode,
48
+ # Store Node
49
+ AuxStoreImpl,
50
+ ColumnReductionImpl,
51
+ RowReductionImpl,
52
+ ScalarReductionImpl,
53
+ StoreNode,
54
+ StoreDImpl,
55
+ )
56
+ from cutlass_cppgen.backend.library import (
57
+ FloatRoundStyleTag,
58
+ FunctionalOp,
59
+ op_tag,
60
+ )
61
+
62
+
63
+ class Sm90AccumulatorImpl(AccumulatorImpl):
64
+
65
+ @property
66
+ def type_decl(self):
67
+ """
68
+ Return the string defining the type
69
+ """
70
+ if self._type_decl is not None:
71
+ return self._type_decl
72
+
73
+ self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n"""
74
+ return self._type_decl
75
+
76
+
77
+ class Sm90LoadSrcImpl(LoadSrcImpl):
78
+
79
+ @property
80
+ def type_decl(self):
81
+ """
82
+ Return the string defining the type
83
+ """
84
+ if self._type_decl is not None:
85
+ return self._type_decl
86
+
87
+ self._type_decl = f"""
88
+ using ElementC = {DataTypeTag[self.element]};
89
+ using StrideC = {self.stride_mnl};
90
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>;
91
+ """
92
+ return self._type_decl
93
+
94
+
95
+ class Sm90AuxLoadImpl(AuxLoadImpl):
96
+
97
+ @property
98
+ def descriptor(self) -> str:
99
+ """
100
+ Descriptor for Aux Load
101
+ """
102
+ return f"{self.name_camel}Descriptor"
103
+
104
+ def decl_descriptor(self) -> str:
105
+ """
106
+ Declare the descriptor type
107
+ """
108
+ return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
109
+
110
+ @property
111
+ def type_decl(self):
112
+ """
113
+ Return the string defining the type
114
+ """
115
+ if self._type_decl is not None:
116
+ return self._type_decl
117
+
118
+ self._type_decl = self.decl_descriptor()
119
+ self._type_decl += f"""
120
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
121
+ {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
122
+ {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
123
+ >;
124
+ """
125
+ return self._type_decl
126
+
127
+ def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
128
+ """
129
+ Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
130
+ """
131
+ return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
132
+
133
+
134
+ class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl):
135
+ def __init__(self, node: LoadNode) -> None:
136
+ super().__init__(node)
137
+ self.broadcast_count = 1
138
+ self.reduction_fn = FunctionalOp.Multiplies
139
+
140
+ @property
141
+ def type_decl(self):
142
+ """
143
+ Return the string defining the type
144
+ """
145
+ if self._type_decl is not None:
146
+ return self._type_decl
147
+
148
+ self._type_decl = f"""
149
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
150
+ {DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
151
+ >;
152
+ """
153
+ return self._type_decl
154
+
155
+
156
+ class Sm90RowBroadcastImpl(RowBroadcastImpl):
157
+ @property
158
+ def type_decl(self):
159
+ """
160
+ Return the string defining the type
161
+ """
162
+ if self._type_decl is not None:
163
+ return self._type_decl
164
+
165
+ self._type_decl = f"""
166
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
167
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
168
+ {self.stride_mnl}
169
+ >;
170
+ """
171
+ return self._type_decl
172
+
173
+
174
+ class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
175
+
176
+ @property
177
+ def type_decl(self):
178
+ """
179
+ Return the string defining the type
180
+ """
181
+ if self._type_decl is not None:
182
+ return self._type_decl
183
+
184
+ self._type_decl = f"""
185
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
186
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
187
+ {self.stride_mnl}
188
+ >;
189
+ """
190
+ return self._type_decl
191
+
192
+
193
+ class Sm90ComputeImpl(ComputeImpl):
194
+
195
+ @property
196
+ def type_decl(self):
197
+ """
198
+ Return the string defining the type
199
+ """
200
+ if self._type_decl is not None:
201
+ return self._type_decl
202
+
203
+ self._type_decl = f"""
204
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute<
205
+ {op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
206
+ {FloatRoundStyleTag[self.round_style]}
207
+ >;
208
+ """
209
+ return self._type_decl
210
+
211
+
212
+ class Sm90AuxStoreImpl(AuxStoreImpl):
213
+
214
+ @property
215
+ def descriptor(self) -> str:
216
+ """
217
+ Descriptor for Aux Load
218
+ """
219
+ return f"{self.name_camel}Descriptor"
220
+
221
+ def decl_descriptor(self) -> str:
222
+ """
223
+ Declare the descriptor type
224
+ """
225
+ return f"""
226
+ using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
227
+ EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
228
+ >;
229
+ """
230
+ @property
231
+ def type_decl(self):
232
+ """
233
+ Return the string defining the type
234
+ """
235
+ if self._type_decl is not None:
236
+ return self._type_decl
237
+
238
+ self._type_decl = self.decl_descriptor()
239
+ self._type_decl += f"""
240
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
241
+ {self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
242
+ {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
243
+ typename {self.descriptor}::CopyOpR2S
244
+ >;
245
+ """
246
+ return self._type_decl
247
+
248
+ def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
249
+ """
250
+ Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
251
+ """
252
+ return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
253
+
254
+
255
+ class Sm90StoreDImpl(StoreDImpl):
256
+
257
+ @property
258
+ def type_decl(self):
259
+ """
260
+ Return the string defining the type
261
+ """
262
+ return f"""
263
+ using ElementD = {DataTypeTag[self.element]};
264
+ using StrideD = {self.stride_mnl};
265
+ """
266
+
267
+
268
+ class Sm90ColumnReductionImpl(ColumnReductionImpl):
269
+
270
+ @property
271
+ def type_decl(self):
272
+ """
273
+ Return the string defining the type
274
+ """
275
+ if self._type_decl is not None:
276
+ return self._type_decl
277
+
278
+ self._type_decl = f"""
279
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction<
280
+ {op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0,
281
+ typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
282
+ {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
283
+ {self.stride_mnl}
284
+ >;
285
+ """
286
+ return self._type_decl
287
+
288
+
289
+ class Sm90RowReductionImpl(RowReductionImpl):
290
+
291
+
292
+ @property
293
+ def type_decl(self):
294
+ """
295
+ Return the string defining the type
296
+ """
297
+ if self._type_decl is not None:
298
+ return self._type_decl
299
+
300
+ self._type_decl = f"""
301
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction<
302
+ {op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */,
303
+ typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
304
+ {DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
305
+ {self.stride_mnl}
306
+ >;
307
+ """
308
+ return self._type_decl
309
+
310
+
311
+ class Sm90ScalarReductionImpl(ScalarReductionImpl):
312
+
313
+
314
+ @property
315
+ def type_decl(self):
316
+ """
317
+ Return the string defining the type
318
+ """
319
+ if self._type_decl is not None:
320
+ return self._type_decl
321
+
322
+ self._type_decl = f"""
323
+ using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction<
324
+ {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
325
+ {DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]},
326
+ {FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}
327
+ >;
328
+ """
329
+ return self._type_decl
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Epilogue Visitor interface for compiling, and running visitor-based epilogue.
35
+ """
36
+
37
+ import ctypes
38
+
39
+ from cutlass_cppgen.utils.lazy_import import lazy_import
40
+ cuda = lazy_import("cuda.cuda")
41
+ from cutlass_library import DataType
42
+ import numpy as np
43
+
44
+ from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase
45
+ import cutlass_cppgen.backend.evt.backend
46
+ from cutlass_cppgen.backend.frontend import TensorFrontend
47
+ from cutlass_cppgen.utils.datatypes import is_numpy_tensor
48
+ from cutlass_cppgen.backend.evt.passes.util import cc_map
49
+
50
+
51
+ class EpilogueFunctorVisitor(EpilogueFunctorBase):
52
+ """
53
+ Apply an epilogue functor described by the epilogue EVT
54
+
55
+ :param cc: compute capability
56
+ :param visitor_frontend: user-provide visitor frontend
57
+
58
+ """
59
+ def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
60
+ # Type of Emitter based on CC
61
+ self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
62
+
63
+ # Visitor Types
64
+ self.visitor = visitor
65
+ self.graph = visitor.dag_ir
66
+
67
+ # Data types
68
+ self.element_epilogue = element_compute # element compute
69
+ self.element_output = self.graph.get_node_meta('D').underlying_impl.element
70
+
71
+ # Epilogue Thread Type
72
+ epilogue_thread_type = self.visitor.epilogue_thread_type
73
+ if cc_map[cc] in [90, 100]:
74
+ self.arg_c_type = self.visitor.arg_c_type
75
+ self.arg_d_type = self.visitor.arg_d_type
76
+ output_names = self.visitor.return_names
77
+ reduction_names = self.visitor.reduction_names
78
+
79
+ # Epilogue stages specialized for sm80 kernel
80
+ if cc == 80:
81
+ if hasattr(self.visitor, "epilogue_stages"):
82
+ self.epilogue_stages = self.visitor.epilogue_stages
83
+ assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue"
84
+
85
+ # Epilogue Argument Type
86
+ class _Arguments(ctypes.Structure):
87
+ """
88
+ Concepts:
89
+ class _EpilogueArguments(ctypes.Structure):
90
+ _fields_ = [
91
+ ("epilogue", _Arguments), <- this class
92
+ ("ptr_C", ctypes.c_void_p),
93
+ ("stride_C", StrideBatched_),
94
+ ("ptr_D", ctypes.c_void_p),
95
+ ("stride_D", StrideBatched_)
96
+ ]
97
+ """
98
+ _fields_ = [
99
+ ("output_op", epilogue_thread_type)
100
+ ]
101
+
102
+ def __init__(self, kwargs: dict) -> None:
103
+ # The user-input kwargs is a dict of (name: tensors)
104
+ # We first convert all of them to device pointers
105
+ ptr_kwargs = {}
106
+ for key in kwargs.keys():
107
+ is_output = key in output_names and key not in reduction_names
108
+ ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output)
109
+ # Initialize the thread arguments
110
+ self.output_op = epilogue_thread_type(ptr_kwargs)
111
+
112
+ def get_tensor_ptr(self, tensor_name, kwargs, is_output=False):
113
+ """
114
+ Helper function for extracting device pointer
115
+ """
116
+ # Skip the special tensors
117
+ if cc in [90, 100]:
118
+ if tensor_name in ["C", "D"]:
119
+ return 0
120
+ if tensor_name not in kwargs.keys():
121
+ raise ValueError(f"Tensor {tensor_name} is not provided.")
122
+ tensor = kwargs[tensor_name]
123
+
124
+ # For float scalar constant, directly return the value
125
+ if isinstance(tensor, float):
126
+ return tensor
127
+
128
+ # The tensor frontend returns a device buffer for np.ndarray
129
+ # and device ptr for other frontends
130
+ buffer_or_ptr = TensorFrontend.argument(tensor, is_output)
131
+ if is_numpy_tensor(tensor):
132
+ # Remember the host tensor for later synchronization
133
+ setattr(self, f"{tensor_name}_buffer", buffer_or_ptr)
134
+ setattr(self, f"{tensor_name}_host", tensor)
135
+ return int(buffer_or_ptr.ptr)
136
+ else:
137
+ return int(buffer_or_ptr)
138
+
139
+ def sync(self):
140
+ """
141
+ Synchronize the results from device to host
142
+ """
143
+ for name in output_names:
144
+ if hasattr(self, f"{name}_host"):
145
+ host_tensor = getattr(self, f"{name}_host")
146
+ tensor_ptr = getattr(self, f"{name}_buffer").ptr
147
+ (err,) = cuda.cuMemcpyDtoH(
148
+ host_tensor,
149
+ tensor_ptr,
150
+ host_tensor.size * host_tensor.itemsize,
151
+ )
152
+ if err != cuda.CUresult.CUDA_SUCCESS:
153
+ raise RuntimeError("CUDA Error %s" % str(err))
154
+
155
+ self.epilogue_type = _Arguments
156
+
157
+ def emit(self, operation):
158
+ """
159
+ Emit the C++ code
160
+ """
161
+ emitter = self.emit_cls(operation, self.graph)
162
+ return emitter.emit()
163
+
164
+ def get_smem_size(self, tile_description):
165
+ """
166
+ Get the shared memory size in bytes
167
+ """
168
+ return self.visitor.get_smem_size(tile_description)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Base class for Python EVT Frontend
35
+ """
36
+
37
+ from typing import Union
38
+
39
+ from cutlass_library import DataType
40
+ from cutlass_cppgen.backend.evt.ir import (
41
+ ComputeNode,
42
+ DAGIR,
43
+ LayoutNode,
44
+ LoadNode,
45
+ StoreNode,
46
+ )
47
+ from cutlass_cppgen.backend.evt.passes import (
48
+ EVTGraphDrawer,
49
+ EVTPassManager,
50
+ GetSmemSize,
51
+ PassDAG2Tree,
52
+ PassGetArgumentType,
53
+ PassGetImpl,
54
+ PassFixElementD,
55
+ PassLayoutManipulateElimination,
56
+ PassPreprocessRed,
57
+ PassShapeTypePropagation,
58
+ )
59
+ from cutlass_cppgen.backend.evt.passes.util import cc_map
60
+ from cutlass_cppgen.backend.utils import device_cc
61
+ from cutlass_cppgen.epilogue.evt_ops import permute, reshape
62
+ from cutlass_cppgen.utils.datatypes import library_type
63
+
64
+
65
+ class EVTFrontendBase:
66
+ layout_fns = {
67
+ "permute": permute,
68
+ "reshape": reshape
69
+ }
70
+
71
+ def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None:
72
+ self.cc = cc
73
+ self.element_compute = library_type(element_compute)
74
+ self.dag_ir = DAGIR(self.cc, self.element_compute)
75
+ self.compute_cnt = 0
76
+ self.layout_cnt = 0
77
+ self.imm_cnt = 0
78
+
79
+ self.pass_manager = EVTPassManager(
80
+ self.dag_ir,
81
+ [
82
+ PassPreprocessRed,
83
+ PassGetArgumentType,
84
+ PassShapeTypePropagation,
85
+ PassLayoutManipulateElimination,
86
+ PassGetImpl,
87
+ PassDAG2Tree,
88
+ PassFixElementD
89
+ ] + additional_passes)
90
+
91
+ if self.cc == 80:
92
+ self._epilogue_stages = 1
93
+ else:
94
+ self._epilogue_stages = None
95
+
96
+ @property
97
+ def epilogue_stages(self):
98
+ return self._epilogue_stages
99
+
100
+ @epilogue_stages.setter
101
+ def epilogue_stages(self, stages):
102
+ self._epilogue_stages = stages
103
+
104
+
105
+ def parse(self, *args, **kwargs):
106
+ raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class")
107
+
108
+ def trace(self, *args, **kwargs):
109
+ # Parse the input
110
+ self.parse(*args, **kwargs)
111
+
112
+ # Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
113
+ if (self.cc >= 90):
114
+ if (self.dag_ir.out_degree("D") != 0):
115
+ raise RuntimeError(
116
+ f"On SM90 or higher, D is expected to be a output node with 0 users to "
117
+ f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
118
+
119
+ # Run the passes
120
+ self.pass_manager()
121
+ # Set the epilogue type
122
+ self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
123
+ if cc_map[self.cc] in [90, 100]:
124
+ self.arg_c_type = self.dag_ir.arg_c_type
125
+ self.arg_d_type = self.dag_ir.arg_d_type
126
+ self.reduction_names = self.dag_ir.reduction_names
127
+
128
+ #
129
+ # Helper functions for DAG IR manipulation
130
+ #
131
+
132
+ def add_node(self, node):
133
+ self.dag_ir.add_node(node)
134
+
135
+ def add_edge(self, src, tgt, weight=0):
136
+ self.dag_ir.add_edge(src, tgt, weight=weight)
137
+
138
+ def set_tensor(self, node_name, example):
139
+ """
140
+ Add an example tensor to node {node_name} in the DAG IR
141
+ """
142
+ meta = self.dag_ir.get_node_meta(node_name)
143
+ meta.tensor = {"tensor": example}
144
+
145
+ def set_store_tensor(self, node_name, example):
146
+ """
147
+ Add an example tensor to node {node_name} in the DAG IR
148
+ """
149
+ meta = self.dag_ir.get_node_meta(node_name)
150
+ meta.store_tensor = {"tensor": example}
151
+
152
+ def mark_output(self, node_name):
153
+ """
154
+ Mark a store node as output
155
+ """
156
+ meta = self.dag_ir.get_node_meta(node_name)
157
+ if not isinstance(meta, StoreNode):
158
+ raise ValueError(
159
+ f"Only StoreNodes can be marked as output. "
160
+ f"Got {type(meta).__name__}: {node_name}")
161
+ meta.is_output = True
162
+
163
+ # Add node with specific type
164
+
165
+ def add_load_node(self, name, example):
166
+ """
167
+ Add a Load node to DAG IR
168
+ :param name: name of the loaded variable
169
+ :type name: str
170
+ :param example: example input
171
+ :type example: np.ndarray|torch.Tensor|cupy.ndarray|float
172
+ """
173
+ if name is None:
174
+ raise ValueError(f"Name is not provided.")
175
+ if example is None:
176
+ raise ValueError(f"Example input for {name} is not provided.")
177
+ load_node = LoadNode(name)
178
+ load_node.tensor = {"tensor": example}
179
+ # Special logics for accumulator
180
+ if name == "accum":
181
+ if load_node.tensor.rank == 2:
182
+ new_shape = tuple([1, ] + list(load_node.tensor.shape))
183
+ load_node.tensor.broadcast(new_shape)
184
+ elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
185
+ raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.")
186
+ self.add_node(load_node)
187
+
188
+ def add_imm(self, value: Union[float,int]):
189
+ """
190
+ Add an immediate scalar value to DAG IR
191
+ :param value: the value of the immediate scalar
192
+ :type value: float
193
+ """
194
+ try:
195
+ value = float(value)
196
+ except:
197
+ raise ValueError(f"{type(value).__name__} cannot be converted to float.")
198
+
199
+ name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
200
+ self.imm_cnt += 1
201
+ load_node = LoadNode(name)
202
+ load_node.tensor = {"tensor": value, "is_constant": True}
203
+ self.add_node(load_node)
204
+ return name
205
+
206
+ def add_compute_node(self, op, name=None):
207
+ """
208
+ Add a compute node.
209
+ :param op: the computation op
210
+ :param name: the node name (optional)
211
+ :type name: str
212
+ :return: the name of the compute node
213
+ """
214
+ if name is None:
215
+ name = f"compute_{self.compute_cnt}"
216
+ self.compute_cnt += 1
217
+ compute_node = ComputeNode(
218
+ name=name, fn=op,
219
+ element_output=self.element_compute,
220
+ element_compute=self.element_compute)
221
+ self.add_node(compute_node)
222
+ return compute_node.name
223
+
224
+ def add_layout_node(self, op, kwargs, name=None):
225
+ """
226
+ Add a layout node.
227
+ :param op: the layout op
228
+ :type op: evt_ops
229
+ :param name: the node name (optional)
230
+ :type name: str
231
+ :return: the name of the layout node
232
+ """
233
+ if name is None:
234
+ name = f"layout_{self.layout_cnt}"
235
+ self.layout_cnt += 1
236
+ layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
237
+ self.add_node(layout_node)
238
+ return layout_node.name
239
+
240
+ def add_store_node(self, name):
241
+ store_node = StoreNode(name)
242
+ self.add_node(store_node)
243
+
244
+ #
245
+ # Visualization The DAG IR
246
+ #
247
+
248
+ def visualize(self, name="dag_ir"):
249
+ """
250
+ Visualize the dag ir with svg file
251
+ :param name: the name of the graph
252
+ """
253
+ drawer = EVTGraphDrawer(self.dag_ir, name)
254
+ try:
255
+ for name, graph in drawer.get_dot_graph():
256
+ graph.write_svg(f"./{name}.svg")
257
+ except:
258
+ raise RuntimeError(
259
+ "'dot' is not found in path. GraphDrawer is disabled. "
260
+ "Please install it with 'sudo apt-get install graphviz'."
261
+ )
262
+
263
+ #
264
+ # Get shared memory size
265
+ #
266
+
267
+ def get_smem_size(self, tile_description):
268
+ """
269
+ Get the shared memory size of the epilogue
270
+ """
271
+ smem_size = GetSmemSize(self.dag_ir)(tile_description)
272
+ return smem_size
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Python AST frontend that parses input into DAG IR
35
+ """
36
+
37
+ import ast
38
+ import inspect
39
+ import textwrap
40
+
41
+ from cutlass_library import DataType
42
+
43
+ import cutlass_cppgen
44
+ from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase
45
+ from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
46
+ from cutlass_cppgen.backend.library import FunctionalOp
47
+
48
+
49
+ class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
50
+ def __init__(self, cc, element_compute=DataType.f32, **kwargs):
51
+ super().__init__(cc, element_compute, **kwargs)
52
+ # Flags
53
+ # If this state is True, visit_Constant returns values without creating imm node
54
+ self.no_imm = False
55
+ self.visiting_return = False
56
+
57
+ def parse(self, example_inputs):
58
+ self.example_inputs = example_inputs
59
+ self.source = textwrap.dedent(inspect.getsource(self.__call__))
60
+ self.ast = ast.parse(self.source)
61
+ self.visit(self.ast)
62
+
63
+ #
64
+ # Helper functions
65
+ #
66
+ @staticmethod
67
+ def ast_op_to_bindings(op):
68
+ mapping = {
69
+ ast.Add: FunctionalOp.Plus,
70
+ ast.Sub: FunctionalOp.Minus,
71
+ ast.Mult: FunctionalOp.Multiplies,
72
+ ast.Div: FunctionalOp.Divides,
73
+ "maximum": FunctionalOp.Maximum,
74
+ "minimum": FunctionalOp.Minimum,
75
+ "identity": identity.binding_type,
76
+ "relu": relu.binding_type,
77
+ "tanh": tanh.binding_type,
78
+ "sigmoid": sigmoid.binding_type,
79
+ "silu": silu.binding_type,
80
+ "hardswish": hardswish.binding_type,
81
+ "gelu": gelu.binding_type,
82
+ "multiply_add": FunctionalOp.MultiplyAdd,
83
+ "sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
84
+ "max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
85
+ "exp": FunctionalOp.Exp
86
+ }
87
+ return mapping[op]
88
+
89
+ #
90
+ # Visiting different node types
91
+ #
92
+
93
+ def visit_FunctionDef(self, node: ast.FunctionDef):
94
+ # Visit args and register load nodes
95
+ for arg in node.args.args:
96
+ self.visit(arg)
97
+ for expr in node.body:
98
+ self.visit(expr)
99
+
100
+ def visit_arg(self, node: ast.arg):
101
+ # Name of the argument
102
+ name = node.arg
103
+ try:
104
+ example_tensor = self.example_inputs[name]
105
+ except:
106
+ raise RuntimeError(f"Example input for {name} is not provided.")
107
+
108
+ self.add_load_node(name, example_tensor)
109
+
110
+ def visit_Name(self, node: ast.Name):
111
+ return node.id
112
+
113
+ def visit_Constant(self, node: ast.Constant):
114
+ if self.no_imm:
115
+ return node.value
116
+ else:
117
+ name = self.add_imm(node.value)
118
+ return name
119
+
120
+ def visit_Tuple(self, node: ast.Tuple):
121
+ results = []
122
+ for elt in node.elts:
123
+ results.append(self.visit(elt))
124
+ return tuple(results)
125
+
126
+ def visit_keyword(self, node: ast.keyword):
127
+ return {node.arg: self.visit(node.value)}
128
+
129
+ def visit_BinOp(self, node: ast.BinOp):
130
+ if self.visiting_return:
131
+ raise SyntaxError("Return value cannot be an expression")
132
+ lhs = self.visit(node.left)
133
+ rhs = self.visit(node.right)
134
+ op = self.ast_op_to_bindings(type(node.op))
135
+ name = self.add_compute_node(op)
136
+
137
+ # Add edges
138
+ # The edge weights are used to sort the input args
139
+ self.add_edge(lhs, name, weight=0)
140
+ self.add_edge(rhs, name, weight=1)
141
+ return name
142
+
143
+ def visit_Assign(self, node: ast.BinOp):
144
+ target = self.visit(node.targets[0])
145
+ value = self.visit(node.value)
146
+ # Create the assign node
147
+ self.add_store_node(target)
148
+
149
+ # Add edges
150
+ self.add_edge(value, target)
151
+ return target
152
+
153
+ def visit_Call(self, node: ast.Call):
154
+ if self.visiting_return:
155
+ raise SyntaxError("Return value cannot be an expression")
156
+ func = self.visit(node.func)
157
+ args = [self.visit(arg) for arg in node.args]
158
+
159
+ if func in self.layout_fns.keys():
160
+ # Parse kwargs
161
+ # By default, visiting imm automatically creates a load node
162
+ # However, in function call, keyword args are used to set
163
+ # specific function attributes such as indices for permute
164
+ # So no_imm is set to True temporarily
165
+ self.no_imm = True
166
+ kwargs = {}
167
+ for kw in node.keywords:
168
+ kwargs.update(self.visit(kw))
169
+ self.no_imm = False
170
+ op = self.layout_fns[func]
171
+ name = self.add_layout_node(op, kwargs)
172
+ else:
173
+ op = self.ast_op_to_bindings(func)
174
+ name = self.add_compute_node(op)
175
+
176
+ # Add edges
177
+ for idx, arg in enumerate(args):
178
+ self.add_edge(arg, name, weight=idx)
179
+ return name
180
+
181
+ def visit_Return(self, node: ast.Return):
182
+ self.visiting_return = True
183
+ results = self.visit(node.value)
184
+ self.visiting_return = False
185
+ self.return_names = results
186
+ if not isinstance(results, tuple):
187
+ results = (results,)
188
+ for rst in results:
189
+ try:
190
+ example_tensor = self.example_inputs[rst]
191
+ except:
192
+ raise RuntimeError(f"Example input for {rst} is not provided.")
193
+ self.set_store_tensor(rst, example_tensor)
194
+ self.mark_output(rst)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
34
+ from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
35
+ from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode
36
+ from cutlass_cppgen.backend.evt.ir.load_nodes import (
37
+ LoadNode,
38
+ AccumulatorImpl,
39
+ LoadSrcImpl,
40
+ AuxLoadImpl,
41
+ RowBroadcastImpl,
42
+ ColumnBroadcastImpl,
43
+ ScalarBroadcastImpl
44
+ )
45
+ from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
46
+ from cutlass_cppgen.backend.evt.ir.store_nodes import (
47
+ StoreNode,
48
+ StoreDImpl,
49
+ AuxStoreImpl,
50
+ ColumnReductionImpl,
51
+ RowReductionImpl,
52
+ ScalarReductionImpl
53
+ )
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Python registration for compute nodes in EVT
35
+ """
36
+
37
+ from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
38
+ from cutlass_cppgen.backend.library import FloatRoundStyle
39
+
40
+
41
+ class ComputeImplBase(ImplBase):
42
+ """
43
+ Base class for compute implementation
44
+ """
45
+ def __init__(self, node) -> None:
46
+ super().__init__(node)
47
+
48
+
49
+ class ComputeImpl(ComputeImplBase):
50
+ """
51
+ Implementation for Compute Node
52
+ """
53
+ def __init__(self, node) -> None:
54
+ super().__init__(node)
55
+
56
+ self.fn = node.fn
57
+ self.element_output = node.element_output
58
+ self.element_compute = node.element_compute
59
+ self.round_style = node.round_style
60
+
61
+ @staticmethod
62
+ def match(node, problem_size: tuple):
63
+ return True
64
+
65
+
66
+ class ComputeNode(NodeBase):
67
+ """
68
+ Compute Node in DAG IR
69
+ """
70
+ possible_impls = [
71
+ ComputeImpl
72
+ ]
73
+ def __init__(
74
+ self, name: str, fn, element_output,
75
+ element_compute,
76
+ round_style=FloatRoundStyle.ToNearest) -> None:
77
+ super().__init__(name)
78
+ self.op = "compute"
79
+ self.fn = fn
80
+ self.element_compute = element_compute
81
+ self.round_style = round_style
82
+
83
+ def type_propagation(self, *args, **kwargs):
84
+ """
85
+ Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
86
+ """
87
+ self.element = self.element_compute
88
+ # In general, the compute nodes have element_output = element_compute
89
+ # In certain cases like producer of D it is overwritten by other passes
90
+ if not hasattr(self, "element_output"):
91
+ self.element_output = self.element
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ DAG IR used by Python EVT
35
+ """
36
+
37
+ import networkx as nx
38
+
39
+ from cutlass_library import DataType
40
+
41
+ from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode
42
+ from cutlass_cppgen.backend.evt.ir.node import NodeBase
43
+ from cutlass_cppgen.backend.library import ActivationOp
44
+ from cutlass_cppgen.backend.utils import device_cc
45
+
46
+
47
+ class DAGIR:
48
+ """
49
+ ``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
50
+ It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
51
+
52
+ In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
53
+ """
54
+ def __init__(self, cc, element_compute=DataType.f32) -> None:
55
+ # The EVT DAGIR is managed through the nextworkX Digraph class
56
+ self._graph = nx.DiGraph()
57
+
58
+ self.element_compute = element_compute
59
+
60
+ self.reduction_names = []
61
+
62
+ self.cc = cc
63
+
64
+ self.identity_counter = 0
65
+
66
+ #
67
+ # IR manipulator
68
+ #
69
+
70
+ def add_node(self, meta: NodeBase):
71
+ """
72
+ Add a node to dag ir
73
+ """
74
+ if self.has_node(meta.name):
75
+ raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
76
+ self._graph.add_node(meta.name, meta=meta)
77
+
78
+ def add_edge(self, src: str, dst: str, weight: int=0):
79
+ """
80
+ Add an edge src -> dst to dag ir with weight
81
+ """
82
+ if not self.has_node(src):
83
+ raise SyntaxError(f"Variable '{src}' is undefined.")
84
+ if not self.has_node(dst):
85
+ raise SyntaxError(f"Variable '{dst}' is undefined.")
86
+
87
+ if self._graph.has_edge(src, dst):
88
+ # The DiGraph doesn't support multiple edges between two nodes
89
+ # We insert an identity node in such case as a workaround
90
+ identity_name = f"autogen_identity_{self.identity_counter}"
91
+ self.identity_counter += 1
92
+ compute_node = ComputeNode(
93
+ name=identity_name, fn=ActivationOp.Identity,
94
+ element_output=self.element_compute,
95
+ element_compute=self.element_compute)
96
+ self.add_node(compute_node)
97
+ self.add_edge(src, identity_name, 0)
98
+ self.add_edge(identity_name, dst, weight)
99
+ else:
100
+ self._graph.add_edge(src, dst, weight=weight)
101
+
102
+ def remove_node(self, node: str):
103
+ """
104
+ Remove node from dag ir
105
+ """
106
+ self._graph.remove_node(node)
107
+
108
+ def remove_edge(self, src: str, dst: str):
109
+ """
110
+ Remove edge src -> dst
111
+ """
112
+ self._graph.remove_edge(src, dst)
113
+
114
+ #
115
+ # Helper functions for getting attrs
116
+ #
117
+
118
+ def has_node(self, node: str) -> bool:
119
+ """
120
+ Check if the node is in the graph
121
+ """
122
+ return self._graph.has_node(node)
123
+
124
+ def in_degree(self, node: str):
125
+ """
126
+ Get the input degree of node
127
+ """
128
+ return self._graph.in_degree(node)
129
+
130
+ def in_edges(self, node: str):
131
+ """
132
+ Get the input edges of node
133
+ """
134
+ return [edge for edge in self._graph.in_edges(node)]
135
+
136
+ def out_degree(self, node: str):
137
+ """
138
+ Get the output degree of node
139
+ """
140
+ return self._graph.out_degree(node)
141
+
142
+ def out_edges(self, node: str):
143
+ """
144
+ Get the output edges of node
145
+ """
146
+ return [edge for edge in self._graph.out_edges(node)]
147
+
148
+ def get_node_meta(self, node: str):
149
+ """
150
+ Get the meta data of the node
151
+ """
152
+ return self._graph.nodes[node]["meta"]
153
+
154
+ def get_edge_weight(self, src, dst):
155
+ """
156
+ Get the edge weight of edge src->dst
157
+ """
158
+ return self._graph.get_edge_data(src, dst)["weight"]
159
+
160
+ #
161
+ # High-level helper functions
162
+ #
163
+
164
+ def all_reachable_nodes(self, node: str):
165
+ """
166
+ Get all the nodes reachable from the current node (exclude)
167
+ """
168
+ return list(nx.dfs_preorder_nodes(self._graph, source=node))
169
+
170
+ def get_users(self, node: str):
171
+ """
172
+ Get all users of the current node
173
+ """
174
+ return [edge[1] for edge in self.out_edges(node)]
175
+
176
+ def get_all_inputs(self, node: str):
177
+ """
178
+ Get all the input nodes sorted by edge weight
179
+ """
180
+ in_edges = self.in_edges(node)
181
+ edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
182
+ return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
183
+
184
+ def get_all_inputs_meta(self, node: str):
185
+ """
186
+ Get all the input node metas sorted by edge weight
187
+ """
188
+ return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)]
189
+
190
+ def replace_all_uses_with(self, node1, node2):
191
+ """
192
+ Replace all uses of node1 with node2
193
+ """
194
+ for edge in self.out_edges(node1):
195
+ weight = self.get_edge_weight(*edge)
196
+ user = edge[1]
197
+ self.add_edge(node2, user, weight)
198
+ self.remove_edge(node1, user)
199
+ self.remove_node(node1)
200
+
201
+ #
202
+ # Node accessor
203
+ #
204
+ def nodes_topological_order(self):
205
+ """
206
+ Get the nodes in the unique lexicographical topological order
207
+ It generates a unique ordering of nodes by first sorting topologically
208
+ and then additionally by sorting lexicographically.
209
+
210
+ Although topological_sort alone also works, this generates a unique key
211
+ for each epilogue visitor pattern and ensures the compilation cache can be reused.
212
+ :return: list[str]
213
+ """
214
+ return list(nx.lexicographical_topological_sort(self._graph))
215
+
216
+ def node_metas_topological_order(self):
217
+ """
218
+ Get the node metas in topological order
219
+ :return: list[NodeBase]
220
+ """
221
+ return [self.get_node_meta(node) for node in self.nodes_topological_order()]
222
+
223
+ @property
224
+ def nodes(self):
225
+ """
226
+ Get all nodes
227
+ :return: list[str]
228
+ """
229
+ return list(self._graph.nodes)
230
+
231
+ @property
232
+ def nodes_meta(self):
233
+ """
234
+ Get all node metas
235
+ :return: list[NodeBase]
236
+ """
237
+ return [data[1]['meta'] for data in self._graph.nodes.data()]
238
+
239
+ @property
240
+ def edges(self):
241
+ """
242
+ Get all edges
243
+ :return: list[(str, str)]
244
+ """
245
+ return list(self._graph.edges)
246
+
247
+ #
248
+ # Path
249
+ #
250
+ def has_path(self, src: str, target: str) -> bool:
251
+ """
252
+ Return True is a path exists from src to target
253
+ """
254
+ return nx.has_path(self._graph, src, target)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Layout algebras
35
+ """
36
+
37
+ from pycute import Layout, composition, make_layout, flatten, product
38
+
39
+
40
+ def _infer_split(old_shape, new_shape):
41
+ old_shape = _tuple_to_list(old_shape)
42
+ new_shape = _tuple_to_list(new_shape)
43
+ if len(old_shape) == 0 and len(new_shape) == 0:
44
+ return []
45
+ if len(old_shape) == 0:
46
+ if product(tuple(new_shape)) != 1:
47
+ raise ValueError("Invalid reshape size")
48
+ else:
49
+ return new_shape
50
+ if len(new_shape) == 0:
51
+ if product(tuple(old_shape)) != 1:
52
+ raise ValueError("Invalid reshape size")
53
+ else:
54
+ return old_shape
55
+ # This is done recursively by only process the last dimension at each time
56
+ old_dim = old_shape[-1]
57
+ new_dim = new_shape[-1]
58
+ # Exact match
59
+ if old_dim == new_dim:
60
+ return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,]
61
+ # Needs split
62
+ if old_dim > new_dim and old_dim % new_dim == 0:
63
+ residual = old_dim // new_dim
64
+ return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,]
65
+ # Needs merge
66
+ if old_dim < new_dim and new_dim % old_dim == 0:
67
+ residual = new_dim // old_dim
68
+ return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,]
69
+
70
+ raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
71
+
72
+ def _infer_merge(flatten_shape, shape):
73
+ flatten_shape = _tuple_to_list(flatten_shape)
74
+ shape = _tuple_to_list(shape)
75
+ idx_flat = 0
76
+ merged_shape = []
77
+ for dim in shape:
78
+ # Exact match
79
+ if dim == flatten_shape[idx_flat]:
80
+ merged_shape.append(dim)
81
+ idx_flat += 1
82
+ # Need group
83
+ elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
84
+ residual = dim
85
+ group = []
86
+ while(residual > 1):
87
+ group.append(flatten_shape[idx_flat])
88
+ residual = residual // flatten_shape[idx_flat]
89
+ idx_flat += 1
90
+ merged_shape.append(group)
91
+ else:
92
+ raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
93
+
94
+ return merged_shape
95
+
96
+ def _list_to_tuple(nested_list):
97
+ if isinstance(nested_list, list) or isinstance(nested_list, tuple):
98
+ return tuple(_list_to_tuple(item) for item in nested_list)
99
+ return nested_list
100
+
101
+ def _tuple_to_list(nested_tuple):
102
+ if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
103
+ return list(_tuple_to_list(item) for item in nested_tuple)
104
+ return nested_tuple
105
+
106
+ def _reverse_tuple(nested_tuple: tuple):
107
+ if isinstance(nested_tuple, tuple):
108
+ return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
109
+ return nested_tuple
110
+
111
+ def _get_first_lhs_nonzero_stride(stride_list, idx):
112
+ for i in reversed(range(idx)):
113
+ if stride_list[i] != 0:
114
+ return i
115
+ else:
116
+ return None
117
+
118
+ def _get_first_rhs_nonzero_stride(stride_list, idx):
119
+ for i in range(idx+1, len(stride_list)):
120
+ if stride_list[i] != 0:
121
+ return i
122
+ else:
123
+ return None
124
+
125
+ def reshape(layout, new_shape):
126
+ """
127
+ General reshape of input layout.
128
+ It takes two steps:
129
+ 1. split the dimensions of the old layout
130
+ 2. merge the splitted dimensions according to the new shape
131
+ """
132
+ #
133
+ # Step 1: Split the dimensions of the old layout
134
+ #
135
+ # 1.1 Flat old and new shape
136
+ old_flatten_shape = list(flatten(layout.shape))
137
+ new_flatten_shape = list(flatten(new_shape))
138
+
139
+ # 1.2 Infer the flatten splitted shape
140
+ splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
141
+
142
+ # 1.3 Unflat the splitted shape based on the old shape
143
+ splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
144
+
145
+ # 1.4 Infer the type of each split
146
+ # If the split type is in row-major (R), the dimension list is reversed because
147
+ # the cute::composition only support column-major split
148
+ split_type = [] # the type of each split (ColumnMajor or RowMajor)
149
+ permuted_splitted_shape = []
150
+ old_flatten_stride = list(flatten(layout.stride))
151
+ for idx, dim in enumerate(splited_shape):
152
+ if not isinstance(dim, list):
153
+ permuted_splitted_shape.append(dim)
154
+ split_type.append("C")
155
+ else:
156
+ lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
157
+ rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
158
+ # Special case for single tuple
159
+ # Use column-major by default
160
+ if lhs_stride is None and rhs_stride is None:
161
+ permuted_splitted_shape.append(dim)
162
+ split_type.append("C")
163
+ else:
164
+ if lhs_stride is not None and rhs_stride is not None:
165
+ # We consider shape[idx]:stride[idx]
166
+ # Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
167
+ if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
168
+ permuted_splitted_shape.append(dim)
169
+ split_type.append("C")
170
+ # Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
171
+ elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
172
+ permuted_splitted_shape.append([d for d in reversed(dim)])
173
+ split_type.append("R")
174
+ # Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
175
+ elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
176
+ if lhs_stride >= rhs_stride:
177
+ permuted_splitted_shape.append(dim)
178
+ split_type.append("C")
179
+ else:
180
+ permuted_splitted_shape.append([d for d in reversed(dim)])
181
+ split_type.append("R")
182
+ # Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
183
+ elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
184
+ if lhs_stride >= rhs_stride:
185
+ permuted_splitted_shape.append(dim)
186
+ split_type.append("C")
187
+ else:
188
+ permuted_splitted_shape.append([d for d in reversed(dim)])
189
+ split_type.append("R")
190
+ else:
191
+ raise NotImplementedError()
192
+ elif lhs_stride is None:
193
+ # Case 1: dim's stride < dim+1's stride, expand in column major
194
+ if old_flatten_stride[idx] > rhs_stride:
195
+ permuted_splitted_shape.append([d for d in reversed(dim)])
196
+ split_type.append("R")
197
+ else:
198
+ permuted_splitted_shape.append(dim)
199
+ split_type.append("C")
200
+ else:
201
+ # Case 1: dim's stride > dim-1's stride
202
+ if old_flatten_stride[idx] < lhs_stride:
203
+ permuted_splitted_shape.append([d for d in reversed(dim)])
204
+ split_type.append("R")
205
+ else:
206
+ permuted_splitted_shape.append(dim)
207
+ split_type.append("C")
208
+
209
+ # 1.4 Generate the splitted layout
210
+ permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape)))
211
+
212
+ # 1.5 Reverse the permutation in 1.4 before merge
213
+ splitted_shape = []
214
+ splitted_stride = []
215
+ for shape_dim, stride_dim, type in zip(
216
+ permuted_splitted_layout.shape,
217
+ permuted_splitted_layout.stride,
218
+ split_type):
219
+ if type == "C":
220
+ splitted_shape.append(shape_dim)
221
+ splitted_stride.append(stride_dim)
222
+ else:
223
+ splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
224
+ splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
225
+ splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
226
+
227
+
228
+ #
229
+ # Step 2: Merge the splitted dimensions according to the new shape
230
+ #
231
+ # 2.1 Merge layout
232
+ merged_layout = composition(splitted_layout, Layout(new_shape))
233
+
234
+ # 2.2 Cleaning up
235
+ output_layout = composition(merged_layout, Layout(new_shape))
236
+ return output_layout
237
+
238
+
239
+ def permutation(layout, permutation):
240
+ """
241
+ Permute the layout
242
+ """
243
+ new_shape = tuple([layout.shape[idx] for idx in permutation])
244
+ new_stride = tuple([layout.stride[idx] for idx in permutation])
245
+ return Layout(new_shape, new_stride)
246
+
247
+
248
+ def _broadcast(layout, new_shape):
249
+ if len(layout) == 1 and isinstance(new_shape, int):
250
+ old_dim = layout.shape
251
+ old_stride = layout.stride
252
+ new_dim = new_shape
253
+ if old_dim == new_dim:
254
+ return Layout(old_dim, old_stride)
255
+ elif old_dim == 1:
256
+ return Layout(new_dim, 0)
257
+ else:
258
+ raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
259
+
260
+ # Align the dimensions
261
+ old_shape = layout.shape
262
+ if isinstance(old_shape, int):
263
+ old_shape = (old_shape,)
264
+ sub_layouts = [layout,]
265
+ else:
266
+ sub_layouts = [sub_layout for sub_layout in layout]
267
+ rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
268
+ # Get the broadcasted layout
269
+ broadcast_layouts = []
270
+ try:
271
+ layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
272
+ broadcast_layouts = []
273
+ for idx, sub_layout in enumerate(layout):
274
+ broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
275
+ except NotImplementedError:
276
+ layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
277
+ for idx, sub_layout in enumerate(layout):
278
+ broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
279
+ return make_layout(*broadcast_layouts)
280
+
281
+
282
+ def broadcast(layout, new_shape):
283
+ """
284
+ Broadcast the new layout based on the input shape
285
+ The broadcasted shape equals to the new shape
286
+ The stride of broadcasted dimensions are 0
287
+ """
288
+ return _broadcast(layout, new_shape)
289
+
290
+
291
+ def debroadcast(layout, dims):
292
+ """
293
+ Squeeze the 0-stride
294
+ """
295
+ for dim in dims:
296
+ if layout.stride[dim] != 0:
297
+ raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}")
298
+ new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
299
+ new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
300
+ return Layout(new_shape, new_stride)
301
+
302
+
303
+ def canonicalization_(shapes, strides):
304
+ if isinstance(shapes, tuple):
305
+ c_shapes = []
306
+ c_strides = []
307
+ for shape, stride in zip(shapes, strides):
308
+ c_shape, c_stride = canonicalization_(shape, stride)
309
+ c_shapes.append(c_shape)
310
+ c_strides.append(c_stride)
311
+ return tuple(c_shapes), tuple(c_strides)
312
+ else:
313
+ if shapes == 1:
314
+ return 1, 0
315
+ else:
316
+ return shapes, strides
317
+
318
+ def canonicalization(layout):
319
+ """
320
+ Canonicalize the input layout
321
+ 1. set the stride of shape "1" to 0
322
+ """
323
+ new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
324
+ return Layout(new_shape, new_stride)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Layout manipulation nodes and implementations
35
+
36
+ The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
37
+ """
38
+
39
+ from copy import deepcopy
40
+
41
+ from cutlass_library import LayoutType
42
+ from pycute import product, flatten
43
+
44
+ import cutlass_cppgen
45
+ from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
46
+ from cutlass_cppgen.backend.evt.ir.node import NodeBase
47
+ from cutlass_cppgen.backend.evt.ir.tensor import Tensor
48
+
49
+
50
+ class PermutationImpl:
51
+ """
52
+ Detailed implementation and helper functions for permutation
53
+ """
54
+ def __init__(self, node) -> None:
55
+ assert "indices" in node.kwargs.keys()
56
+ self.indices = list(node.kwargs["indices"])
57
+ self.inverse_indices = self.get_inverse_indices(self.indices)
58
+
59
+ def get_inverse_impl(self):
60
+ inverse_impl = deepcopy(self)
61
+ inverse_impl.indices = self.inverse_indices
62
+ inverse_impl.inverse_indices = self.indices
63
+ return inverse_impl
64
+
65
+ def update(self, shape):
66
+ num_dim = len(shape)
67
+ indices = self.indices
68
+ num_old_dim = len(indices)
69
+ # Add offset
70
+ for i, idx in enumerate(indices):
71
+ indices[i] = idx + num_dim - num_old_dim
72
+ # Add broadcast dims
73
+ for i in range(num_dim - num_old_dim):
74
+ indices = [i,] + indices
75
+
76
+ self.indices = indices
77
+ self.inverse_indices = self.get_inverse_indices(self.indices)
78
+
79
+ def get_inverse_indices(self, indices):
80
+ """
81
+ Get the indices for inverse permutation
82
+ """
83
+ num_dim = len(indices)
84
+ inverse_indices = [0] * num_dim
85
+ for i in range(num_dim):
86
+ inverse_indices[indices[i]] = i
87
+ return inverse_indices
88
+
89
+ def shape_propagation(self, input_node_meta):
90
+ input_shape = input_node_meta.tensor.shape
91
+ output_shape = tuple([input_shape[idx] for idx in self.indices])
92
+ return output_shape
93
+
94
+ def broadcast(self, shape, node_meta: NodeBase):
95
+ """
96
+ Broadcast the inputs based on current shape
97
+ """
98
+ self.update(shape)
99
+ inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
100
+ node_meta.tensor.broadcast(inverse_shape)
101
+
102
+ def apply_to_user(self, usr_meta: NodeBase):
103
+ """
104
+ Propagate the permutation to the users of the current nodes
105
+ """
106
+ usr_meta.tensor.permute(self.inverse_indices)
107
+ if hasattr(usr_meta, "store_tensor"):
108
+ if usr_meta.store_tensor is not None:
109
+ usr_meta.store_tensor.permute(self.inverse_indices)
110
+
111
+ def apply_to_input(self, input_meta: NodeBase):
112
+ """
113
+ Propagate the permutation to inputs of the current nodes
114
+ """
115
+ input_meta.tensor.permute(self.indices)
116
+ if hasattr(input_meta, "store_tensor"):
117
+ if input_meta.store_tensor is not None:
118
+ input_meta.store_tensor.permute(self.indices)
119
+
120
+
121
+ class ReshapeImpl:
122
+ """
123
+ Detailed implementation and helper functions for reshape
124
+ """
125
+ def __init__(self, node) -> None:
126
+ self.node = node
127
+ assert "new_shape" in node.kwargs.keys()
128
+ self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
129
+
130
+ def get_inverse_impl(self):
131
+ inverse_impl = deepcopy(self)
132
+ inverse_impl.output_shape = self.input_shape
133
+ inverse_impl.input_shape = self.output_shape
134
+ return inverse_impl
135
+
136
+ def shape_propagation(self, input_node_meta):
137
+ self.input_shape = input_node_meta.tensor.shape
138
+ return _list_to_tuple(self.output_shape)
139
+
140
+ def broadcast(self, shape, node_meta: NodeBase):
141
+ """
142
+ Broadcast the inputs based on current shape.
143
+ """
144
+ # Step 1: infer split
145
+ flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
146
+ split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
147
+ split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
148
+
149
+ # broadcast shape -> split_output_shape -> flatten_split_shape
150
+ if len(shape) - len(split_output_shape) > 0:
151
+ for _ in range(len(shape) - len(split_output_shape)):
152
+ split_output_shape = [1,] + split_output_shape
153
+ flatten_split_shape = [1,] + flatten_split_shape
154
+ split_input_shape = [1,] + split_input_shape
155
+ broadcast_factor = []
156
+ for dim, old_dim in zip(shape, split_output_shape):
157
+ if not isinstance(dim, list):
158
+ dim = [dim,]
159
+ if not isinstance(old_dim, list):
160
+ old_dim = [old_dim,]
161
+ if product(tuple(dim)) == product(tuple(old_dim)):
162
+ broadcast_factor += [1] * len(old_dim)
163
+ elif product(tuple(old_dim)) == 1:
164
+ assert len(dim) == 1
165
+ broadcast_factor.append(dim[0])
166
+ else:
167
+ raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
168
+
169
+ # flatten_split_shape -> split_input_shape
170
+ factor_idx = 0
171
+ broadcast_split_input_shape = []
172
+ for dim in split_input_shape:
173
+ if isinstance(dim, list):
174
+ new_dim = []
175
+ for d in dim:
176
+ new_dim.append(d * broadcast_factor[factor_idx])
177
+ factor_idx += 1
178
+ broadcast_split_input_shape.append(new_dim)
179
+ else:
180
+ broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
181
+ factor_idx += 1
182
+ broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
183
+ node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
184
+ node_meta.tensor.broadcast(broadcast_split_input_shape)
185
+ # Last reshape op to clean up
186
+ broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
187
+ node_meta.tensor.reshape(broadcast_input_shape)
188
+ # Update the input shape and output shape
189
+ self.input_shape = _list_to_tuple(node_meta.tensor.shape)
190
+ self.output_shape = _list_to_tuple(shape)
191
+
192
+ def apply_to_user(self, user_meta: NodeBase):
193
+ """
194
+ Propagate the reshape to user nodes
195
+ """
196
+ user_meta.tensor.reshape(tuple(self.input_shape))
197
+ if hasattr(user_meta, "store_tensor"):
198
+ if user_meta.store_tensor is not None:
199
+ user_meta.store_tensor.reshape(tuple(self.input_shape))
200
+
201
+ def apply_to_input(self, input_meta: NodeBase):
202
+ """
203
+ Propagate the reshape to input nodes
204
+ """
205
+ input_meta.tensor.reshape(tuple(self.output_shape))
206
+ if hasattr(input_meta, "store_tensor"):
207
+ if input_meta.store_tensor is not None:
208
+ input_meta.store_tensor.reshape(tuple(self.output_shape))
209
+
210
+ #
211
+ # Helper functions
212
+ #
213
+
214
+ def infer_split(self, input_shape, output_shape):
215
+ """
216
+ Infer the flatten splitted shape that can be merged to both input_shape and output_shape
217
+ """
218
+ input_shape = _tuple_to_list(input_shape)
219
+ output_shape = _tuple_to_list(output_shape)
220
+ if len(input_shape) == 0 and len(output_shape) == 0:
221
+ return []
222
+ if len(input_shape) == 0:
223
+ if product(tuple(output_shape)) != 1:
224
+ raise ValueError("Invalid reshape size")
225
+ else:
226
+ return output_shape
227
+ if len(output_shape) == 0:
228
+ if product(tuple(input_shape)) != 1:
229
+ raise ValueError("Invalid reshape size")
230
+ else:
231
+ return input_shape
232
+ # This is done recursively by only process the last dimension at each time
233
+ old_dim = input_shape[-1]
234
+ new_dim = output_shape[-1]
235
+ # Exact match
236
+ if old_dim == new_dim:
237
+ return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
238
+ # Needs split
239
+ if old_dim > new_dim and old_dim % new_dim == 0:
240
+ residual = old_dim // new_dim
241
+ return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
242
+ # Needs merge
243
+ if old_dim < new_dim and new_dim % old_dim == 0:
244
+ residual = new_dim // old_dim
245
+ return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
246
+
247
+ raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
248
+
249
+ def infer_merge(self, flatten_shape, shape):
250
+ flatten_shape = _tuple_to_list(flatten_shape)
251
+ shape = _tuple_to_list(shape)
252
+ idx_flat = len(flatten_shape) - 1
253
+ merged_shape = []
254
+ for dim in reversed(shape):
255
+ # Exact match
256
+ if dim == flatten_shape[idx_flat]:
257
+ merged_shape.append(dim)
258
+ idx_flat -= 1
259
+ # need group
260
+ elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
261
+ residual = dim
262
+ group = []
263
+ while(residual > 1):
264
+ group.append(flatten_shape[idx_flat])
265
+ residual = residual // flatten_shape[idx_flat]
266
+ idx_flat -= 1
267
+ merged_shape.append(group[::-1])
268
+ else:
269
+ raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
270
+
271
+ return merged_shape[::-1]
272
+
273
+
274
+ class LayoutNode(NodeBase):
275
+ """
276
+ Layout manipulation nodes
277
+ """
278
+ fn_to_impl = {
279
+ "permute": PermutationImpl,
280
+ "reshape": ReshapeImpl
281
+ }
282
+ def __init__(self, name: str, fn, kwargs: dict) -> None:
283
+ super().__init__(name)
284
+ self.op = "layout"
285
+ self.fn = fn
286
+ self.kwargs = kwargs
287
+ self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
288
+
289
+ def get_inverse_node(self):
290
+ inverse_node = deepcopy(self)
291
+ inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
292
+ return inverse_node
293
+
294
+ def shape_propagation(self, input_node_metas):
295
+ if self._tensor is not None:
296
+ return
297
+ assert len(input_node_metas) == 1, "Layout node can only have one input node"
298
+
299
+ output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
300
+
301
+ self._tensor = Tensor(
302
+ element=self.element_output,
303
+ shape=output_shape, layout_tag=LayoutType.RowMajor
304
+ )
305
+
306
+ return super().shape_propagation(input_node_metas)
307
+
308
+ def type_propagation(self, input_node_metas: 'list[NodeBase]'):
309
+ """
310
+ The store nodes has element_output = element_input
311
+ """
312
+ assert len(input_node_metas) == 1, "Layout node can only have one input node"
313
+ self.element_output = input_node_metas[0].element_output
314
+
315
+ def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
316
+ """
317
+ Propagate the broadcast in the reversed topological order
318
+ """
319
+ if self.tensor is None:
320
+ raise RuntimeError(f"The tensor of node {self.name} is unknown.")
321
+ shape = self.tensor.shape
322
+
323
+ for child in input_node_metas:
324
+ self.underlying_impl.broadcast(shape, child)
325
+
326
+ def apply_to_user(self, usr_meta: NodeBase):
327
+ """
328
+ Propagate the permutation to user nodes
329
+ """
330
+ self.underlying_impl.apply_to_user(usr_meta)
331
+
332
+ def apply_to_input(self, input_meta: NodeBase):
333
+ """
334
+ Propagate the permutation to input nodes
335
+ """
336
+ self.underlying_impl.apply_to_input(input_meta)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Load nodes and implementations
35
+ """
36
+
37
+ import ctypes
38
+
39
+ from cutlass_cppgen.backend.c_types import tuple_factory
40
+ from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
41
+ from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
42
+
43
+
44
+ class LoadImplBase(ImplBase):
45
+ """
46
+ Base class for load node implementations
47
+ """
48
+ reserved_names = ["accum", "C"]
49
+ def __init__(self, node) -> None:
50
+ super().__init__(node)
51
+ self.element = node.element
52
+ self.element_output = node.element_output
53
+ self.stride = node.tensor.stride
54
+
55
+
56
+ class AccumulatorImpl(LoadImplBase):
57
+ """
58
+ Accumulator node implementation
59
+ """
60
+
61
+ @staticmethod
62
+ def match(node, problem_size: tuple):
63
+ return node.name == "accum" and node.tensor.shape == problem_size
64
+
65
+
66
+ class LoadSrcImpl(LoadImplBase):
67
+ """
68
+ Load C implementation
69
+ """
70
+ @property
71
+ def name_camel(self) -> str:
72
+ return "TensorC"
73
+
74
+ @property
75
+ def argument_type_c(self):
76
+ stride_mnl = self.get_stride_mnl()
77
+ tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
78
+ class _Argument(ctypes.Structure):
79
+ _fields_ = [
80
+ ("ptr_C", ctypes.c_void_p),
81
+ ("stride_C", tuple_type)
82
+ ]
83
+ def __init__(self, ptr) -> None:
84
+ self.ptr_C = ptr
85
+ self.stride_C = tuple_type(stride_mnl)
86
+
87
+ return _Argument
88
+
89
+ @staticmethod
90
+ def match(node, problem_size: tuple):
91
+ return node.name == "C" and node.tensor.shape == problem_size
92
+
93
+
94
+ class AuxLoadImpl(LoadImplBase):
95
+ """
96
+ Load arbitrary tensor
97
+ """
98
+ @property
99
+ def argument_type(self):
100
+ stride_mnl = self.get_stride_mnl()
101
+ name = self.name
102
+ tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
103
+ element_type = self.element
104
+ class _Argument(ctypes.Structure):
105
+ _fields_ = [
106
+ ("ptr_aux", ctypes.c_void_p),
107
+ ("null_default", dtype2ctype[element_type]),
108
+ ("dAux", tuple_type)
109
+ ]
110
+ def __init__(self, kwargs) -> None:
111
+ ptr = kwargs[name]
112
+ self.ptr_aux = ptr
113
+ self.null_default = to_ctype_value(0, element_type)
114
+ self.dAux = tuple_type(stride_mnl)
115
+
116
+ return _Argument
117
+
118
+ @staticmethod
119
+ def match(node, problem_size: tuple):
120
+ if node.name in LoadImplBase.reserved_names:
121
+ return False
122
+ strideMN = node.tensor.stride[-2:]
123
+ if (strideMN[0] == 1 and strideMN[1] != 0 or
124
+ strideMN[0] != 0 and strideMN[1] == 1 ):
125
+ return True
126
+ else:
127
+ return False
128
+
129
+
130
+ class RowBroadcastImpl(LoadImplBase):
131
+ """
132
+ Broadcast a row vector
133
+ """
134
+ def __init__(self, node) -> None:
135
+ super().__init__(node)
136
+ self.stride_dtype = "int"
137
+
138
+ @property
139
+ def argument_type(self):
140
+ stride_mnl = self.get_stride_mnl()
141
+ name = self.name
142
+ tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
143
+ element_type = self.element
144
+ class _Argument(ctypes.Structure):
145
+ _fields_ = [
146
+ ("ptr_row", ctypes.c_void_p),
147
+ ("null_default", dtype2ctype[element_type]),
148
+ ("dRow", tuple_type)
149
+ ]
150
+ def __init__(self, kwargs) -> None:
151
+ ptr = kwargs[name]
152
+ self.ptr_row = ptr
153
+ self.null_default = to_ctype_value(0, element_type)
154
+ self.dRow = tuple_type(stride_mnl)
155
+
156
+ return _Argument
157
+
158
+ @staticmethod
159
+ def match(node, problem_size: tuple):
160
+ if node.name in LoadImplBase.reserved_names:
161
+ return False
162
+
163
+ strideMN = node.tensor.stride[-2:]
164
+ if strideMN == (0, 1):
165
+ return True
166
+ else:
167
+ return False
168
+
169
+
170
+ class ColumnBroadcastImpl(LoadImplBase):
171
+ """
172
+ Broadcast a column vector
173
+ """
174
+ def __init__(self, node) -> None:
175
+ super().__init__(node)
176
+ self.stride_dtype = "int"
177
+
178
+ @property
179
+ def argument_type(self):
180
+ stride_mnl = self.get_stride_mnl()
181
+ name = self.name
182
+ tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
183
+ element_type = self.element
184
+ class _Argument(ctypes.Structure):
185
+ _fields_ = [
186
+ ("ptr_col", ctypes.c_void_p),
187
+ ("null_default", dtype2ctype[element_type]),
188
+ ("dCol", tuple_type)
189
+ ]
190
+ def __init__(self, kwargs) -> None:
191
+ ptr = kwargs[name]
192
+ self.ptr_col = int(ptr)
193
+ self.null_default = to_ctype_value(0, element_type)
194
+ self.dCol = tuple_type(stride_mnl)
195
+
196
+ return _Argument
197
+
198
+ @staticmethod
199
+ def match(node, problem_size: tuple):
200
+ if node.name in LoadImplBase.reserved_names:
201
+ return False
202
+
203
+ strideMN = node.tensor.stride[-2:]
204
+ if strideMN == (1, 0):
205
+ return True
206
+ else:
207
+ return False
208
+
209
+
210
+ class ScalarBroadcastImpl(LoadImplBase):
211
+ """
212
+ Broadcast a scalar
213
+ """
214
+ def __init__(self, node) -> None:
215
+ super().__init__(node)
216
+ self.stride_dtype = "int"
217
+
218
+ @property
219
+ def argument_type(self):
220
+ stride_mnl = self.get_stride_mnl()
221
+ name = self.name
222
+ tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
223
+ element_type = self.element
224
+
225
+ if self.tensor.is_constant:
226
+ value = self.tensor.value
227
+ class _Argument(ctypes.Structure):
228
+ _fields_ = [
229
+ ("scalars", dtype2ctype[element_type]),
230
+ ("scalar_ptrs", ctypes.c_void_p),
231
+ ("dScalar", tuple_type)
232
+ ]
233
+ def __init__(self, kwargs) -> None:
234
+ self.scalars = to_ctype_value(value, element_type)
235
+ self.scalar_ptrs = 0
236
+ self.dScalar = tuple_type(stride_mnl)
237
+
238
+ else:
239
+ class _Argument(ctypes.Structure):
240
+ _fields_ = [
241
+ ("scalars", dtype2ctype[element_type]),
242
+ ("scalar_ptrs", ctypes.c_void_p),
243
+ ("dScalar", tuple_type)
244
+ ]
245
+ def __init__(self, kwargs) -> None:
246
+ scalar_or_ptr = kwargs[name]
247
+ if isinstance(scalar_or_ptr, float):
248
+ self.scalars = to_ctype_value(scalar_or_ptr, element_type)
249
+ self.scalar_ptrs = 0
250
+ else:
251
+ self.scalar_ptrs = int(scalar_or_ptr)
252
+
253
+ self.dScalar = tuple_type(stride_mnl)
254
+
255
+ return _Argument
256
+
257
+ @staticmethod
258
+ def match(node, problem_size: tuple):
259
+ if node.name in LoadImplBase.reserved_names:
260
+ return False
261
+
262
+ strideMN = node.tensor.stride[-2:]
263
+ if strideMN == (0, 0):
264
+ return True
265
+ else:
266
+ return False
267
+
268
+
269
+ class LoadNode(NodeBase):
270
+ """
271
+ Load Node
272
+ """
273
+ cnt = 0
274
+ possible_impls = [
275
+ AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
276
+ RowBroadcastImpl, ColumnBroadcastImpl,
277
+ ScalarBroadcastImpl
278
+ ]
279
+ def __init__(self, name: str) -> None:
280
+ if name is None:
281
+ name = f"load{LoadNode.cnt}"
282
+ LoadNode.cnt += 1
283
+ super().__init__(name)
284
+ self.op = "load"
285
+
286
+ def type_propagation(self, *args, **kwargs):
287
+ """
288
+ Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
289
+ """
290
+ if self.tensor is None:
291
+ raise RuntimeError(f"The tensor of node {self.name} is unknown.")
292
+
293
+ self.element = self.tensor.element
294
+ self.element_output = self.tensor.element
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Base & visitor classes of DAGIR Nodes
35
+ """
36
+
37
+ import ctypes
38
+ from re import sub
39
+
40
+ from cutlass_library import LayoutType
41
+
42
+ from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
43
+ from cutlass_cppgen.backend.evt.ir.tensor import Tensor
44
+
45
+
46
+ class TupleEmitter:
47
+ """
48
+ Emit the cute tuple to C++ code
49
+ """
50
+ def __init__(self, stride_dtype):
51
+ self.stride_dtype = stride_dtype
52
+
53
+ def emit(self, py_tuple):
54
+ if isinstance(py_tuple, int):
55
+ if py_tuple in [0, 1]:
56
+ return f"cute::Int<{py_tuple}>"
57
+ else:
58
+ return f"{self.stride_dtype}"
59
+ elif isinstance(py_tuple, tuple):
60
+ decl = "cute::Stride<"
61
+ for item in py_tuple:
62
+ decl += self.emit(item) + ", "
63
+ return decl[:-2] + ">"
64
+ else:
65
+ raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}")
66
+
67
+
68
+ class ImplBase:
69
+ """
70
+ Base class for Node Implementation
71
+ """
72
+ def __init__(self, node) -> None:
73
+ self.node = node
74
+ self.name = node.name
75
+ self.tensor = node.tensor
76
+ self._type_decl = None
77
+ self.tuple_emitter = TupleEmitter("int64_t")
78
+
79
+ @property
80
+ def stride_dtype(self):
81
+ return self.tuple_emitter.stride_dtype
82
+
83
+ @stride_dtype.setter
84
+ def stride_dtype(self, stride_dtype):
85
+ self.tuple_emitter.stride_dtype = stride_dtype
86
+
87
+ @staticmethod
88
+ def match(node, problem_size: tuple):
89
+ """
90
+ Match function used in get_underlying_impl
91
+ """
92
+ raise NotImplementedError(f"The `match` function is not defined.")
93
+
94
+ @property
95
+ def argument_type(self):
96
+ """
97
+ Default class for Argument Type
98
+ """
99
+ class _Argument(ctypes.Structure):
100
+ _fields_ = []
101
+
102
+ def __init__(self, *args, **kwargs) -> None:
103
+ pass
104
+
105
+ return _Argument
106
+
107
+ @property
108
+ def name_camel(self) -> str:
109
+ """
110
+ Return the CamelCase name.
111
+ """
112
+ return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
113
+
114
+ @property
115
+ def stride_mnl(self):
116
+ """
117
+ Typename StrideMNL
118
+ """
119
+ stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
120
+ return self.tuple_emitter.emit(stride)
121
+
122
+ def get_non_constant_stride(self, py_tuple):
123
+ if isinstance(py_tuple, int):
124
+ if py_tuple not in [0, 1]:
125
+ return py_tuple
126
+ else:
127
+ return None
128
+ non_constant_stride = []
129
+ for item in py_tuple:
130
+ item_out = self.get_non_constant_stride(item)
131
+ if item_out:
132
+ non_constant_stride.append(item_out)
133
+ return tuple(non_constant_stride)
134
+
135
+ def get_stride_mnl(self):
136
+ """
137
+ Get the non-zero stride mnl. This is used in argument construction
138
+ """
139
+ stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
140
+ return stride
141
+
142
+ def get_smem_size(self, *args, **kwargs):
143
+ """
144
+ Get the shared memory size and alignment of current node
145
+ """
146
+ return (0, 1)
147
+
148
+
149
+ class NoOpImpl(ImplBase):
150
+ """
151
+ The NoOpImpl does nothing but forward its input to users
152
+ """
153
+ def __init__(self, node) -> None:
154
+ super().__init__(node)
155
+
156
+ @staticmethod
157
+ def match(node, problem_size: tuple):
158
+ if node.op == "store":
159
+ # Store that is not output is a No OP
160
+ return not node.is_output
161
+
162
+
163
+ class NodeBase:
164
+ """
165
+ Base class of DAG Node
166
+ """
167
+ def __init__(self, name: str) -> None:
168
+ self.name = name
169
+ self.underlying_impl = None
170
+
171
+ self._tensor = None
172
+
173
+ # Whether the node is disabled for emit
174
+ self.disabled = False
175
+
176
+ @property
177
+ def name_camel(self) -> str:
178
+ """
179
+ Return the CamelCase name.
180
+ """
181
+ return self.underlying_impl.name_camel
182
+
183
+ @property
184
+ def tensor(self) -> Tensor:
185
+ """
186
+ Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
187
+ """
188
+ return self._tensor
189
+
190
+ @tensor.setter
191
+ def tensor(self, kwargs):
192
+ """
193
+ Setting the tensor
194
+ """
195
+ self._tensor = Tensor(**kwargs)
196
+
197
+ #
198
+ # Helper functions for type/shape propagation
199
+ #
200
+
201
+ def shape_propagation(self, input_node_metas):
202
+ """
203
+ Infer shape from input nodes
204
+ General Broadcasting Rules from NumPy
205
+ When operating on two arrays, we compare their shapes element-wise.
206
+ It starts with the trailing (i.e. rightmost) dimension and works its
207
+ way left. Two dimensions are compatible when
208
+ 1. they are equal
209
+ 2. one of them is 1
210
+ """
211
+ if self._tensor is not None:
212
+ return
213
+
214
+ shape = None
215
+ for src in input_node_metas:
216
+ src_shape = src.tensor.shape
217
+ if shape is None:
218
+ shape = src_shape
219
+ else:
220
+ len_difference = len(shape) - len(src_shape)
221
+ if len_difference > 0:
222
+ for _ in range(len_difference):
223
+ src_shape = [1, ] + list(src_shape)
224
+ elif len_difference < 0:
225
+ for _ in range(-len_difference):
226
+ shape = [1, ] + list(shape)
227
+ broadcasted_shape = []
228
+ # Infer broadcast shape
229
+ for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
230
+ if shape_dim == 1:
231
+ broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
232
+ elif src_dim == 1:
233
+ broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
234
+ elif shape_dim == src_dim:
235
+ broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
236
+ else:
237
+ error_msg = "Dimension mismatch between "
238
+ for src_ in input_node_metas:
239
+ error_msg += f"{src_.name}{src_.tensor.shape}, "
240
+ error_msg = error_msg[:-2] + "."
241
+ raise RuntimeError(error_msg)
242
+ shape = tuple(broadcasted_shape)
243
+
244
+ self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
245
+
246
+ def type_propagation(self, *args, **kwargs):
247
+ """
248
+ Each node is associated with two data types: `element` and `element_output`.
249
+ The `element_output` is the type of return array of the node. The `element`
250
+ has specific meaning for different node types.
251
+ * Load Node: data type of tensor in gmem
252
+ * Compute Node: element compute
253
+ * Store Node: data type of tensor in gmem
254
+ This function must be overloaded in the derived classes
255
+ """
256
+ raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
257
+
258
+ def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
259
+ """
260
+ Propagate the broadcast in the reversed topological order.
261
+ For example:
262
+ C[l, m, n] = A[m, 1] + B[l, m, n]
263
+ After the broadcast propagation, it will be come
264
+ C[l, m, n] = A[l, m, n] + B[l, m, n]
265
+ and each tensor will have a proper stride accessing the underlying tensor
266
+ """
267
+ if self.tensor is None:
268
+ raise RuntimeError(f"The tensor of node {self.name} is unknown.")
269
+ for child in input_node_metas:
270
+ child.tensor.broadcast(self.tensor.shape)
271
+
272
+ def get_underlying_impl(self, problem_size: tuple):
273
+ """
274
+ Get the underlying implementation of the current node.
275
+ """
276
+ if self.tensor is None:
277
+ raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
278
+
279
+ for impl in self.possible_impls:
280
+ if impl.match(self, problem_size):
281
+ self.underlying_impl = impl(self)
282
+ break
283
+
284
+ if self.underlying_impl is None:
285
+ raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
286
+
287
+ #
288
+ # Visitor Nodes & Impls
289
+ #
290
+
291
+ class TopoVisitorImpl(ImplBase):
292
+ """
293
+ Impl for topological visitor
294
+ """
295
+ def __init__(self, node) -> None:
296
+ super().__init__(node.output_node)
297
+ self.name = node.name
298
+ self.element_output = node.output_node.element_output
299
+
300
+ class TopoVisitorNode(NodeBase):
301
+ def __init__(self, name: str, subgraph, output_node) -> None:
302
+ super().__init__(name)
303
+ self.subgraph = subgraph
304
+ self.output_node = output_node
305
+ self.op = "dag"
306
+ self.underlying_impl = TopoVisitorImpl(self)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Store node and implementations
35
+ """
36
+
37
+ import ctypes
38
+
39
+ from cutlass_library import DataType
40
+
41
+ from cutlass_cppgen.backend.c_types import tuple_factory
42
+ from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
43
+ from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
44
+ from cutlass_cppgen.backend.evt.ir.tensor import Tensor
45
+ from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
46
+
47
+
48
+ class StoreImplBase(ImplBase):
49
+ """
50
+ Base class for store node implementation
51
+ """
52
+ reserved_names = ["D"]
53
+ def __init__(self, node) -> None:
54
+ super().__init__(node)
55
+ self.element = node.element
56
+ self.element_output = node.element_output
57
+ self.stride = node.store_tensor.stride
58
+
59
+
60
+ class StoreDImpl(StoreImplBase):
61
+ """
62
+ Store D implementation
63
+ """
64
+
65
+ @property
66
+ def argument_type_d(self):
67
+ stride_mnl = self.get_stride_mnl()
68
+ tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
69
+ class _Argument(ctypes.Structure):
70
+ _fields_ = [
71
+ ("ptr_D", ctypes.c_void_p),
72
+ ("stride_D", tuple_type)
73
+ ]
74
+ def __init__(self, ptr: int) -> None:
75
+ self.ptr_D = ptr
76
+ self.stride_D = tuple_type(stride_mnl)
77
+
78
+ return _Argument
79
+
80
+ @staticmethod
81
+ def match(node, problem_size: tuple):
82
+ if node.name == "D" and node.store_tensor.shape == problem_size:
83
+ return True
84
+ return False
85
+
86
+
87
+ class AuxStoreImpl(StoreImplBase):
88
+ def __init__(self, node) -> None:
89
+ super().__init__(node)
90
+ self.round_style = FloatRoundStyle.ToNearest
91
+
92
+ @property
93
+ def argument_type(self):
94
+ stride_mnl = self.get_stride_mnl()
95
+ name = self.name
96
+ tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
97
+ class _Argument(ctypes.Structure):
98
+ _fields_ = [
99
+ ("ptr_aux", ctypes.c_void_p),
100
+ ("dAux", tuple_type)
101
+ ]
102
+ def __init__(self, kwargs) -> None:
103
+ ptr = kwargs[name]
104
+ self.ptr_aux = ptr
105
+ self.dAux = tuple_type(stride_mnl)
106
+
107
+ return _Argument
108
+
109
+ @staticmethod
110
+ def match(node, problem_size: tuple):
111
+ if not node.is_output:
112
+ return False
113
+ if node.name in StoreImplBase.reserved_names:
114
+ return False
115
+
116
+ strideMN = node.store_tensor.stride[-2:]
117
+ if (strideMN[0] == 1 and strideMN[1] != 0 or
118
+ strideMN[0] != 0 and strideMN[1] == 1 ):
119
+ return True
120
+ else:
121
+ return False
122
+
123
+
124
+ class ReductionImplBase(StoreImplBase):
125
+ def __init__(self, node) -> None:
126
+ super().__init__(node)
127
+ self.element = node.store_tensor.element
128
+ self.element_compute = node.element_compute
129
+ self.reg_reduce_fn = self.node.reg_reduce_fn
130
+ self.gmem_reduce_fn = self.node.gmem_reduce_fn
131
+ self.round_style = node.round_style
132
+ self.stride_dtype = "int"
133
+
134
+ def get_reduce_identity(self):
135
+ """
136
+ Return the reduction identity of the current reduce_fn
137
+ """
138
+ maxes = {
139
+ DataType.f32: (2 ** 31) - 1,
140
+ DataType.f16: (2 ** 15),
141
+ DataType.s32: (2 ** 31) - 1,
142
+ DataType.s8: (2 ** 7) - 1
143
+ }
144
+ mins = {
145
+ DataType.f32: -maxes[DataType.f32],
146
+ DataType.f16: -maxes[DataType.f16],
147
+ DataType.s32: -maxes[DataType.s32],
148
+ DataType.s8: -maxes[DataType.s8]
149
+ }
150
+ if self.reg_reduce_fn == FunctionalOp.Maximum:
151
+ if self.element_compute not in mins:
152
+ raise Exception(f"No min entry for data type {self.element_compute}")
153
+ return to_ctype_value(mins[self.element_compute], self.element_compute)
154
+ elif self.reg_reduce_fn == FunctionalOp.Multiplies:
155
+ return to_ctype_value(1., self.element_compute)
156
+ elif self.reg_reduce_fn == FunctionalOp.Minimum:
157
+ if self.element_compute not in maxes:
158
+ raise Exception(f"No max entry for data type {self.element_compute}")
159
+ return to_ctype_value(maxes[self.element_compute], self.element_compute)
160
+ else:
161
+ return to_ctype_value(0., self.element_compute)
162
+
163
+ @property
164
+ def argument_type(self):
165
+ self.get_reduce_identity()
166
+ stride_mnl = self.get_stride_mnl()
167
+ name = self.name
168
+ tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
169
+ element_compute = self.element_compute
170
+ reduce_identity = self.get_reduce_identity()
171
+ class _Argument(ctypes.Structure):
172
+ _fields_ = [
173
+ ("ptr", ctypes.c_void_p),
174
+ ("reduce_identity", dtype2ctype[element_compute]),
175
+ ("dMNL", tuple_type)
176
+ ]
177
+ def __init__(self, kwargs) -> None:
178
+ ptr = kwargs[name]
179
+ self.ptr = ptr
180
+ self.reduce_identity = reduce_identity
181
+ self.dMNL = tuple_type(stride_mnl)
182
+
183
+ return _Argument
184
+
185
+
186
+ class ColumnReductionImpl(ReductionImplBase):
187
+
188
+ @staticmethod
189
+ def match(node, problem_size: tuple):
190
+ if not node.is_output:
191
+ return False
192
+ if node.name in StoreImplBase.reserved_names:
193
+ return False
194
+
195
+ strideMN = node.store_tensor.stride[-2:]
196
+ if strideMN == (1, 0):
197
+ return True
198
+ else:
199
+ return False
200
+
201
+
202
+ class RowReductionImpl(ReductionImplBase):
203
+
204
+ @staticmethod
205
+ def match(node, problem_size: tuple):
206
+ if not node.is_output:
207
+ return False
208
+ if node.name in StoreImplBase.reserved_names:
209
+ return False
210
+
211
+ strideMN = node.store_tensor.stride[-2:]
212
+ if strideMN == (0, 1):
213
+ return True
214
+ else:
215
+ return False
216
+
217
+
218
+ class ScalarReductionImpl(ReductionImplBase):
219
+
220
+ @staticmethod
221
+ def match(node, problem_size: tuple):
222
+ if not node.is_output:
223
+ return False
224
+ if node.name in StoreImplBase.reserved_names:
225
+ return False
226
+
227
+ strideMN = node.store_tensor.stride[-2:]
228
+ if strideMN == (0, 0):
229
+ return True
230
+ else:
231
+ return False
232
+
233
+
234
+ class StoreNode(NodeBase):
235
+ """
236
+ Store node
237
+ """
238
+ possible_impls = [
239
+ AuxStoreImpl, RowReductionImpl,
240
+ ColumnReductionImpl, ScalarReductionImpl,
241
+ NoOpImpl, StoreDImpl
242
+ ]
243
+ def __init__(self, name: str) -> None:
244
+ super().__init__(name)
245
+ self.op = "store"
246
+ self.is_output = False
247
+ self._store_tensor = None
248
+
249
+ @property
250
+ def store_tensor(self) -> Tensor:
251
+ """
252
+ Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
253
+ """
254
+ return self._store_tensor
255
+
256
+ @store_tensor.setter
257
+ def store_tensor(self, kwargs):
258
+ """
259
+ Setting the tensor
260
+ """
261
+ self._store_tensor = Tensor(**kwargs)
262
+
263
+ def type_propagation(self, input_node_metas: 'list[NodeBase]'):
264
+ """
265
+ The store nodes has element_output = element_input
266
+ """
267
+ if self.is_output:
268
+ if self.store_tensor is None:
269
+ raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
270
+ self.element = self.store_tensor.element
271
+ assert len(input_node_metas) == 1, "Store node can only have one input node"
272
+ self.element_output = input_node_metas[0].element_output
273
+
274
+ def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
275
+ super().broadcast_propagation(input_node_metas)
276
+ if self.is_output:
277
+ self._store_tensor.broadcast(self.tensor.shape)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ High-level class for tensor
35
+ """
36
+
37
+ from cutlass_library import LayoutType
38
+
39
+ from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
40
+ Layout,
41
+ broadcast,
42
+ canonicalization,
43
+ permutation,
44
+ reshape,
45
+ _reverse_tuple
46
+ )
47
+ from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
48
+
49
+
50
+ class Tensor:
51
+ """
52
+ The tensor abstracts the data type
53
+ """
54
+ def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
55
+ if element is not None and tensor is not None:
56
+ raise Exception(f"Must not specify both element and tensor")
57
+ elif shape is not None and tensor is not None:
58
+ raise Exception(f"Must not specify both shape and tensor")
59
+ elif layout_tag is not None and tensor is not None:
60
+ raise Exception(f"Must not specify both layout_tag and tensor")
61
+ elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
62
+ raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
63
+ elif stride is not None and tensor is not None:
64
+ raise Exception(f"Must not specify both stride and tensor")
65
+ elif stride is not None and layout_tag is not None:
66
+ raise Exception(f"Must not specify layout_tag when stride is provided")
67
+
68
+ if isinstance(tensor, Tensor):
69
+ # Directly copy all the attributes
70
+ self.__dict__.update(vars(tensor))
71
+ else:
72
+ if tensor is None:
73
+ self.element = library_type(element)
74
+ else:
75
+ self.element, layout_tag = get_datatype_and_layout(tensor)
76
+ shape = get_tensor_shape(tensor)
77
+ if stride is not None:
78
+ self.layout = Layout(shape[::-1], stride[::-1])
79
+ else:
80
+ if layout_tag == LayoutType.RowMajor:
81
+ self.layout = Layout(shape[::-1])
82
+ elif layout_tag == LayoutType.ColumnMajor:
83
+ self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
84
+ self.layout = canonicalization(self.layout)
85
+
86
+ self.is_constant = is_constant
87
+ # Save the tensor value if it is constant
88
+ if is_constant and tensor is not None:
89
+ self.value = tensor
90
+
91
+ @property
92
+ def shape(self):
93
+ """
94
+ Returns the RowMajor layout shape
95
+ """
96
+ return _reverse_tuple(self.layout.shape)
97
+
98
+ @property
99
+ def stride(self):
100
+ """
101
+ Returns the RowMajor layout stride
102
+ """
103
+ return _reverse_tuple(self.layout.stride)
104
+
105
+ @property
106
+ def rank(self):
107
+ """
108
+ Returns the rank of the tensor
109
+ """
110
+ return len(self.shape)
111
+
112
+ #
113
+ # Layout Algorithms
114
+ #
115
+
116
+ def broadcast(self, shape):
117
+ """
118
+ Broadcast self.layout to shape
119
+ """
120
+ assert isinstance(shape, tuple)
121
+ self.layout = broadcast(self.layout, _reverse_tuple(shape))
122
+
123
+ def reshape(self, shape):
124
+ """
125
+ Reshape self.layout to shape
126
+ """
127
+ assert isinstance(shape, tuple)
128
+ reverse_shape = _reverse_tuple(shape)
129
+ self.layout = reshape(self.layout, reverse_shape)
130
+
131
+ def permute(self, indices):
132
+ """
133
+ Permute self.layout according to indices
134
+ """
135
+ length = len(indices)
136
+ indices = [length - idx - 1 for idx in indices]
137
+ self.layout = permutation(self.layout, indices[::-1])
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer
34
+ from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType
35
+ from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
36
+ from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
37
+ from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
38
+ from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
39
+ from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager
40
+ from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
41
+ from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
42
+ from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+ from __future__ import annotations
33
+
34
+ import subprocess
35
+
36
+ from cutlass_library import DataTypeTag
37
+
38
+ from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
39
+
40
+
41
+ _COLOR_MAP = {
42
+ "load": '"AliceBlue"',
43
+ "compute": "LemonChiffon1",
44
+ "accumulator": "LightGrey",
45
+ "store": "PowderBlue",
46
+ "layout": "lightseagreen",
47
+ "dag": "darkorange"
48
+ }
49
+
50
+
51
+ class EVTGraphDrawer:
52
+ """
53
+ Visualize a EVT DAGIR with graphviz
54
+ """
55
+ def __init__(
56
+ self,
57
+ graph: DAGIR,
58
+ name: str
59
+ ):
60
+ self._name = name
61
+ self._dot_graphs = {}
62
+
63
+ self._dot_graphs[name] = self._to_dot(graph, name)
64
+
65
+ def _get_node_style(self, node):
66
+ template = {
67
+ "shape": "record",
68
+ "fillcolor": "#CAFFE3",
69
+ "style": '"filled,rounded"',
70
+ "fontcolor": "#000000",
71
+ }
72
+ if node.op in _COLOR_MAP:
73
+ template["fillcolor"] = _COLOR_MAP[node.op]
74
+ else:
75
+ raise NotImplementedError("unknown node op")
76
+ if node.disabled:
77
+ template["fontcolor"] = "grey"
78
+ template["fillcolor"] = "white"
79
+ return template
80
+
81
+ def _get_node_label(self, node):
82
+ label = "{" + f"name={node.name}|op={node.op}"
83
+ if node.op == "layout":
84
+ label += f"|fn={node.fn.__name__}"
85
+ for key in node.kwargs:
86
+ label += f"|{key}={node.kwargs[key]}"
87
+ if node.underlying_impl is not None:
88
+ label += f"|impl={type(node.underlying_impl).__name__}"
89
+ if node.op == "load":
90
+ label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
91
+ elif node.op == "compute":
92
+ label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
93
+ elif node.op == "store":
94
+ label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
95
+ elif node.op == "dag":
96
+ label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
97
+ if node.tensor is not None:
98
+ shape = node.tensor.shape
99
+ stride = node.tensor.stride
100
+ label += f"|shape={shape}|stride={stride}"
101
+
102
+ if hasattr(node, "store_tensor"):
103
+ if node.store_tensor is not None:
104
+ store_shape = node.store_tensor.shape
105
+ store_stride = node.store_tensor.stride
106
+ label += f"|store_shape={store_shape}|stride_stride={store_stride}"
107
+
108
+ label += "}"
109
+ return label
110
+
111
+ def _to_dot(
112
+ self,
113
+ graph: DAGIR,
114
+ name: str
115
+ ):
116
+ import pydot
117
+ dot_graph = pydot.Dot(name, randir="TB")
118
+ for node in graph.nodes_meta:
119
+ style = self._get_node_style(node)
120
+ label = self._get_node_label(node)
121
+ dot_node = pydot.Node(
122
+ node.name, label=label, **style
123
+ )
124
+ dot_graph.add_node(dot_node)
125
+ if node.op == "dag":
126
+ dot_subgraph = self._to_dot(node.subgraph, name=node.name)
127
+ self._dot_graphs[node.name] = dot_subgraph
128
+
129
+ # Add edges
130
+ for src, dst in graph.edges:
131
+ weight = graph.get_edge_weight(src, dst)
132
+ dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
133
+
134
+ return dot_graph
135
+
136
+ def get_dot_graph(self) -> pydot.Dot:
137
+ return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]
138
+
139
+ def get_dot_graph_by_name(self, name) -> pydot.Dot:
140
+ return self._dot_graphs[name]
141
+
142
+ def get_main_dot_graph(self) -> pydot.Dot:
143
+ return self._dot_graphs[self._name]
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Construct the epilogue visitor argument type
35
+ """
36
+
37
+ from cutlass_cppgen.backend.c_types import visitor_factory
38
+ from cutlass_cppgen.backend.evt.ir import TopoVisitorNode
39
+ from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
40
+ from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
41
+ from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
42
+ from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
43
+ from cutlass_cppgen.backend.evt.passes.util import cc_map
44
+
45
+
46
+ class PassGetArgumentType(EVTPassBase):
47
+ """
48
+ Construct the epilogue visitor argument type
49
+ """
50
+ dependencies = [
51
+ PassShapeTypePropagation, # The Layout of all nodes must be set
52
+ PassDAG2Tree, # The type of each node must be set
53
+ PassGetImpl # The DAG subgraphs must be set
54
+ ]
55
+
56
+ def requires(self) -> None:
57
+ # Check "D" is in the node list
58
+ if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")):
59
+ raise SyntaxError(
60
+ "Sm90+ EVT requires the epilogue to have a returned tensor D, "
61
+ "but the variable 'D' is not found in the return values.")
62
+
63
+ def call(self):
64
+ nodes = self.dag_ir.nodes_topological_order()
65
+ self.argument_types = {}
66
+ for node in nodes:
67
+ meta = self.dag_ir.get_node_meta(node)
68
+ if not meta.disabled:
69
+ self.argument_types[node] = meta.underlying_impl.argument_type
70
+ if node == "D" and cc_map[self.cc] in [90, 100]:
71
+ continue
72
+ if isinstance(meta, TopoVisitorNode):
73
+ self.get_dag_argument_type(node)
74
+ else:
75
+ self.get_evt_argument_type(node)
76
+
77
+ self.cc_specific_method(self.set_argument_type)()
78
+
79
+ def get_evt_argument_type(self, node):
80
+ # Sort the input nodes by edge weight
81
+ input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)]
82
+ if len(input_types) > 0:
83
+ self.argument_types[node] = visitor_factory(
84
+ input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,])
85
+
86
+ def get_dag_argument_type(self, node):
87
+ meta = self.dag_ir.get_node_meta(node)
88
+ subgraph = meta.subgraph
89
+ subgraph_nodes = subgraph.nodes_topological_order()
90
+ # Visit the unvisited nodes in subgraph
91
+ for n in subgraph_nodes:
92
+ m = subgraph.get_node_meta(n)
93
+ if m.disabled:
94
+ continue
95
+ else:
96
+ self.argument_types[n] = m.underlying_impl.argument_type
97
+ input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
98
+ if len(input_types) > 0:
99
+ self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1])
100
+
101
+ def set_argument_type(self):
102
+ pass
103
+
104
+ def sm90_set_argument_type(self):
105
+ self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]]
106
+ # Get the tensorD argument type
107
+ self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d
108
+
109
+ # Get the tensorC argument type
110
+ if self.dag_ir.has_node("C"):
111
+ self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c
112
+ else:
113
+ self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
114
+
115
+ def sm100_set_argument_type(self):
116
+ self.sm90_set_argument_type()
117
+
118
+ def sm80_set_argument_type(self):
119
+ nodes = self.dag_ir.nodes_topological_order()
120
+ self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
35
+ by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
36
+ """
37
+
38
+ from copy import deepcopy
39
+
40
+ from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode
41
+ from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
42
+ from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
43
+ from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
44
+
45
+
46
+ class PassDAG2Tree(EVTPassBase):
47
+ """
48
+ Convert the DAG IR to Tree by fusing subgraphs
49
+ """
50
+ dependencies = [
51
+ PassShapeTypePropagation,
52
+ PassGetImpl
53
+ ]
54
+
55
+ def call(self):
56
+ # Step 1: find the nodes that have multiple parents
57
+ multi_parent_nodes = []
58
+
59
+ for node in self.dag_ir.nodes_topological_order():
60
+ if self.dag_ir.out_degree(node) > 1:
61
+ multi_parent_nodes.append(node)
62
+ # Step 2: find the lowest common ancestor (LCA) of all its parents
63
+ for node in multi_parent_nodes:
64
+ # A multi-parent node could be already fused by the previous node
65
+ if not self.dag_ir.has_node(node):
66
+ continue
67
+ # A node uncovered by the previous fusions can have out degree change
68
+ # Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
69
+ # Case 2: it has more than one edges to the previously fused subgraph, degree drops
70
+ if self.dag_ir.out_degree(node) <= 1:
71
+ continue
72
+
73
+ # Otherwise, the node still
74
+ reachable_nodes = []
75
+ # Complexity: O(Dout*N)
76
+ for parent in self.dag_ir.get_users(node):
77
+ reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
78
+ # get the common reachable objects
79
+ common_items = set.intersection(*reachable_nodes)
80
+ node_to_fuse = set.union(*reachable_nodes).difference(common_items)
81
+
82
+ lca = None
83
+ # If common ancestor exists, find the lowest one
84
+ if len(common_items) > 0:
85
+ topo_order = self.dag_ir.nodes_topological_order()
86
+ topo_idx = -1
87
+ for item in common_items:
88
+ if lca is None:
89
+ lca = item
90
+ topo_idx = topo_order.index(item)
91
+ else:
92
+ if topo_idx > topo_order.index(item):
93
+ lca = item
94
+ topo_idx = topo_order.index(item)
95
+ else:
96
+ # there is no common ancestor for all the parents, we pack all the reachable
97
+ # nodes into a single DAG node as a fallback. The lca should be the input node of
98
+ # one of the output nodes with out_degree = 0
99
+ potential_output_nodes = []
100
+ for node in node_to_fuse:
101
+ if self.dag_ir.out_degree(node) == 0:
102
+ potential_output_nodes.append(node)
103
+ if len(potential_output_nodes) == 0:
104
+ raise RuntimeError(f"No output node with out degree = 0 found.")
105
+
106
+ output_node = None
107
+ if (self.dag_ir.cc >= 90):
108
+ # For SM90+, the lca should be the input node of D
109
+ if (not self.dag_ir.has_node("D")):
110
+ raise RuntimeError(f"D is not a node in the DAG IR.")
111
+ output_node = "D"
112
+ else:
113
+ output_node = potential_output_nodes[0]
114
+
115
+ if (output_node is None):
116
+ raise RuntimeError(f"No output node found.")
117
+ lca = self.dag_ir.get_all_inputs(output_node)[0]
118
+ node_to_fuse.remove(output_node)
119
+
120
+ # The lca is the output node of the DAG node
121
+ # Get the nodes to be fused
122
+ node_to_fuse.add(lca)
123
+ # Get all the input nodes
124
+ all_input_nodes = []
125
+ all_output_nodes = []
126
+ for node in node_to_fuse:
127
+ all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
128
+ all_output_nodes.append(set(self.dag_ir.get_users(node)))
129
+ all_input_nodes = set.union(*all_input_nodes)
130
+ all_output_nodes = set.union(*all_output_nodes)
131
+
132
+ new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
133
+
134
+ # Create the subgraph
135
+ subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
136
+ subgraph = DAGIR(self.dag_ir.cc)
137
+ for node in subgraph_.nodes:
138
+ meta = deepcopy(self.dag_ir.get_node_meta(node))
139
+ if node not in node_to_fuse:
140
+ meta.disabled = True
141
+ subgraph.add_node(meta)
142
+ for edge in subgraph_.edges:
143
+ subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
144
+
145
+
146
+ # Create the fused node
147
+ dag_node = TopoVisitorNode(
148
+ name=f"dag_{lca}", subgraph=subgraph,
149
+ output_node=self.dag_ir.get_node_meta(lca))
150
+ self.dag_ir.add_node(dag_node)
151
+
152
+ # Add input edges
153
+ for idx, node in enumerate(all_input_nodes):
154
+ self.dag_ir.add_edge(node, dag_node.name, weight=idx)
155
+
156
+ # Replace all uses with DAG node (only 1 output node)
157
+ self.dag_ir.replace_all_uses_with(lca, dag_node.name)
158
+
159
+ # Remove all fused nodes
160
+ node_to_fuse.remove(lca)
161
+ for node in node_to_fuse:
162
+ self.dag_ir.remove_node(node)
163
+
164
+ def ensures(self) -> None:
165
+ # Ensure that after the pass, the resulting DAG becomes a tree
166
+ for node in self.dag_ir.nodes:
167
+ out_degree = self.dag_ir.out_degree(node)
168
+ if out_degree > 1:
169
+ raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}")
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Fix the element_output of producer of D.
35
+
36
+ In Sm90 epilogue visitor, the node writing D to gmem does not have internal
37
+ element converter, so the compute node producing D must have element_output = type(D).
38
+ """
39
+
40
+ from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
41
+ from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
42
+
43
+
44
+ class PassFixElementD(EVTPassBase):
45
+ """
46
+ In Sm90 epilogue visitor, the node writing D to gmem does not have internal
47
+ element converter, so the compute node producing D must have
48
+ element_output = type(D)
49
+ """
50
+ dependencies = [
51
+ PassLayoutManipulateElimination
52
+ ]
53
+ def get_producer(self, node, element_D):
54
+ node_meta = self.dag_ir.get_node_meta(node)
55
+ if node_meta.op == "compute":
56
+ node_meta.element_output = element_D
57
+ elif node_meta.op == "store":
58
+ self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D)
59
+
60
+ def call(self):
61
+ if self.dag_ir.has_node("D"):
62
+ node_d_meta = self.dag_ir.get_node_meta("D")
63
+ element_D = node_d_meta.store_tensor.element
64
+ self.get_producer("D", element_D)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Infer the underlying implement of each node.
35
+
36
+ While the frontend only distinguish between Load/Store/Compute Node,
37
+ each of these nodes can have different underlying implementation based
38
+ on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
39
+ This pass infers the underlying impl of each node
40
+ """
41
+
42
+ import cutlass_cppgen.backend.evt.backend as evt_backend
43
+ from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode
44
+ from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
45
+ from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
46
+ from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
47
+ from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
48
+ from cutlass_cppgen.backend.evt.passes.util import cc_map
49
+
50
+
51
+ class PassGetImpl(EVTPassBase):
52
+ """
53
+ While the frontend only distinguish between Load/Store/Compute Node,
54
+ each of these nodes can have different underlying implementation based
55
+ on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
56
+ This pass infers the underlying impl of each node
57
+ """
58
+ dependencies = [
59
+ PassShapeTypePropagation, # The shape and type info are required for inference
60
+ PassFixElementD
61
+ ]
62
+
63
+ def __init__(self, dag_ir: DAGIR) -> None:
64
+ super().__init__(dag_ir)
65
+ self.no_op_elimination = PassNoOpElimination(dag_ir)
66
+
67
+ def requires(self) -> None:
68
+ # Verify "accum" is in the arg list
69
+ if not self.dag_ir.has_node("accum"):
70
+ raise SyntaxError("Cannot find 'accum' in the argument list.")
71
+
72
+ def call(self):
73
+ # The loop structure of the epilogue is determined by the
74
+ # accumulator shape
75
+ accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
76
+ problem_size = accumulator.tensor.shape
77
+
78
+ for node_meta in self.dag_ir.node_metas_topological_order():
79
+ node_meta.get_underlying_impl(problem_size)
80
+
81
+ def ensures(self) -> None:
82
+ # Some nodes will be lowered to NoOp, eliminate them
83
+ self.no_op_elimination()
84
+ # Lower to cc-specific impl
85
+ for node_meta in self.dag_ir.nodes_meta:
86
+ node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
87
+ node_meta.underlying_impl = getattr(
88
+ node_impl_ccs,
89
+ f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__
90
+ )(node_meta)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Eliminate layout manipulation nodes
35
+ """
36
+
37
+ from copy import deepcopy
38
+
39
+ from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode
40
+ from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
41
+ from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
42
+
43
+
44
+ class PassLayoutManipulateElimination(EVTPassBase):
45
+ """
46
+ Eliminate layout manipulation nodes
47
+ """
48
+ dependencies = [PassShapeTypePropagation]
49
+
50
+ def __init__(self, dag_ir: DAGIR) -> None:
51
+ super().__init__(dag_ir)
52
+ self.copy_cnt = 0
53
+
54
+ def call(self):
55
+ self.layout_nodes_worklist = self.get_all_layout_nodes()
56
+ # Run while loop utill all layout nodes are eliminated
57
+ while(len(self.layout_nodes_worklist) > 0):
58
+ node = self.layout_nodes_worklist.pop(0)
59
+ # for node in layout_nodes:
60
+ # Step 1: get the propagation direction
61
+ direction = self.get_propagation_direction(node)
62
+ self.visited = []
63
+ getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node)
64
+ # Eliminate the current node
65
+ input_node = self.dag_ir.get_all_inputs(node)[0]
66
+ self.dag_ir.replace_all_uses_with(node, input_node)
67
+ # layout_nodes = self.get_all_layout_nodes()
68
+
69
+ def get_all_layout_nodes(self):
70
+ layout_nodes = []
71
+ for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
72
+ if isinstance(node_meta, LayoutNode):
73
+ layout_nodes.append(node_meta.name)
74
+ return layout_nodes
75
+
76
+ def get_propagation_direction(self, node: str):
77
+ """
78
+ The logic is propagating all layout nodes away from the accumulator node.
79
+ """
80
+ self.visited = []
81
+ self.get_influenced_users(node)
82
+ nodes_influenced_dir_users = self.visited
83
+ self.visited = []
84
+ self.get_influenced_inputs(node)
85
+ nodes_influenced_dir_inputs = self.visited
86
+
87
+ if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs:
88
+ return "inputs"
89
+ elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs:
90
+ return "users"
91
+ else:
92
+ raise RuntimeError("Unsolved propagation direction")
93
+
94
+ # Get all influenced nodes if we propagate along the user direction
95
+ def get_influenced_users(self, node: str):
96
+ if node in self.visited:
97
+ return
98
+ self.visited.append(node)
99
+
100
+ users = self.dag_ir.get_users(node)
101
+ for user in users:
102
+ self.get_influenced_users(user)
103
+ user_inputs = []
104
+ for user in users:
105
+ user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
106
+ if len(user_inputs) > 0:
107
+ user_inputs = set.union(*user_inputs)
108
+ user_inputs.remove(node)
109
+ for input in user_inputs:
110
+ self.get_influenced_inputs(input)
111
+
112
+ # Get all influenced nodes if we propagate along the input direction
113
+ def get_influenced_inputs(self, node: str):
114
+ if node in self.visited:
115
+ return
116
+ self.visited.append(node)
117
+
118
+ inputs = self.dag_ir.get_all_inputs(node)
119
+ for input in inputs:
120
+ self.get_influenced_inputs(input)
121
+ input_users = []
122
+ for input in inputs:
123
+ input_users.append(set(self.dag_ir.get_users(input)))
124
+ if len(input_users) > 0:
125
+ input_users = set.union(*input_users)
126
+ input_users.remove(node)
127
+ for user in input_users:
128
+ self.get_influenced_users(user)
129
+
130
+ def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
131
+ copied_node_meta = deepcopy(layout_node_meta)
132
+ copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
133
+ self.copy_cnt += 1
134
+ copied_node_meta.name = copied_node
135
+ self.dag_ir.add_node(copied_node_meta)
136
+ # Add edges
137
+ target_inputs = self.dag_ir.get_all_inputs(target)
138
+ for src in target_inputs:
139
+ self.dag_ir.remove_edge(src, target)
140
+ self.dag_ir.add_edge(src, copied_node)
141
+ self.dag_ir.add_edge(copied_node, target)
142
+ self.layout_nodes_worklist.append(copied_node)
143
+
144
+ def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
145
+ copied_node_meta = deepcopy(layout_node_meta)
146
+ copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
147
+ self.copy_cnt += 1
148
+ copied_node_meta.name = copied_node
149
+ self.dag_ir.add_node(copied_node_meta)
150
+ # Add edges
151
+ users = self.dag_ir.get_users(target)
152
+ for user in users:
153
+ self.dag_ir.remove_edge(target, user)
154
+ self.dag_ir.add_edge(copied_node, user)
155
+ self.dag_ir.add_edge(target, copied_node)
156
+ self.layout_nodes_worklist.append(copied_node)
157
+
158
+ # Propagate the layout `node` along the user direction
159
+ def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
160
+ """
161
+ Propagate layout node to users
162
+ """
163
+ if node in self.visited:
164
+ # Avoid applying twice
165
+ return
166
+ self.visited.append(node)
167
+
168
+ node_meta = self.dag_ir.get_node_meta(node)
169
+ if layout_node_meta.name != node:
170
+ if isinstance(node_meta, LayoutNode):
171
+ # Layout node is not transparent with layout node
172
+ self.add_copy_before(layout_node_meta, node)
173
+ return
174
+ else:
175
+ layout_node_meta.apply_to_user(node_meta)
176
+
177
+ users = self.dag_ir.get_users(node)
178
+ user_inputs = []
179
+ for user in users:
180
+ user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
181
+ for user in users:
182
+ self.propagate_to_users(layout_node_meta, user)
183
+ if len(user_inputs) > 0:
184
+ user_inputs = set.union(*user_inputs)
185
+ user_inputs.remove(node)
186
+ for input in user_inputs:
187
+ self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input)
188
+
189
+ # Propagate the layout `node` along the input direction
190
+ def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
191
+ """
192
+ Propagate layout node to inputs
193
+ """
194
+ if node in self.visited:
195
+ # Avoid applying twice
196
+ return
197
+ self.visited.append(node)
198
+
199
+ node_meta = self.dag_ir.get_node_meta(node)
200
+ if layout_node_meta.name != node:
201
+ if isinstance(node_meta, LayoutNode):
202
+ # Layout node is not transparent with layout node
203
+ self.add_copy_after(layout_node_meta, node)
204
+ return
205
+ else:
206
+ layout_node_meta.apply_to_input(node_meta)
207
+ inputs = self.dag_ir.get_all_inputs(node)
208
+ input_users = []
209
+ for input in inputs:
210
+ input_users.append(set(self.dag_ir.get_users(input)))
211
+ for input in inputs:
212
+ self.propagate_to_inputs(layout_node_meta, input)
213
+ if len(input_users) > 0:
214
+ input_users = set.union(*input_users)
215
+ input_users.remove(node)
216
+ for user in input_users:
217
+ self.propagate_to_users(layout_node_meta.get_inverse_node(), user)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Pass manager for DAG IR.
35
+ """
36
+
37
+ from typing import Any
38
+
39
+ import networkx as nx
40
+
41
+ from cutlass_cppgen.backend.evt.ir import DAGIR
42
+ from cutlass_cppgen.backend.evt.passes.util import cc_map
43
+
44
+
45
+ class EVTPassBase:
46
+ """
47
+ Base class for EVT Passes
48
+ """
49
+ dependencies = []
50
+ def __init__(self, dag_ir: DAGIR) -> None:
51
+ self.dag_ir = dag_ir
52
+ self.cc = self.dag_ir.cc
53
+
54
+ def requires(self) -> None:
55
+ """
56
+ This function will be called before the pass is run.
57
+ """
58
+ pass
59
+
60
+ def call(self) -> None:
61
+ """
62
+ The pass that is run through the self.dag_ir
63
+ """
64
+ raise NotImplementedError(
65
+ f"__call__ is not overwritten in Pass {self.__class__.__name__}")
66
+
67
+ def ensures(self) -> None:
68
+ """
69
+ This function will be called after the pass is run.
70
+ """
71
+ pass
72
+
73
+ def __call__(self) -> Any:
74
+ self.requires()
75
+ self.call()
76
+ self.ensures()
77
+
78
+ def cc_specific_method(self, func):
79
+ """
80
+ This enables defining function that behaves differently under different cc
81
+ The simplest example of using this function is the following
82
+
83
+ .. highlight:: python
84
+ .. code-block:: python
85
+
86
+ class ExamplePass(EVTPassBase):
87
+
88
+ def call(sekf):
89
+ # This automatically select the smXX_func based on current cc
90
+ self.cc_specific_method(self.func)()
91
+
92
+ # Interface func, can be empty
93
+ def func(self):
94
+ pass
95
+
96
+ # Sm90 specific func
97
+ def sm90_func(self):
98
+ // sm90 specific method
99
+ return
100
+
101
+ # Sm80 specific func
102
+ def sm80_func(self):
103
+ // sm80 specific method
104
+ return
105
+ """
106
+ func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
107
+ if hasattr(self, func_name):
108
+ return getattr(self, func_name)
109
+ else:
110
+ raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}")
111
+
112
+
113
+ class EVTPassManager(nx.DiGraph):
114
+ """
115
+ Topological-based Pass Manager.
116
+ Each registered pass has a list of dependencies. The pass manager organizes
117
+ the passes as a DAG and launch the compiler passes under topological order.
118
+ """
119
+ def __init__(self, dag_ir: DAGIR, pass_list):
120
+ super().__init__()
121
+ self.dag_ir = dag_ir
122
+ for pass_cls in pass_list:
123
+ self.add_pass(pass_cls)
124
+
125
+ self.sorted_passes = self.schedule()
126
+
127
+ def get_callable(self, pass_name):
128
+ """
129
+ Return the callable of the pass
130
+ """
131
+ return self.nodes[pass_name]["callable"]
132
+
133
+ def add_pass(self, pass_cls):
134
+ """
135
+ Add a pass to the pass manager
136
+ :param pass_cls: the class of pass
137
+ :type pass_cls: derived class of EVTPassBase
138
+ """
139
+ name = pass_cls.__name__
140
+ pass_callable = pass_cls(self.dag_ir)
141
+ self.add_node(name, callable=pass_callable)
142
+
143
+ def schedule(self):
144
+ """
145
+ Schedule the added passes under topological order
146
+ """
147
+ # Add edges
148
+ for pass_name in self.nodes:
149
+ callable = self.get_callable(pass_name)
150
+ for dependency_cls in callable.dependencies:
151
+ self.add_edge(
152
+ dependency_cls.__name__,
153
+ type(callable).__name__)
154
+
155
+ # Topological sort
156
+ return list(nx.topological_sort(self))
157
+
158
+ def __call__(self) -> Any:
159
+ """
160
+ Launch the registered passes
161
+ """
162
+ for pass_name in self.sorted_passes:
163
+ callable = self.get_callable(pass_name)
164
+ callable()
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ No op elimination node
35
+ """
36
+
37
+ from typing import Any
38
+
39
+ from cutlass_cppgen.backend.evt.ir import NoOpImpl
40
+ from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
41
+
42
+
43
+ class PassNoOpElimination(EVTPassBase):
44
+ """
45
+ The dead node elimination pass removes nodes with NoOpImpl in DAG IR
46
+ """
47
+ dependencies = []
48
+
49
+ def call(self) -> Any:
50
+ for node in self.dag_ir.nodes_topological_order():
51
+ node_meta = self.dag_ir.get_node_meta(node)
52
+ if isinstance(node_meta.underlying_impl, NoOpImpl):
53
+ self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0])
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Preprocess the reduction nodes.
35
+
36
+ The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
37
+ This pass fuses these into a single store node, and then replaces all uses of the
38
+ current node with the new store node.
39
+ """
40
+
41
+ from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode
42
+ from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
43
+
44
+
45
+ class PassPreprocessRed(EVTPassBase):
46
+ """
47
+ Preprocess red nodes
48
+ """
49
+
50
+ def call(self):
51
+ # Step 1: find the compute nodes with op=red
52
+ red_compute_nodes = []
53
+ for node_meta in self.dag_ir.nodes_meta:
54
+ if isinstance(node_meta, ComputeNode):
55
+ if type(node_meta.fn) == tuple:
56
+ # To keep the frontend simple, the reduction nodes
57
+ # are parsed into compute nodes by default
58
+ # The simple heuristic to distinguish between compute
59
+ # and reduction node is that compute node is a single function,
60
+ # while the reduction node is a tuple of functions for
61
+ # in-register reduction and atomic global memory reduction
62
+ red_compute_nodes.append(node_meta.name)
63
+
64
+ # Step 2: for each compute, merge it with the succeeding store
65
+ for node in red_compute_nodes:
66
+ # Verify
67
+ users = self.dag_ir.get_users(node)
68
+ inputs = self.dag_ir.get_all_inputs(node)
69
+ # Has a single user
70
+ assert len(users) == 1
71
+ assert len(inputs) == 1
72
+ user = users[0]
73
+ input = inputs[0]
74
+
75
+ user_meta = self.dag_ir.get_node_meta(user)
76
+ # Must be a store node
77
+ assert isinstance(user_meta, StoreNode)
78
+ # With output degree == 0
79
+ assert self.dag_ir.out_degree(user) == 0
80
+ # Register the reduce op
81
+ node_meta = self.dag_ir.get_node_meta(node)
82
+ user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
83
+ user_meta.element_compute = node_meta.element_compute
84
+ user_meta.round_style = node_meta.round_style
85
+
86
+ # Replace all uses
87
+ self.dag_ir.remove_edge(input, node)
88
+ input_users = self.dag_ir.get_users(input)
89
+ for iu in input_users:
90
+ weight = self.dag_ir.get_edge_weight(input, iu)
91
+ self.dag_ir.add_edge(user, iu, weight)
92
+ self.dag_ir.remove_edge(input, iu)
93
+ self.dag_ir.add_edge(input, user)
94
+ self.dag_ir.remove_node(node)
95
+
96
+ # Register the reduction name
97
+ self.dag_ir.reduction_names.append(user)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Shape and type propagation pass
35
+ """
36
+
37
+ from cutlass_cppgen.backend.evt.ir.node import NodeBase
38
+ from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
39
+ from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
40
+
41
+
42
+ class PassShapeTypePropagation(EVTPassBase):
43
+ """
44
+ Propagate the shape and type of all nodes
45
+ """
46
+ dependencies = [PassPreprocessRed]
47
+
48
+ def call(self):
49
+ # Propagate the node shape and type
50
+ for node in self.dag_ir.nodes_topological_order():
51
+ node_meta: NodeBase = self.dag_ir.get_node_meta(node)
52
+ input_node_metas = self.dag_ir.get_all_inputs_meta(node)
53
+ node_meta.type_propagation(input_node_metas)
54
+ node_meta.shape_propagation(input_node_metas)
55
+
56
+ for node in reversed(self.dag_ir.nodes_topological_order()):
57
+ node_meta: NodeBase = self.dag_ir.get_node_meta(node)
58
+ input_node_metas = self.dag_ir.get_all_inputs_meta(node)
59
+ node_meta.broadcast_propagation(input_node_metas)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Compute the shared memory size in bytes
35
+ """
36
+
37
+ from math import gcd
38
+
39
+ import cutlass_library
40
+ from pycute import flatten, shape_div, product
41
+
42
+ import cutlass_cppgen
43
+ from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
44
+ from cutlass_cppgen.backend.library import DataType, DataTypeSize
45
+
46
+
47
+ class GetSmemSize:
48
+ """
49
+ Get the size in byte of shared memory used by the kernel
50
+ """
51
+ def __init__(self, dag_ir: DAGIR) -> None:
52
+ self.dag_ir = dag_ir
53
+ self.cc = self.dag_ir.cc
54
+
55
+ #
56
+ # Sm90 epilogue specific
57
+ #
58
+
59
+ def sm90_epilogue_tile(self, tile_description):
60
+ # Get the epilogue tile size
61
+ schedule = tile_description.epilogue_schedule
62
+ if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized:
63
+ element_d = self.dag_ir.get_node_meta("D").element
64
+ nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32
65
+ epi_tile_m = min(64, tile_description.threadblock_shape[0])
66
+ epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
67
+ epilogue_tile_mn = (epi_tile_m, epi_tile_n)
68
+ elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative:
69
+ epi_tile_m = min(128, tile_description.threadblock_shape[0])
70
+ epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
71
+ epilogue_tile_mn = (epi_tile_m, epi_tile_n)
72
+ else:
73
+ raise NotImplementedError(f"Unsupported schedule: {schedule}")
74
+
75
+ # Get the pipeline stages
76
+ stages_d = 2
77
+ epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
78
+ if self.dag_ir.has_node("C"):
79
+ element_c = self.dag_ir.get_node_meta("C").element
80
+ else:
81
+ element_c = None
82
+
83
+ element_d = self.dag_ir.get_node_meta("D").element
84
+ if element_c == element_d:
85
+ reuse_smem_c = True
86
+ else:
87
+ reuse_smem_c = False
88
+ stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
89
+
90
+ # Record the epilogue tile
91
+ self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
92
+ self.epilogue_tile_mn = epilogue_tile_mn
93
+ self.epi_tiles = epi_tiles
94
+ self.stages_c = stages_c
95
+ self.stages_d = stages_d
96
+ self.reuse_smem_c = reuse_smem_c
97
+ self.element_c = element_c
98
+ self.element_d = element_d
99
+ self.is_source_supported = element_c is not None
100
+
101
+ def sm90_or_sm100_epilogue_smem_size(self, tile_description):
102
+ # Get the Fusion Storage
103
+ nodes = self.dag_ir.nodes_topological_order()
104
+ self.smem_types = {}
105
+ for node in nodes:
106
+ meta = self.dag_ir.get_node_meta(node)
107
+ if not meta.disabled:
108
+ self.smem_types[node] = meta.underlying_impl.get_smem_size(
109
+ self.cta_tile_mnk, self.epilogue_tile_mn,
110
+ self.stages_c, self.stages_d, self.epi_tiles)
111
+ if node == "D":
112
+ continue
113
+ if isinstance(meta, TopoVisitorNode):
114
+ self.get_dag_smem_type(node)
115
+ else:
116
+ self.get_evt_smem_type(node)
117
+
118
+ thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
119
+ # Get the Tensor Storage
120
+ tensors = []
121
+ if self.is_source_supported:
122
+ smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8
123
+ tensors.append((smem_C, 128))
124
+ else:
125
+ tensors.append((0, 1))
126
+ if self.reuse_smem_c:
127
+ tensors.append((0, 128))
128
+ else:
129
+ smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8
130
+ tensors.append((smem_D, 128))
131
+ tensors.append((thread_smem_size, 128))
132
+
133
+ tensor_smem_size = self.get_struct_size(tensors)
134
+ # Get pipeline storage size
135
+ # sizeof(uint64_t * stages_c * 2), alignment of uint64_t
136
+ # 2 is for FullBarrier and EmptyBarrier
137
+ pipeline_smem_size = (8 * self.stages_c * 2, 8)
138
+
139
+ # get SharedStorage size
140
+ smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
141
+ return smem_size[0]
142
+
143
+ def sm90_epilogue_smem_size(self, tile_description):
144
+ """
145
+ Compute the shared memory size of sm90 collective epilogue
146
+ """
147
+ self.sm90_epilogue_tile(tile_description)
148
+ return self.sm90_or_sm100_epilogue_smem_size(tile_description)
149
+
150
+ #
151
+ # Sm100 epilogue specific
152
+ #
153
+
154
+ def sm100_epilogue_tile(self, tile_description):
155
+ cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1])
156
+ mma_tile = cta_tile
157
+
158
+ if tile_description.is_2sm:
159
+ cta_tile = (cta_tile[0] // 2, cta_tile[1])
160
+
161
+ if tile_description.is_2sm and mma_tile[0] == 128:
162
+ tmem_warps = (2, 2)
163
+ else:
164
+ tmem_warps = (4, 1)
165
+
166
+ if self.dag_ir.has_node("C"):
167
+ element_c = self.dag_ir.get_node_meta("C").element
168
+ element_c_size = DataTypeSize[element_c]
169
+ else:
170
+ element_c = None
171
+ element_c_size = 0
172
+
173
+ element_d = self.dag_ir.get_node_meta("D").element
174
+
175
+ DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void
176
+
177
+ CtaM = cta_tile[0]
178
+ CtaN = cta_tile[1]
179
+ WarpM = tmem_warps[0]
180
+ WarpN = tmem_warps[1]
181
+ MaxBits = max(element_c_size, DataTypeSize[element_d])
182
+ DpFull = 32
183
+ M = min(CtaM, DpFull * WarpM)
184
+
185
+ if DisableSource:
186
+ # Epilogues w/o residual load are less sensitive to smem allocation
187
+ # Target a fixed amount of compute per epilogue iteration
188
+ if MaxBits == 4:
189
+ # Make epilogue tile larger to reduce the epilogue iterations.
190
+ # 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
191
+ ComputeElts = 8192
192
+ Nperf = ComputeElts // M
193
+ else:
194
+ ComputeElts = 4096
195
+ Nperf = ComputeElts // M
196
+ else:
197
+ # Epilogues w/ residual load are more sensitive to smem allocation
198
+ # Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
199
+ if MaxBits == 32:
200
+ Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32
201
+ elif MaxBits == 16:
202
+ Nperf = 32 if CtaN <= 128 else 64
203
+ else:
204
+ Nperf = 64
205
+
206
+ def is_m_major(layout):
207
+ return flatten(layout.stride[0]) == 1
208
+
209
+ if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout):
210
+ N_min_C = 8 * WarpN
211
+ elif element_c_size == 6:
212
+ N_min_C = 128 * WarpN
213
+ else:
214
+ N_min_C = (128 // element_c_size) * WarpN
215
+
216
+ if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout):
217
+ N_min_D = 8 * WarpN
218
+ elif DataTypeSize[element_d] == 6:
219
+ N_min_D = 128 * WarpN
220
+ else:
221
+ N_min_D = (128 // DataTypeSize[element_d]) * WarpN
222
+
223
+ N = min(CtaN, max(Nperf, N_min_C, N_min_D))
224
+
225
+ tile_m = M
226
+ tile_n_size = N // WarpN * WarpN
227
+
228
+ epilogue_tile_mn = (tile_m, tile_n_size)
229
+ epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
230
+
231
+ stages_d = min(epi_tiles, 2)
232
+ reuse_smem_c = (element_c_size > 8)
233
+
234
+ if reuse_smem_c:
235
+ stages_c = max(min(epi_tiles, 4), stages_d + 1)
236
+ else:
237
+ stages_c = min(epi_tiles, 4)
238
+
239
+ # Record the epilogue tile
240
+ self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
241
+ self.epilogue_tile_mn = epilogue_tile_mn
242
+ self.epi_tiles = epi_tiles
243
+ self.stages_c = stages_c
244
+ self.stages_d = stages_d
245
+ self.reuse_smem_c = reuse_smem_c
246
+ self.element_c = element_c
247
+ self.element_d = element_d
248
+ self.is_source_supported = not DisableSource
249
+
250
+ def sm100_epilogue_smem_size(self, tile_description):
251
+ """
252
+ Compute the shared memory size of sm100 collective epilogue
253
+ """
254
+ self.sm100_epilogue_tile(tile_description)
255
+ return self.sm90_or_sm100_epilogue_smem_size(tile_description)
256
+
257
+ def __call__(self, tile_description):
258
+ return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
259
+
260
+ #
261
+ # Helper functions
262
+ #
263
+
264
+ @staticmethod
265
+ def get_visitor_size(members: list, ebo: bool):
266
+ """
267
+ Get the size of struct in bytes
268
+ """
269
+ offset = 0
270
+ max_alignment = 1
271
+ if len(members) > 0:
272
+ # Get alignment
273
+ for _, alignment in members:
274
+ max_alignment = max(max_alignment, alignment)
275
+
276
+ for type_size, _ in members:
277
+ if type_size != 0:
278
+ offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
279
+ if type_size == 0 and not ebo:
280
+ offset += 1
281
+ else:
282
+ offset += type_size
283
+ offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
284
+ return (offset, max_alignment)
285
+ else:
286
+ # Struct size is at least 1
287
+ return (1, 1)
288
+
289
+ def get_struct_size(self, members: list):
290
+ """
291
+ Get the size of struct in bytes
292
+ """
293
+ return self.get_visitor_size(members, False)
294
+
295
+ def get_evt_smem_type(self, node):
296
+ # Sort the input nodes by edge weight
297
+ input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)]
298
+ input_types.append(self.smem_types[node])
299
+ if len(input_types) > 1:
300
+ ebo = len(input_types) > 4
301
+ self.smem_types[node] = self.get_visitor_size(input_types, ebo)
302
+
303
+ def get_dag_smem_type(self, node):
304
+ meta = self.dag_ir.get_node_meta(node)
305
+ subgraph = meta.subgraph
306
+ subgraph_nodes = subgraph.nodes_topological_order()
307
+ # Visit the unvisited nodes in subgraph
308
+ for n in subgraph_nodes:
309
+ m = subgraph.get_node_meta(n)
310
+ if m.disabled:
311
+ continue
312
+ else:
313
+ self.smem_types[n] = m.underlying_impl.get_smem_size(
314
+ self.cta_tile_mnk, self.epilogue_tile_mn,
315
+ self.stages_c, self.stages_d, self.epi_tiles)
316
+ input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
317
+ if len(input_types) > 0:
318
+ ebo = len(input_types) > 4
319
+ self.smem_types[node] = self.get_visitor_size(input_types, ebo)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for passes
35
+ """
36
+
37
+ # Map from the CC of the kernel to the EVT implementation that the CC targets
38
+ cc_map = {
39
+ 80: 80,
40
+ 86: 80,
41
+ 89: 80,
42
+ 90: 90,
43
+ 100: 100,
44
+ 101: 100,
45
+ 103: 100,
46
+ }
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+ from __future__ import annotations
33
+
34
+ from cutlass_cppgen.utils.lazy_import import lazy_import
35
+ cuda = lazy_import("cuda.cuda")
36
+ import numpy as np
37
+
38
+ from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
39
+ from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
40
+
41
+
42
+ class NumpyFrontend:
43
+ """
44
+ Frontend node for numpy
45
+ """
46
+
47
+ @staticmethod
48
+ def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr:
49
+ """Convert the input numpy tensor to CUDA device pointer
50
+
51
+ :param np_tensor: input numpy nd array
52
+ :param is_output: whether the tensor is output
53
+
54
+ :return: CUDA device pointer
55
+ """
56
+ # copy the data to device
57
+ if is_output:
58
+ return device_mem_alloc(np_tensor.size * np_tensor.itemsize)
59
+ else:
60
+ return todevice(np_tensor)
61
+
62
+
63
+ class TorchFrontend:
64
+ """
65
+ Frontend node for torch
66
+ """
67
+
68
+ @staticmethod
69
+ def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr:
70
+ """Convert the input torch tensor to CUDA device pointer
71
+
72
+ :param torch_tensor: input torch tensor
73
+ :param is_output: whether the tensor is output
74
+
75
+ :return: CUDA device pointer
76
+ """
77
+
78
+ # check the device of torch_tensor
79
+ if not torch_tensor.is_cuda:
80
+ torch_tensor = torch_tensor.to("cuda")
81
+
82
+ return cuda.CUdeviceptr(torch_tensor.data_ptr())
83
+
84
+
85
+ class CupyFrontend:
86
+ """
87
+ Frontend node for cupy
88
+ """
89
+
90
+ @staticmethod
91
+ def argument(cupy_ndarray: "cp.ndarray"):
92
+ return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr))
93
+
94
+
95
+ class TensorFrontend:
96
+ """
97
+ Universal Frontend for client-provide tensors
98
+ """
99
+
100
+ @staticmethod
101
+ def argument(tensor, is_output=False):
102
+ if is_numpy_tensor(tensor):
103
+ return NumpyFrontend.argument(tensor, is_output)
104
+ elif is_torch_tensor(tensor):
105
+ return TorchFrontend.argument(tensor)
106
+ elif is_cupy_tensor(tensor):
107
+ return CupyFrontend.argument(tensor)
108
+ else:
109
+ raise NotImplementedError("Unknown Tensor Type")
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py ADDED
@@ -0,0 +1,2145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+ from __future__ import annotations
33
+
34
+ import copy
35
+ import ctypes
36
+ import enum
37
+
38
+ from cutlass_cppgen.utils.lazy_import import lazy_import
39
+ cuda = lazy_import("cuda.cuda")
40
+ cudart = lazy_import("cuda.cudart")
41
+ from cutlass_library import SubstituteTemplate
42
+ import numpy as np
43
+
44
+ from cutlass_library import (
45
+ ComplexTransformTag,
46
+ DataType,
47
+ DataTypeNames,
48
+ DataTypeSize,
49
+ DataTypeTag,
50
+ EpilogueScheduleSuffixes,
51
+ EpilogueScheduleTag,
52
+ EpilogueScheduleType,
53
+ GemmKind,
54
+ GemmKindNames,
55
+ GemmUniversalMode,
56
+ KernelScheduleSuffixes,
57
+ KernelScheduleTag,
58
+ KernelScheduleType,
59
+ LayoutTag,
60
+ LayoutType,
61
+ MathOperation,
62
+ MathOperationTag,
63
+ OpcodeClass,
64
+ OpcodeClassNames,
65
+ OpcodeClassTag,
66
+ OperationKind,
67
+ ShortComplexLayoutNames,
68
+ ShortDataTypeNames,
69
+ ShortLayoutTypeNames,
70
+ SwizzlingFunctor,
71
+ SwizzlingFunctorTag,
72
+ TileSchedulerSuffixes,
73
+ TileSchedulerTag,
74
+ TileSchedulerType,
75
+ get_complex_from_real
76
+ )
77
+ from cutlass_cppgen.backend.arguments import ArgumentBase
78
+ from cutlass_cppgen.backend.c_types import (
79
+ GemmCoord_,
80
+ GemmCoordBatched_,
81
+ GenericMainloopArguments3x_,
82
+ StrideBatched_,
83
+ dim3_,
84
+ get_gemm_arguments,
85
+ get_gemm_arguments_3x,
86
+ get_gemm_arguments_streamk,
87
+ get_gemm_grouped_arguments,
88
+ get_mainloop_arguments_3x,
89
+ get_tile_scheduler_arguments_3x,
90
+ )
91
+ from cutlass_cppgen.backend.library import (
92
+ ApiVersion,
93
+ EmissionType,
94
+ SchedulerMode,
95
+ SchedulerModeTag,
96
+ TensorDescription,
97
+ TileDescription,
98
+ api_version,
99
+ )
100
+ from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
101
+ from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
102
+ from cutlass_cppgen.backend.type_hint import GemmOperation, Tensor
103
+ from cutlass_cppgen.backend.utils.device import device_sm_count
104
+ from cutlass_cppgen.shape import GemmCoord, MatrixCoord
105
+
106
+
107
+ ################################################################################
108
+ #
109
+ # Data structure modeling a GEMM operation
110
+ #
111
+ ################################################################################
112
+
113
+
114
+ def leading_dimension(layout: LayoutType, shape: MatrixCoord) -> int:
115
+ """
116
+ Returns the leading dimenson of a tensor with layout ``layout`` and shape ``shape``.
117
+
118
+ :param layout: layout of the tensor
119
+ :type layout: cutlass_cppgen.shape.LayoutType
120
+ :param shape: shape of the tensor
121
+ :type shape: cutlass_cppgen.shape.MatrixCoord
122
+
123
+ :return: leading dimension of the tensor
124
+ :rtype: int
125
+ """
126
+ if layout == LayoutType.RowMajor:
127
+ return shape.column
128
+ elif layout == LayoutType.ColumnMajor:
129
+ return shape.row
130
+
131
+
132
+ def transpose_layout(layout: LayoutType) -> LayoutType:
133
+ if layout == LayoutType.ColumnMajor:
134
+ return LayoutType.RowMajor
135
+ elif layout == LayoutType.RowMajor:
136
+ return LayoutType.ColumnMajor
137
+ else:
138
+ raise ValueError(f"Unsupported Layout {layout}")
139
+
140
+
141
+ class GemmArguments2x(ArgumentBase):
142
+ """
143
+ Argument wrapper for GEMM in CUTLASS 2. It encodes problem information and
144
+ user-provide tensors into the kernel's argument
145
+
146
+ :param operation: the GEMM operation to take the argument
147
+ :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
148
+ :class:`cutlass_cppgen.backend.GemmOperationGrouped`
149
+
150
+ :param problem_size: GEMM problem size gemm(M, N, K)
151
+ :type operation: :class:`cutlass_cppgen.shape.GemmCoord`
152
+
153
+ :param A: tensor A
154
+ :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
155
+
156
+ :param B: tensor B
157
+ :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
158
+
159
+ :param C: tensor C
160
+ :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
161
+
162
+ :param D: tensor D
163
+ :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
164
+
165
+ :param gemm_mode: GEMM mode
166
+ :type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
167
+
168
+ :param output_op: output operator, optional
169
+ :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
170
+
171
+ :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
172
+ :type stream: :class:`cuda.cuda.CUstream`
173
+ """
174
+
175
+ def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
176
+ self.operation = operation
177
+
178
+ self.layout_A = operation.A.layout
179
+ self.layout_B = operation.B.layout
180
+ self.layout_C = operation.C.layout
181
+
182
+ self.element_A = operation.A.element
183
+ self.element_B = operation.B.element
184
+ self.element_C = operation.C.element
185
+
186
+ if operation.C.layout in [LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32]:
187
+ raise Exception("Interleaved layout not currently supported")
188
+
189
+ if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch not in [90, 100, 101, 103]:
190
+ super().__init__(A, B, None, None, **kwargs)
191
+ else:
192
+ super().__init__(A, B, C, D, **kwargs)
193
+
194
+ if operation.switched:
195
+ self.problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
196
+ self.ptr_A, self.ptr_B = self.ptr_B, self.ptr_A
197
+ else:
198
+ self.problem_size = problem_size
199
+ # If the number of elements in C = problem_size.n, C is treated as the bias
200
+ if hasattr(self, "tensor_c_numel"):
201
+ if self.tensor_c_numel == self.problem_size.n and self.problem_size.m != 1:
202
+ self.bias = True
203
+
204
+ self.lda = leading_dimension(self.layout_A, self.problem_size.mk)
205
+ self.ldb = leading_dimension(self.layout_B, self.problem_size.kn)
206
+ self.ldc = leading_dimension(self.layout_C, self.problem_size.mn)
207
+ self.ldd = self.ldc
208
+
209
+ if self.bias:
210
+ self.ldc = 0
211
+
212
+ if "output_op" in kwargs.keys() and gemm_mode != GemmUniversalMode.GemmSplitKParallel:
213
+ self.output_op = kwargs["output_op"]
214
+ else:
215
+ if self.operation.epilogue_functor.element_epilogue in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]:
216
+ dtype = int
217
+ else:
218
+ dtype = float
219
+ self.output_op = self.operation.epilogue_type(dtype(1.0), dtype(0.0))
220
+
221
+ self.gemm_mode = gemm_mode
222
+ if gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]:
223
+ if "split_k_slices" in kwargs.keys():
224
+ self.batch_count = kwargs["split_k_slices"]
225
+ else:
226
+ self.batch_count = 1
227
+ self.split_k_slices = self.batch_count
228
+
229
+ if gemm_mode in [GemmUniversalMode.Batched, GemmUniversalMode.Array]:
230
+ if "batch" in kwargs.keys():
231
+ self.batch_count = kwargs["batch"]
232
+ else:
233
+ self.batch_count = 1
234
+
235
+ if "batch_strides" in kwargs:
236
+ self.batched_stride_A = kwargs["batch_strides"]["A"]
237
+ self.batched_stride_B = kwargs["batch_strides"]["B"]
238
+ self.batched_stride_C = kwargs["batch_strides"]["C"]
239
+ self.batched_stride_D = kwargs["batch_strides"]["D"]
240
+ else:
241
+ self.batched_stride_A = self.problem_size.m * self.problem_size.k
242
+ self.batched_stride_B = self.problem_size.n * self.problem_size.k
243
+ self.batched_stride_C = self.problem_size.m * self.problem_size.n
244
+ self.batched_stride_D = self.problem_size.m * self.problem_size.n
245
+
246
+ if self.bias:
247
+ self.batched_stride_C = self.problem_size.n
248
+
249
+ if gemm_mode == GemmUniversalMode.Array:
250
+ self.ptr_A_array = []
251
+ self.ptr_B_array = []
252
+ self.ptr_C_array = []
253
+ self.ptr_D_array = []
254
+
255
+ ptr_A_addr = int(self.ptr_A)
256
+ ptr_B_addr = int(self.ptr_B)
257
+ ptr_C_addr = int(self.ptr_C)
258
+ ptr_D_addr = int(self.ptr_D)
259
+
260
+ stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8
261
+ stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8
262
+ stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8
263
+ stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8
264
+ for _ in range(self.batch_count):
265
+ self.ptr_A_array.append(ptr_A_addr)
266
+ self.ptr_B_array.append(ptr_B_addr)
267
+ self.ptr_C_array.append(ptr_C_addr)
268
+ self.ptr_D_array.append(ptr_D_addr)
269
+
270
+ ptr_A_addr += stride_A
271
+ ptr_B_addr += stride_B
272
+ ptr_C_addr += stride_C
273
+ ptr_D_addr += stride_D
274
+
275
+ self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64)
276
+ self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64)
277
+ self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64)
278
+ self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64)
279
+
280
+ if isinstance(self.operation, GemmOperationUniversal):
281
+ self.initialize()
282
+
283
+ def get_arguments(self):
284
+ problem_size_ = self.problem_size.ctype
285
+ grid_tiled_shape_ = GemmCoord(
286
+ self.grid_tiled_shape.x,
287
+ self.grid_tiled_shape.y,
288
+ self.grid_tiled_shape.z ).ctype
289
+
290
+ if self.gemm_mode == GemmUniversalMode.Array:
291
+ arguments = self.operation.argument_type(
292
+ # Arguments from UniversalArgumentsBase
293
+ self.gemm_mode,
294
+ problem_size_,
295
+ self.batch_count,
296
+ 0,
297
+ # Remaining arguments
298
+ self.output_op,
299
+ int(self.ptr_A_array_buffer.ptr),
300
+ int(self.ptr_B_array_buffer.ptr),
301
+ int(self.ptr_C_array_buffer.ptr),
302
+ int(self.ptr_D_array_buffer.ptr),
303
+ 0, 0, 0,
304
+ self.lda, self.ldb, self.ldc, self.ldd,
305
+ self.lda, self.ldb, self.ldc, self.ldd,
306
+ 0, 0, 0
307
+ )
308
+ else:
309
+ arguments = self.operation.argument_type(
310
+ # Arguments from UniversalArgumentsBase
311
+ self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D,
312
+ # Remaining arguments
313
+ self.output_op,
314
+ int(self.ptr_A),
315
+ int(self.ptr_B),
316
+ int(self.ptr_C),
317
+ int(self.ptr_D),
318
+ self.batched_stride_A,
319
+ self.batched_stride_B,
320
+ self.batched_stride_C,
321
+ self.lda, self.ldb, self.ldc, self.ldd,
322
+ self.lda, self.ldb, self.ldc, self.ldd,
323
+ 0, 0, 0
324
+ )
325
+
326
+ self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size
327
+
328
+ def initialize(self):
329
+ launch_config = self.operation.rt_module.plan(self)
330
+
331
+ # Get the host and device workspace
332
+ device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
333
+
334
+ if device_workspace_size > 0:
335
+ self.workspace_buffer = device_mem_alloc(device_workspace_size)
336
+ workspace_ptr = self.workspace_buffer.ptr
337
+ err, = cuda.cuMemsetD32(
338
+ workspace_ptr, 0, device_workspace_size // 4)
339
+ else:
340
+ workspace_ptr = None
341
+
342
+ device_workspace = 0
343
+ if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
344
+ # In GEMM splik-K parallel, the D pointer is redirected to the workspace
345
+ self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
346
+ elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
347
+ device_workspace = workspace_ptr
348
+
349
+ self.get_arguments()
350
+
351
+ arguments, grid_tiled_shape, gemm_k_size = self.arguments
352
+ res_arg = self.operation.rt_module.get_args(
353
+ ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace)))
354
+ host_workspace = bytearray(res_arg.contents)
355
+
356
+ device_workspace = None
357
+
358
+ self.host_workspace = host_workspace
359
+ self.device_workspace = device_workspace
360
+ self.launch_config = launch_config
361
+
362
+ def sync(self, stream_sync=True):
363
+ super().sync(stream_sync)
364
+ if hasattr(self.output_op, "sync"):
365
+ self.output_op.sync()
366
+
367
+
368
+ class GemmArguments2xStreamK(GemmArguments2x):
369
+ """
370
+ Argument wrapper for stream-K GEMMs in CUTLASS 2. It encodes problem information and
371
+ user-provide tensors into the kernel's argument
372
+
373
+ :param operation: the GEMM operation to take the argument
374
+ :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
375
+ :class:`cutlass_cppgen.backend.GemmOperationGrouped`
376
+
377
+ :param problem_size: GEMM problem size gemm(M, N, K)
378
+ :type operation: :class:`cutlass_cppgen.shape.GemmCoord`
379
+
380
+ :param A: tensor A
381
+ :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
382
+
383
+ :param B: tensor B
384
+ :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
385
+
386
+ :param C: tensor C
387
+ :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
388
+
389
+ :param D: tensor D
390
+ :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
391
+
392
+ :param gemm_mode: GEMM mode
393
+ :type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
394
+
395
+ :param output_op: output operator, optional
396
+ :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
397
+ """
398
+
399
+ def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
400
+ if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]:
401
+ raise Exception(f"Unsupported GEMM mode {gemm_mode}.")
402
+
403
+ super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
404
+
405
+ def get_arguments(self):
406
+ batch_stride_A = self.problem_size.m * self.problem_size.k
407
+ batch_stride_B = self.problem_size.k * self.problem_size.n
408
+ batch_stride_C = self.problem_size.m * self.problem_size.n
409
+ batch_stride_D = self.problem_size.m * self.problem_size.n
410
+
411
+ arguments = self.operation.argument_type(
412
+ self.gemm_mode,
413
+ GemmCoord_(self.problem_size.m, self.problem_size.n, self.problem_size.k),
414
+ self.batch_count,
415
+ self.output_op,
416
+ int(self.ptr_A),
417
+ int(self.ptr_B),
418
+ int(self.ptr_C),
419
+ int(self.ptr_D),
420
+ batch_stride_A,
421
+ batch_stride_B,
422
+ batch_stride_C,
423
+ batch_stride_D,
424
+ self.lda, self.ldb, self.ldc, self.ldd, # strides
425
+ self.lda, self.ldb, self.ldc, self.ldd,
426
+ -1, # avail_sms
427
+ )
428
+ return arguments
429
+
430
+ def initialize(self):
431
+ # Get the host and device workspace
432
+ device_workspace_size = self.operation.rt_module.get_device_workspace_size(
433
+ self,
434
+ device_sm_count(),
435
+ self.operation.rt_module.occupancy
436
+ )
437
+
438
+ if device_workspace_size > 0:
439
+ self.workspace_buffer = device_mem_alloc(device_workspace_size)
440
+ workspace_ptr = self.workspace_buffer.ptr
441
+ err, = cuda.cuMemsetD32(
442
+ workspace_ptr, 0, device_workspace_size // 4)
443
+ else:
444
+ workspace_ptr = None
445
+
446
+ device_workspace = 0
447
+ if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
448
+ # In GEMM splik-K parallel, the D pointer is redirected to the workspace
449
+ self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
450
+ elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
451
+ device_workspace = workspace_ptr
452
+
453
+ arguments = self.get_arguments()
454
+
455
+ res_arg = self.operation.rt_module.get_args(
456
+ ctypes.byref(arguments),
457
+ ctypes.c_void_p(int(device_workspace)),
458
+ device_sm_count(),
459
+ self.operation.rt_module.occupancy
460
+ )
461
+ host_workspace = bytearray(res_arg.contents)
462
+
463
+ grid = self.operation.rt_module.get_grid_shape(
464
+ ctypes.byref(arguments),
465
+ device_sm_count(),
466
+ self.operation.rt_module.occupancy
467
+ )
468
+
469
+ device_workspace = None
470
+
471
+ self.host_workspace = host_workspace
472
+ self.device_workspace = device_workspace
473
+ self.launch_config = LaunchConfiguration(
474
+ [grid.m, grid.n, grid.k],
475
+ [self.operation.rt_module.threads, 1, 1],
476
+ self.operation.rt_module.shared_memory_capacity
477
+ )
478
+
479
+
480
+ class GemmArguments3x(GemmArguments2x):
481
+ """
482
+ Argument wrapper for GEMM in CUTLASS 3. It encodes problem information and
483
+ user-provide tensors into the kernel's argument
484
+
485
+ :param operation: the GEMM operation to take the argument
486
+ :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
487
+ :class:`cutlass_cppgen.backend.GemmOperationGrouped`
488
+
489
+ :param problem_size: GEMM problem size gemm(M, N, K)
490
+ :type operation: :class:`cutlass_cppgen.shape.GemmCoord`
491
+
492
+ :param A: tensor A
493
+ :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
494
+
495
+ :param B: tensor B
496
+ :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
497
+
498
+ :param C: tensor C
499
+ :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
500
+
501
+ :param D: tensor D
502
+ :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
503
+
504
+ :param gemm_mode: GEMM mode
505
+ :type gemm_mode: GemmUniversalMode
506
+
507
+ :param output_op: output operator, optional
508
+ :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
509
+ """
510
+
511
+ def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
512
+ if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]:
513
+ raise Exception(f"Unsupported GEMM mode {gemm_mode}.")
514
+
515
+ super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
516
+
517
+ def get_arguments(self):
518
+ mainloop_args = get_mainloop_arguments_3x(
519
+ self.operation.tile_description.kernel_schedule,
520
+ self.operation.A.element,
521
+ self.operation.B.element,
522
+ self.operation.A.alignment,
523
+ self.operation.B.alignment
524
+ )
525
+ scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler)
526
+ uses_default_epilogue = self.operation.rt_module.uses_default_epilogue()
527
+ argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x(
528
+ mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue)
529
+
530
+ problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count)
531
+
532
+ if self.batch_count > 1:
533
+ bsA = self.batched_stride_A
534
+ bsB = self.batched_stride_B
535
+ bsC = self.batched_stride_C
536
+ bsD = self.batched_stride_D
537
+ else:
538
+ bsA = 0
539
+ bsB = 0
540
+ bsC = 0
541
+ bsD = 0
542
+ stride_A = StrideBatched_(self.lda, bsA)
543
+ stride_B = StrideBatched_(self.ldb, bsB)
544
+ stride_C = StrideBatched_(self.ldc, bsC)
545
+ stride_D = StrideBatched_(self.ldd, bsD)
546
+
547
+ # Superset of potential mainloop arguments
548
+ generic_args = GenericMainloopArguments3x_(
549
+ int(self.ptr_A),
550
+ stride_A,
551
+ int(self.ptr_B),
552
+ stride_B,
553
+ 4 # mma_promotion_interval
554
+ )
555
+
556
+ # Set of mainloop arguments needed for this kernel
557
+ mainloop = mainloop_args.from_generic_mainloop_args(generic_args)
558
+
559
+ if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"):
560
+ self.output_op = self.output_op.to_evt_params()
561
+
562
+ epilogue = epilogue_args(
563
+ self.output_op,
564
+ int(self.ptr_C),
565
+ stride_C,
566
+ int(self.ptr_D),
567
+ stride_D,
568
+ )
569
+
570
+ # Set hardware info
571
+ hw_info_ = hw_info(
572
+ 0, device_sm_count(), 0,
573
+ dim3_(0,0,0),
574
+ dim3_(0,0,0),
575
+ )
576
+
577
+ self.arguments = argument_type(
578
+ int(self.gemm_mode),
579
+ problem_size_,
580
+ mainloop,
581
+ epilogue,
582
+ hw_info_,
583
+ scheduler_args
584
+ )
585
+ return self.arguments
586
+
587
+ def initialize(self):
588
+ # Get the host and evice workspace
589
+ device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
590
+
591
+ if device_workspace_size > 0:
592
+ self.workspace_buffer = device_mem_alloc(device_workspace_size)
593
+ workspace_ptr = self.workspace_buffer.ptr
594
+ err, = cuda.cuMemsetD32(
595
+ workspace_ptr, 0, device_workspace_size // 4)
596
+ else:
597
+ workspace_ptr = None
598
+
599
+ device_workspace = 0
600
+ if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
601
+ # In GEMM splik-K parallel, the D pointer is redirected to the workspace
602
+ self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
603
+ elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
604
+ device_workspace = workspace_ptr
605
+
606
+ self.get_arguments()
607
+ res_arg = self.operation.rt_module.get_args(
608
+ ctypes.byref(self.arguments),
609
+ ctypes.c_void_p(int(device_workspace)),
610
+ )
611
+ host_workspace = bytearray(res_arg.contents)
612
+
613
+ grid = self.operation.rt_module.get_grid_shape(
614
+ ctypes.byref(self.arguments),
615
+ ctypes.c_void_p(int(device_workspace)),
616
+ )
617
+ block = self.operation.rt_module.get_block_shape()
618
+
619
+ device_workspace = None
620
+
621
+ self.host_workspace = host_workspace
622
+ self.device_workspace = device_workspace
623
+ self.launch_config = LaunchConfiguration(
624
+ [grid.x, grid.y, grid.z],
625
+ [block.x, block.y, block.z],
626
+ self.operation.rt_module.shared_memory_capacity,
627
+ )
628
+
629
+
630
+ def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
631
+ """
632
+ Argument wrapper for GEMM in CUTLASS 2 or 3. It returns either 2x arguments
633
+ or 3x arguments depending on the `arch` field specified in `operation`.
634
+
635
+ :param operation: the GEMM operation to take the argument
636
+ :type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
637
+ :class:`cutlass_cppgen.backend.GemmOperationGrouped`
638
+
639
+ :param problem_size: GEMM problem size gemm(M, N, K)
640
+ :type operation: :class:`cutlass_cppgen.shape.GemmCoord`
641
+
642
+ :param A: tensor A
643
+ :type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
644
+
645
+ :param B: tensor B
646
+ :type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
647
+
648
+ :param C: tensor C
649
+ :type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
650
+
651
+ :param D: tensor D
652
+ :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
653
+
654
+ :param gemm_mode: GEMM mode
655
+ :type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
656
+
657
+ :param output_op: output operator, optional
658
+ :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
659
+ """
660
+ if operation.swizzling_functor == SwizzlingFunctor.StreamK:
661
+ if operation.api == ApiVersion.v3x:
662
+ raise Exception("Stream K is currently only supported in CUTLASS 2.x")
663
+ ArgClass = GemmArguments2xStreamK
664
+ else:
665
+ ArgClass = GemmArguments3x if operation.api == ApiVersion.v3x else GemmArguments2x
666
+ return ArgClass(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
667
+
668
+
669
+ class GemmGroupedArguments:
670
+ """
671
+ Argument wrapper for GEMM Grouped. It encodes problem information and
672
+ user-provide tensors into the kernel's argument
673
+
674
+ :param operation: the GEMM Grouped operation to take the argument
675
+ :type operation: :class:`cutlass_cppgen.backend.GemmOperationGrouped`
676
+
677
+ :param problem_size: list of GEMM problem size gemm(M, N, K)
678
+ :type operation: list[:class:`cutlass_cppgen.shape.GemmCoord`]
679
+
680
+ :param A: list of tensor A
681
+ :type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
682
+
683
+ :param B: list of tensor B
684
+ :type B: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
685
+
686
+ :param C: list of tensor C
687
+ :type C: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
688
+
689
+ :param D: list of tensor D
690
+ :type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
691
+
692
+ :param output_op: output operator, optional
693
+ :type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
694
+
695
+ :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
696
+ :type stream: :class:`cuda.cuda.CUstream`
697
+ """
698
+
699
+ def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs):
700
+ # Get number of problems in the group
701
+ self.problem_count = len(problem_sizes)
702
+
703
+ # Check the input arguments
704
+ assert len(A) == self.problem_count
705
+ assert len(B) == self.problem_count
706
+ assert len(C) == self.problem_count
707
+ assert len(D) == self.problem_count
708
+
709
+ problem_size_host = []
710
+ self.ptr_A_host = []
711
+ self.ptr_B_host = []
712
+ self.ptr_C_host = []
713
+ self.ptr_D_host = []
714
+
715
+ lda_host = []
716
+ ldb_host = []
717
+ ldc_host = []
718
+ ldd_host = []
719
+
720
+ self.partitions = 1
721
+
722
+ self.operation = operation
723
+
724
+ # Get the threadblock
725
+ threadblock_shape = operation.tile_description.threadblock_shape
726
+ self.threadblock_shape = GemmCoord(
727
+ threadblock_shape[0],
728
+ threadblock_shape[1],
729
+ threadblock_shape[2],
730
+ )
731
+ self.threadblock_swizzle = operation.swizzling_functor
732
+
733
+ self.total_tiles = 0
734
+
735
+ self.gemm_arguments = []
736
+
737
+ self.stream = kwargs.get("stream", cuda.CUstream(0))
738
+
739
+ # Process the input arguments
740
+ for idx, problem_size in enumerate(problem_sizes):
741
+ M, N, K = problem_size.m, problem_size.n, problem_size.k
742
+ temp_argument = GemmArguments2x(
743
+ operation=operation,
744
+ problem_size=GemmCoord(M, N, K),
745
+ A=A[idx], B=B[idx], C=C[idx], D=D[idx])
746
+ self.gemm_arguments.append(temp_argument)
747
+
748
+ problem_size_host.append(
749
+ [temp_argument.problem_size.m,
750
+ temp_argument.problem_size.n,
751
+ temp_argument.problem_size.k]
752
+ )
753
+
754
+ self.ptr_A_host.append(int(temp_argument.ptr_A))
755
+ lda_host.append(temp_argument.lda)
756
+
757
+ self.ptr_B_host.append(int(temp_argument.ptr_B))
758
+ ldb_host.append(temp_argument.ldb)
759
+
760
+ self.ptr_C_host.append(int(temp_argument.ptr_C))
761
+ ldc_host.append(temp_argument.ldc)
762
+
763
+ self.ptr_D_host.append(int(temp_argument.ptr_D))
764
+ ldd_host.append(temp_argument.ldd)
765
+
766
+ # Get number of tiles
767
+ grid = self.operation.rt_module.get_grid_shape(
768
+ self.operation.rt_module.get_tiled_shape(
769
+ temp_argument.problem_size.ctype,
770
+ self.threadblock_shape.ctype,
771
+ temp_argument.batch_count
772
+ )
773
+ )
774
+ self.total_tiles += grid.x * grid.y * grid.z
775
+
776
+ self.problem_size_buffer = todevice(problem_size_host, np.int32)
777
+ self.ptr_A_buffer = todevice(self.ptr_A_host, np.int64)
778
+ self.ptr_B_buffer = todevice(self.ptr_B_host, np.int64)
779
+ self.ptr_C_buffer = todevice(self.ptr_C_host, np.int64)
780
+ self.ptr_D_buffer = todevice(self.ptr_D_host, np.int64)
781
+
782
+ self.lda_buffer = todevice(lda_host, np.int64)
783
+ self.ldb_buffer = todevice(ldb_host, np.int64)
784
+ self.ldc_buffer = todevice(ldc_host, np.int64)
785
+ self.ldd_buffer = todevice(ldd_host, np.int64)
786
+
787
+ if "output_op" in kwargs.keys():
788
+ self.alpha = kwargs["output_op"].alpha
789
+ self.beta = kwargs["output_op"].beta
790
+ else:
791
+ self.alpha = 1.0
792
+ self.beta = 0.0
793
+
794
+ if "output_op" in kwargs.keys():
795
+ self.output_op = kwargs["output_op"]
796
+ else:
797
+ self.output_op = self.operation.epilogue_type(1.0, 0.0)
798
+
799
+ # Get host problem size
800
+ self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0]
801
+
802
+ self.arguments = self.get_arguments()
803
+
804
+ self.initialize()
805
+
806
+ def get_arguments(self):
807
+ return self.operation.argument_type(
808
+ self.problem_size_buffer.ptr,
809
+ self.problem_count,
810
+ self.total_tiles,
811
+ self.output_op,
812
+ self.ptr_A_buffer.ptr,
813
+ self.ptr_B_buffer.ptr,
814
+ self.ptr_C_buffer.ptr,
815
+ self.ptr_D_buffer.ptr,
816
+ self.lda_buffer.ptr,
817
+ self.ldb_buffer.ptr,
818
+ self.ldc_buffer.ptr,
819
+ self.ldd_buffer.ptr,
820
+ ctypes.c_void_p(int(self.host_problem_size_ptr)),
821
+ )
822
+
823
+ def initialize(self):
824
+ # Get launch configuration
825
+ launch_config = self.operation.rt_module.plan(self)
826
+
827
+ # Get the host and evice workspace
828
+ device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
829
+
830
+ if device_workspace_size > 0:
831
+ self.workspace_buffer = device_mem_alloc(device_workspace_size)
832
+ workspace_ptr = self.workspace_buffer.ptr
833
+ err, = cuda.cuMemsetD32(
834
+ workspace_ptr, 0, device_workspace_size // 4)
835
+ else:
836
+ workspace_ptr = None
837
+
838
+ if self.operation.precompute_mode == SchedulerMode.Host:
839
+ device_workspace_ptr = self.operation.rt_module.host_precompute(
840
+ self, self.operation.rt_module.get_workspace_size(self),)
841
+ else:
842
+ device_workspace_ptr = 0
843
+
844
+ result = self.operation.rt_module.get_args(
845
+ ctypes.byref(self.arguments),
846
+ self.total_tiles,
847
+ ctypes.c_void_p(int(device_workspace_ptr)),
848
+ )
849
+ host_workspace = bytearray(result.contents)
850
+
851
+ device_workspace = None
852
+
853
+ self.host_workspace = host_workspace
854
+ self.device_workspace = device_workspace
855
+ self.launch_config = launch_config
856
+
857
+ def sync(self):
858
+ err, = cudart.cudaDeviceSynchronize()
859
+ if err != cuda.CUresult.CUDA_SUCCESS:
860
+ raise RuntimeError("CUDA Error %s" % str(err))
861
+ for arg in self.gemm_arguments:
862
+ arg.sync(stream_sync=False)
863
+
864
+
865
+ ################################################################################
866
+ # Base class for GEMM runtime module
867
+ ################################################################################
868
+
869
+
870
+ class GemmRTbase(ExecutableOperation):
871
+ """
872
+ GemmRT manages the CUTLASS runtime components
873
+ """
874
+
875
+ KernelTemplate = r"""
876
+ extern "C"
877
+ __global__ void
878
+ ${operation_name}(${operation_name}${operation_suffix}::Params params) {
879
+
880
+ // Dynamic shared memory base pointer
881
+ extern __shared__ int SharedStorageBase[];
882
+
883
+ // Declare pointer to dynamic shared memory.
884
+ ${operation_name}${operation_suffix}::SharedStorage *shared_storage =
885
+ reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
886
+
887
+ ${operation_name}${operation_suffix}::invoke(params, *shared_storage);
888
+ }
889
+ """
890
+
891
+ def __init__(self, operation: "GemmOperation"):
892
+ super().__init__(operation)
893
+
894
+ self.operation = operation
895
+ threadblock_shape = operation.tile_description.threadblock_shape
896
+ self.threadblock_shape = GemmCoord(
897
+ threadblock_shape[0], threadblock_shape[1], threadblock_shape[2])
898
+ self.threadblock_swizzle = operation.swizzling_functor
899
+
900
+ # Threads per threadblock
901
+ self.threads = operation.tile_description.num_threads
902
+
903
+ def emit(self):
904
+ return self.emitter.emit(self.operation)
905
+
906
+ def can_implement(self, configuration, arguments):
907
+ raise NotImplementedError()
908
+
909
+ def get_host_workspace_size(self, arguments):
910
+ raise NotImplementedError()
911
+
912
+ def get_device_workspace_size(self, arguments):
913
+ return 0
914
+
915
+ def initialize(self):
916
+ err, = cuda.cuFuncSetAttribute(
917
+ self.kernel,
918
+ attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
919
+ value=self.shared_memory_capacity)
920
+ if err != cuda.CUresult.CUDA_SUCCESS:
921
+ raise RuntimeError(
922
+ f"CUDA error on call to cuFuncSetAttribute: {cuda.cuGetErrorString(err)[1]}"
923
+ )
924
+
925
+
926
+ ################################################################################
927
+ # Runtime module for GEMM Universal
928
+ ################################################################################
929
+
930
+
931
+ class GemmRTUniversal(GemmRTbase):
932
+ """
933
+ GemmRTUniversal manages the CUTLASS runtime components
934
+ """
935
+
936
+ HostTemplate = r"""
937
+ extern "C" {
938
+ // Get the size of params in bytes
939
+ int ${operation_name}_get_param_size(){
940
+ return sizeof(${operation_name}${operation_suffix}::Params);
941
+ }
942
+
943
+ // Get the size of dynamic shared memory in bytes
944
+ int ${operation_name}_shared_memory_size() {
945
+ return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
946
+ }
947
+
948
+ // Get the params as byte array
949
+ char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){
950
+ ${operation_name}_base::Params* params;
951
+ params = new ${operation_name}_base::Params(*argument,
952
+ -1, // SM count. Only used for stream-K
953
+ -1 // Occupancy. Only used for stream-K
954
+ );
955
+
956
+ // Semaphore holds the pointer to the workspace in the Params struct
957
+ params->semaphore = workspace;
958
+
959
+ char *bytes = ((char*)(params));
960
+ char *output = new char[sizeof(${operation_name}_base::Params)];
961
+ for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++)
962
+ output[i] = bytes[i];
963
+
964
+ return output;
965
+ }
966
+
967
+ cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape(
968
+ cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) {
969
+ return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape(
970
+ problem_size, tile_size, split_k_slices);
971
+ }
972
+
973
+ dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) {
974
+ return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape);
975
+ }
976
+ }
977
+ """
978
+
979
+ def __init__(self, operation):
980
+ super(GemmRTUniversal, self).__init__(operation)
981
+ self.extra_funcs = {
982
+ "get_tiled_shape": GemmCoord_,
983
+ "get_grid_shape": dim3_,
984
+ }
985
+ self.emitter = EmitGemmUniversalInstance(
986
+ "_type", operation.direct_store)
987
+
988
+ self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor)
989
+ self.argtype = [
990
+ ctypes.POINTER(self.argument_type),
991
+ ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p
992
+ ]
993
+
994
+ def plan(self, arguments):
995
+ grid = self.get_tiled_shape(
996
+ arguments.problem_size.ctype,
997
+ self.threadblock_shape.ctype,
998
+ arguments.batch_count
999
+ )
1000
+
1001
+ gemm_k_size = arguments.problem_size.k
1002
+ if arguments.gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]:
1003
+ alignk = max(max(128 // DataTypeSize[self.operation.A.element],
1004
+ 128 // DataTypeSize[self.operation.B.element]), 1)
1005
+
1006
+ gemm_k_size = (((arguments.problem_size.k + arguments.batch_count - 1) //
1007
+ arguments.batch_count + alignk - 1) // alignk) * alignk
1008
+
1009
+ if gemm_k_size:
1010
+ grid_z = (arguments.problem_size.k + gemm_k_size - 1) // gemm_k_size
1011
+ grid = GemmCoord(grid.m, grid.n, grid_z).ctype
1012
+
1013
+ arguments.grid_tiled_shape = dim3_(grid.m, grid.n, grid.k)
1014
+ grid = self.get_grid_shape(grid)
1015
+ arguments.gemm_k_size = gemm_k_size
1016
+ return LaunchConfiguration(
1017
+ [grid.x, grid.y, grid.z],
1018
+ [self.threads, 1, 1],
1019
+ self.shared_memory_capacity)
1020
+
1021
+ def get_device_workspace_size(self, arguments: GemmArguments):
1022
+ workspace_bytes = 0
1023
+ if arguments.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
1024
+ workspace_bytes = (DataTypeSize[arguments.operation.C.element]
1025
+ * arguments.batched_stride_D * arguments.grid_tiled_shape.z // 8)
1026
+ elif (arguments.gemm_mode == GemmUniversalMode.Gemm and
1027
+ arguments.split_k_slices > 1):
1028
+ workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y
1029
+
1030
+ return workspace_bytes
1031
+
1032
+
1033
+ class GemmRTUniversalStreamK(GemmRTUniversal):
1034
+ """
1035
+ Manages the CUTLASS runtime components for 2.x stream K kernels
1036
+ """
1037
+
1038
+ HostTemplate = r"""
1039
+ extern "C" {
1040
+ // Get the size of params in bytes
1041
+ int ${operation_name}_get_param_size(){
1042
+ return sizeof(${operation_name}${operation_suffix}::Params);
1043
+ }
1044
+
1045
+ // Get the size of dynamic shared memory in bytes
1046
+ int ${operation_name}_shared_memory_size() {
1047
+ return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
1048
+ }
1049
+
1050
+ using GemmType = ${operation_name}_base;
1051
+
1052
+ // Get the params as byte array
1053
+ char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace,
1054
+ int sm_count, int occupancy) {
1055
+ GemmType::Params* params;
1056
+ params = new GemmType::Params(*argument, sm_count, occupancy);
1057
+
1058
+ params->init_workspace(workspace);
1059
+
1060
+ char *bytes = ((char*)(params));
1061
+ char *output = new char[sizeof(GemmType::Params)];
1062
+ for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++)
1063
+ output[i] = bytes[i];
1064
+
1065
+ return output;
1066
+ }
1067
+
1068
+ dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int device_sms, int sm_occupancy) {
1069
+ typename GemmType::Params params(*args, device_sms, sm_occupancy);
1070
+ return params.get_grid_dims();
1071
+ }
1072
+
1073
+ uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* args, int device_sms, int sm_occupancy) {
1074
+ typename GemmType::Params params(*args, device_sms, sm_occupancy);
1075
+ return params.get_workspace_size();
1076
+ }
1077
+ }
1078
+ """
1079
+
1080
+ def __init__(self, operation: "GemmOperation"):
1081
+ super(GemmRTUniversalStreamK, self).__init__(operation)
1082
+ self.extra_funcs = {
1083
+ "get_grid_shape": GemmCoord_,
1084
+ "get_kernel_workspace_size": ctypes.c_uint64,
1085
+ }
1086
+ self._occupancy = None
1087
+ self.argument_type, self.epilogue_type = get_gemm_arguments_streamk(operation.epilogue_functor)
1088
+
1089
+ @property
1090
+ def occupancy(self):
1091
+ if self._occupancy is None:
1092
+ err, self._occupancy = cuda.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
1093
+ self.kernel, self.threads, self.shared_memory_capacity,
1094
+ cuda.CUoccupancy_flags.CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE)
1095
+
1096
+ if err != cuda.CUresult.CUDA_SUCCESS:
1097
+ raise RuntimeError(
1098
+ "CUDA error on call to cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: "
1099
+ f"{cuda.cuGetErrorString(err)[1]}")
1100
+ return self._occupancy
1101
+
1102
+ def get_device_workspace_size(self, arguments: GemmArguments2xStreamK, device_sms: int, sm_occupancy: int):
1103
+ return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()), device_sms, sm_occupancy)
1104
+
1105
+
1106
+ ################################################################################
1107
+ # Runtime module for GEMM Universal within CUTLASS 3
1108
+ ################################################################################
1109
+
1110
+
1111
+ class GemmRTUniversal3x(GemmRTUniversal):
1112
+ """
1113
+ Manages the CUTLASS runtime components for 3.x kernels
1114
+ """
1115
+
1116
+ KernelTemplate = r"""
1117
+
1118
+ using Operator = ${operation_name}${operation_suffix};
1119
+ extern "C"
1120
+ __global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor)
1121
+ void ${operation_name}(__grid_constant__ typename Operator::Params const params) {
1122
+ // Dynamic shared memory base pointer
1123
+ extern __shared__ char smem[];
1124
+
1125
+ // Declare pointer to dynamic shared memory.
1126
+ Operator op;
1127
+ op(params, smem);
1128
+ }
1129
+ """
1130
+ HostTemplate = r"""
1131
+ extern "C" {
1132
+ // Get the size of params in bytes
1133
+ int ${operation_name}_get_param_size(){
1134
+ return sizeof(${operation_name}${operation_suffix}::Params);
1135
+ }
1136
+
1137
+ // Get the size of dynamic shared memory in bytes
1138
+ int ${operation_name}_shared_memory_size() {
1139
+ return ${operation_name}${operation_suffix}::SharedStorageSize;
1140
+ }
1141
+
1142
+ using GemmType = ${operation_name}_base;
1143
+
1144
+ bool ${operation_name}_uses_default_epilogue() {
1145
+ return std::is_same_v<GemmType::CollectiveEpilogue::DispatchPolicy, cutlass::gemm::EpilogueDefault>;
1146
+ }
1147
+
1148
+ // Get the workspace size
1149
+ uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
1150
+ return GemmType::get_workspace_size(*argument);
1151
+ }
1152
+
1153
+ // Get the params as byte array
1154
+ char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){
1155
+ GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace);
1156
+ char *bytes = ((char*)(&params));
1157
+ char *output = new char[sizeof(GemmType::Params)];
1158
+ for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++)
1159
+ output[i] = bytes[i];
1160
+
1161
+ return output;
1162
+ }
1163
+
1164
+ // Get the total number of blocks for a persistent kernel
1165
+ uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) {
1166
+ auto problem_shape_MNKL = append<4>(problem, Int<1>{});
1167
+ auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] =
1168
+ cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl(
1169
+ problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{});
1170
+ return problem_blocks_m * problem_blocks_n * problem_blocks_l;
1171
+ }
1172
+
1173
+ // Get the grid shape
1174
+ dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int* workspace) {
1175
+ auto tmp_params = GemmType::to_underlying_arguments(*args, workspace);
1176
+ return GemmType::get_grid_shape(tmp_params);
1177
+ }
1178
+
1179
+ // Get the block shape
1180
+ dim3 ${operation_name}_get_block_shape() {
1181
+ return GemmType::get_block_shape();
1182
+ }
1183
+ }
1184
+ """
1185
+
1186
+ def __init__(self, operation):
1187
+ super(GemmRTUniversal3x, self).__init__(operation)
1188
+ self.extra_funcs = {
1189
+ "get_grid_shape": dim3_,
1190
+ "get_block_shape": dim3_,
1191
+ "get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
1192
+ "get_kernel_workspace_size": ctypes.c_uint64,
1193
+ "uses_default_epilogue": ctypes.c_bool,
1194
+ }
1195
+ self.emitter = EmitGemmUniversalInstance3x("_type")
1196
+
1197
+ def get_device_workspace_size(self, arguments: GemmArguments3x):
1198
+ return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))
1199
+
1200
+
1201
+ class EmitGemmUniversalInstance3x:
1202
+ """Responsible for emitting a CUTLASS 3 template definition"""
1203
+
1204
+ def __init__(self, operation_suffix=""):
1205
+ self.operation_suffix = operation_suffix
1206
+ self.includes = [
1207
+ "cutlass/cutlass.h",
1208
+ "cute/tensor.hpp",
1209
+ "cute/atom/mma_atom.hpp",
1210
+ "cutlass/numeric_types.h",
1211
+ "cutlass/gemm/collective/collective_builder.hpp",
1212
+ "cutlass/gemm/kernel/sm90_tile_scheduler.hpp",
1213
+ "cutlass/gemm/kernel/gemm_universal.hpp",
1214
+ "cutlass/epilogue/collective/collective_builder.hpp",
1215
+ "cutlass/epilogue/collective/default_epilogue.hpp",
1216
+ "cutlass/epilogue/thread/linear_combination.h"
1217
+ ]
1218
+ self.gemm_template_kernel = """
1219
+ using namespace cute;
1220
+
1221
+ using CollectiveEpilogue =
1222
+ typename cutlass::epilogue::collective::CollectiveBuilder<
1223
+ ${arch}, ${opcode_class},
1224
+ cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
1225
+ cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
1226
+ cutlass::epilogue::collective::EpilogueTileAuto,
1227
+ ${element_accumulator}, ${element_epilogue},
1228
+ ${element_c}, ${layout_c}, ${align_c},
1229
+ ${element_d}, ${layout_d}, ${align_d},
1230
+ ${epilogue_schedule}
1231
+ >::CollectiveOp;
1232
+
1233
+ using CollectiveMainloop =
1234
+ typename cutlass::gemm::collective::CollectiveBuilder<
1235
+ ${arch}, ${opcode_class},
1236
+ ${element_a}, ${layout_a}, ${align_a},
1237
+ ${element_b}, ${layout_b}, ${align_b},
1238
+ ${element_accumulator},
1239
+ cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
1240
+ cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
1241
+ ${stage_count_type},
1242
+ ${kernel_schedule}
1243
+ >::CollectiveOp;
1244
+
1245
+ // Gemm operator ${operation_name}
1246
+ using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
1247
+ Shape<int,int,int,int>,
1248
+ CollectiveMainloop,
1249
+ CollectiveEpilogue,
1250
+ ${tile_scheduler}
1251
+ >;
1252
+
1253
+ // Define named type
1254
+ struct ${operation_name}${operation_suffix} :
1255
+ public ${operation_name}_base { };
1256
+ """
1257
+ self.gemm_template_kernel_visitor = """
1258
+ using namespace cute;
1259
+
1260
+ ${callback_decl}
1261
+
1262
+ using CollectiveEpilogue =
1263
+ typename cutlass::epilogue::collective::CollectiveBuilder<
1264
+ ${arch}, ${opcode_class},
1265
+ cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
1266
+ cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
1267
+ cutlass::epilogue::collective::EpilogueTileAuto,
1268
+ ${element_accumulator}, ${element_epilogue},
1269
+ ElementC, StrideC, ${align_c},
1270
+ ElementD, StrideD, ${align_d},
1271
+ ${epilogue_schedule},
1272
+ ${callback_name}
1273
+ >::CollectiveOp;
1274
+
1275
+ using CollectiveMainloop =
1276
+ typename cutlass::gemm::collective::CollectiveBuilder<
1277
+ ${arch}, ${opcode_class},
1278
+ ${element_a}, ${layout_a}, ${align_a},
1279
+ ${element_b}, ${layout_b}, ${align_b},
1280
+ ${element_accumulator},
1281
+ cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
1282
+ cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
1283
+ ${stage_count_type},
1284
+ ${kernel_schedule}
1285
+ >::CollectiveOp;
1286
+
1287
+ // Gemm operator ${operation_name}
1288
+ using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
1289
+ Shape<int,int,int,int>,
1290
+ CollectiveMainloop,
1291
+ CollectiveEpilogue,
1292
+ ${tile_scheduler}
1293
+ >;
1294
+
1295
+ // Define named type
1296
+ struct ${operation_name}${operation_suffix} :
1297
+ public ${operation_name}_base { };
1298
+ """
1299
+
1300
+ self.gemm_template_device = self.gemm_template_kernel + """
1301
+
1302
+ // Define device-level operator
1303
+ using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}${operation_suffix}>;
1304
+ """
1305
+
1306
+ def emit(self, operation):
1307
+ # Support built-in epilogue functors or user-defined functions
1308
+
1309
+ if operation.tile_description.stages is None or operation.tile_description.stages == 0:
1310
+ stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>"
1311
+ else:
1312
+ stage_count_type = "_" + str(operation.tile_description.stages)
1313
+
1314
+ if operation.emission_type == EmissionType.Kernel:
1315
+ gemm_template = self.gemm_template_kernel
1316
+ else:
1317
+ gemm_template = self.gemm_template_device
1318
+
1319
+ kschedule = KernelScheduleType.ScheduleAuto
1320
+ eschedule = EpilogueScheduleType.ScheduleAuto
1321
+ tschedule = TileSchedulerType.Default
1322
+ if operation.tile_description.kernel_schedule is not None:
1323
+ kschedule = operation.tile_description.kernel_schedule
1324
+ if operation.tile_description.epilogue_schedule is not None:
1325
+ eschedule = operation.tile_description.epilogue_schedule
1326
+ if operation.tile_description.tile_scheduler is not None:
1327
+ tschedule = operation.tile_description.tile_scheduler
1328
+
1329
+ emit_tile_m, emit_tile_n, emit_tile_k = operation.tile_description.blackwell_threadblock_shape
1330
+
1331
+ values = {
1332
+ "operation_name": operation.procedural_name(),
1333
+ "operation_suffix": self.operation_suffix,
1334
+ "element_a": DataTypeTag[operation.A.element],
1335
+ "layout_a": LayoutTag[operation.A.layout],
1336
+ "element_b": DataTypeTag[operation.B.element],
1337
+ "layout_b": LayoutTag[operation.B.layout],
1338
+ "element_c": DataTypeTag[operation.C.element],
1339
+ "layout_c": LayoutTag[operation.C.layout],
1340
+ "element_d": DataTypeTag[operation.epilogue_functor.element_output],
1341
+ "layout_d": LayoutTag[operation.C.layout],
1342
+ "element_accumulator": DataTypeTag[operation.accumulator_type()],
1343
+ "element_epilogue": DataTypeTag[operation.epilogue_functor.element_epilogue],
1344
+ "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
1345
+ "arch": "cutlass::arch::Sm%d" % operation.arch,
1346
+ "threadblock_shape_m": str(emit_tile_m),
1347
+ "threadblock_shape_n": str(emit_tile_n),
1348
+ "threadblock_shape_k": str(emit_tile_k),
1349
+ "cluster_m": str(operation.tile_description.cluster_shape[0]),
1350
+ "cluster_n": str(operation.tile_description.cluster_shape[1]),
1351
+ "cluster_k": str(operation.tile_description.cluster_shape[2]),
1352
+ "align_a": str(operation.A.alignment),
1353
+ "align_b": str(operation.B.alignment),
1354
+ "align_c": str(operation.C.alignment),
1355
+ "align_d": str(operation.C.alignment),
1356
+ "stage_count_type": stage_count_type,
1357
+ "kernel_schedule": KernelScheduleTag[kschedule],
1358
+ "epilogue_schedule": EpilogueScheduleTag[eschedule],
1359
+ "tile_scheduler": TileSchedulerTag[tschedule]
1360
+ }
1361
+ if hasattr(operation.epilogue_functor, "visitor"):
1362
+ callback_name, callback_decl = operation.epilogue_functor.emit(operation)
1363
+ values["callback_name"] = callback_name
1364
+ values["callback_decl"] = callback_decl
1365
+ return SubstituteTemplate(self.gemm_template_kernel_visitor, values)
1366
+
1367
+ else:
1368
+ values["epilogue_functor"] = operation.epilogue_functor.emit()
1369
+ return SubstituteTemplate(gemm_template, values)
1370
+
1371
+
1372
+ ###################################################################################################
1373
+ # Runtime module for GEMM Grouped
1374
+ ###################################################################################################
1375
+
1376
+
1377
+ class GemmRTGrouped(GemmRTbase):
1378
+ """
1379
+ GemmRTGrouped manages the CUTLASS runtime components
1380
+ """
1381
+
1382
+ KernelTemplate = r"""
1383
+ extern "C"
1384
+ __global__ void
1385
+ ${operation_name}(${operation_name}${operation_suffix}::Params params) {
1386
+
1387
+ // Dynamic shared memory base pointer
1388
+ extern __shared__ int SharedStorageBase[];
1389
+
1390
+ // Declare pointer to dynamic shared memory.
1391
+ ${operation_name}${operation_suffix}::SharedStorage *shared_storage =
1392
+ reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
1393
+
1394
+ ${operation_name}${operation_suffix} op;
1395
+
1396
+ op(params, *shared_storage);
1397
+ }
1398
+ """
1399
+
1400
+ HostTemplate = r"""
1401
+ extern "C" {
1402
+
1403
+ // precompute scheduling information
1404
+ char * ${operation_name}_precompute(${operation_name}_base::Arguments const &args, int tile_count, size_t workspace_bytes) {
1405
+ char* host_workspace = new char[workspace_bytes];
1406
+ ${operation_name}_base::ProblemVisitor::host_precompute(
1407
+ args.host_problem_sizes,
1408
+ args.problem_count,
1409
+ args.threadblock_count,
1410
+ (void*)host_workspace
1411
+ );
1412
+ return host_workspace;
1413
+ }
1414
+
1415
+ // Get the size of params in bytes
1416
+ int ${operation_name}_get_param_size(){
1417
+ return sizeof(${operation_name}${operation_suffix}::Params);
1418
+ }
1419
+
1420
+ // Get the size of dynamic shared memory in bytes
1421
+ int ${operation_name}_shared_memory_size() {
1422
+ return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
1423
+ }
1424
+
1425
+ // Get the params as byte array
1426
+ char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int tile_count, void* workspace=nullptr){
1427
+ ${operation_name}_base::Params* params;
1428
+ params = new ${operation_name}_base::Params(*argument, workspace, tile_count);
1429
+
1430
+ char *bytes = ((char*)(params));
1431
+ char *output = new char[sizeof(${operation_name}_base::Params)];
1432
+ for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++)
1433
+ output[i] = bytes[i];
1434
+
1435
+ return output;
1436
+ }
1437
+
1438
+ cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape(
1439
+ cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) {
1440
+ return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape(
1441
+ problem_size, tile_size, split_k_slices);
1442
+ }
1443
+
1444
+ dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) {
1445
+ return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape);
1446
+ }
1447
+ }
1448
+ """
1449
+
1450
+ def __init__(self, operation: "GemmOperation"):
1451
+ super(GemmRTGrouped, self).__init__(operation)
1452
+ self.extra_funcs = {
1453
+ "precompute": None,
1454
+ "get_tiled_shape": GemmCoord_,
1455
+ "get_grid_shape": dim3_,
1456
+ }
1457
+ self.emitter = EmitGemmGroupedInstance("_type")
1458
+ self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor)
1459
+ self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p]
1460
+
1461
+ def host_precompute(self, arguments, workspace_bytes):
1462
+ self.precompute.argtype = [
1463
+ self.argtype[0], ctypes.c_int, ctypes.c_longlong]
1464
+ self.precompute.restype = ctypes.POINTER(ctypes.c_byte * workspace_bytes)
1465
+
1466
+ problem_info = self.precompute(
1467
+ ctypes.byref(arguments.arguments),
1468
+ arguments.total_tiles,
1469
+ workspace_bytes)
1470
+ problem_info_array = bytearray(problem_info.contents)
1471
+
1472
+ # copy to device memory
1473
+ return todevice(problem_info_array).ptr
1474
+
1475
+ def plan(self, arguments):
1476
+ return LaunchConfiguration(
1477
+ [arguments.total_tiles, 1, 1],
1478
+ [self.threads, 1, 1],
1479
+ self.shared_memory_capacity,
1480
+ )
1481
+
1482
+ def get_workspace_size(self, arguments):
1483
+ if self.operation.precompute_mode == SchedulerMode.Device:
1484
+ return 0
1485
+ elif self.operation.precompute_mode == SchedulerMode.Host:
1486
+ total_tiles = arguments.total_tiles
1487
+ entries_per_block = 1
1488
+ return 8 * entries_per_block * total_tiles # three int32_t
1489
+
1490
+
1491
+ ################################################################################
1492
+ # Runtime module for GEMM and grouped GEMM
1493
+ ################################################################################
1494
+
1495
+
1496
+ class GemmOperationBase:
1497
+ """
1498
+ CUTLASS GEMM operation
1499
+ """
1500
+
1501
+ def __init__(
1502
+ self, gemm_kind, arch, tile_description: TileDescription,
1503
+ A: TensorDescription, B: TensorDescription, C: TensorDescription,
1504
+ epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1,
1505
+ api=ApiVersion.v2x, emission_type=EmissionType.Kernel, **kwargs):
1506
+ self.operation_kind: OperationKind = OperationKind.Gemm
1507
+ self.arch: int = arch
1508
+ self.tile_description: TileDescription = tile_description
1509
+ self.gemm_kind: GemmKind = gemm_kind
1510
+
1511
+ self.api = api
1512
+ self.prefix = "3x" if self.api == ApiVersion.v3x else ""
1513
+ self.emission_type = emission_type
1514
+
1515
+ # Optionally swap the TensorDescriptions for operands A and B and transpose their
1516
+ # layouts. This is needed to mimic the transpose performed by device::GemmUniversal.
1517
+ # The code below uses deep copy to avoid overwritting the original TensorDescription
1518
+ self.switched = (self.api != ApiVersion.v3x and
1519
+ self.emission_type == EmissionType.Kernel and
1520
+ C.layout == LayoutType.ColumnMajor)
1521
+
1522
+ self.A, self.B, self.C = GemmOperationBase.get_operands(A, B, C, self.switched)
1523
+
1524
+ self.epilogue_functor = epilogue_functor
1525
+ self.swizzling_functor = swizzling_functor
1526
+
1527
+ if "direct_store" in kwargs:
1528
+ self.direct_store = kwargs["direct_store"]
1529
+ else:
1530
+ self.direct_store = False
1531
+
1532
+ @staticmethod
1533
+ def get_operands(A: TensorDescription, B: TensorDescription, C: TensorDescription, swap: bool):
1534
+ """
1535
+ Makes copies of A, B, and C, and possibly transposes their order. If ``swap`` is set,
1536
+ A and B are swapped, and the layout of A, B, and C are transposed.
1537
+
1538
+ :param A: description of operand A
1539
+ :type A: TensorDescription
1540
+ :param B: description of operand B
1541
+ :type B: TensorDescription
1542
+ :param C: description of operand C
1543
+ :type C: TensorDescription
1544
+
1545
+ :return: descriptions of operands A, B, and C
1546
+ :rtype: tuple[TileDescription]
1547
+ """
1548
+ if swap:
1549
+ A_out = copy.deepcopy(B)
1550
+ B_out = copy.deepcopy(A)
1551
+ C_out = copy.deepcopy(C)
1552
+ A_out.layout = transpose_layout(A_out.layout)
1553
+ B_out.layout = transpose_layout(B_out.layout)
1554
+ C_out.layout = transpose_layout(C_out.layout)
1555
+ else:
1556
+ A_out = copy.deepcopy(A)
1557
+ B_out = copy.deepcopy(B)
1558
+ C_out = copy.deepcopy(C)
1559
+ return A_out, B_out, C_out
1560
+
1561
+ def run(self, arguments: GemmArguments) -> cuda.CUresult:
1562
+ """
1563
+ Configure and launch the cuda kernel with input arguments
1564
+ """
1565
+ if self.emission_type == EmissionType.Device:
1566
+ raise Exception('Running a kernel via PyCUTLASS is only enabled with emission type "Kernel"')
1567
+
1568
+ err = self.rt_module.run(
1569
+ arguments.host_workspace,
1570
+ arguments.device_workspace,
1571
+ arguments.launch_config,
1572
+ arguments.stream
1573
+ )
1574
+
1575
+ if err != cuda.CUresult.CUDA_SUCCESS:
1576
+ raise RuntimeError("CUDA Error %s" % str(err))
1577
+
1578
+ return err
1579
+
1580
+ def is_complex(self):
1581
+ complex_operators = [
1582
+ MathOperation.multiply_add_complex,
1583
+ MathOperation.multiply_add_complex_gaussian,
1584
+ MathOperation.multiply_add_complex_fast_f32,
1585
+ ]
1586
+ return self.tile_description.math_instruction.math_operation in complex_operators
1587
+
1588
+ def is_planar_complex(self):
1589
+ return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
1590
+
1591
+ def accumulator_type(self):
1592
+ accum = self.tile_description.math_instruction.element_accumulator
1593
+
1594
+ if self.is_complex():
1595
+ return get_complex_from_real(accum)
1596
+
1597
+ return accum
1598
+
1599
+ def short_math_name(self):
1600
+ if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
1601
+ return "g%s" % ShortDataTypeNames[self.accumulator_type()]
1602
+ return ShortDataTypeNames[self.accumulator_type()]
1603
+
1604
+ def core_name(self):
1605
+ """The basic operation kind is prefixed with a letter indicating the accumulation type."""
1606
+
1607
+ inst_shape = ""
1608
+ inst_operation = ""
1609
+ intermediate_type = ""
1610
+
1611
+ math_operations_map = {
1612
+ MathOperation.xor_popc: "xor",
1613
+ }
1614
+
1615
+ if (self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or
1616
+ self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp):
1617
+ math_op = self.tile_description.math_instruction.math_operation
1618
+ math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ""
1619
+
1620
+ if self.tile_description.math_instruction.instruction_shape is not None:
1621
+ if self.api == ApiVersion.v3x and self.arch >= 90:
1622
+ inst_shape = "%dx%dx%d" % tuple(
1623
+ self.tile_description.math_instruction.instruction_shape)
1624
+ else:
1625
+ inst_shape = "%d%d%d" % tuple(
1626
+ self.tile_description.math_instruction.instruction_shape)
1627
+ else:
1628
+ inst_shape = "Default"
1629
+ inst_shape += math_op_string
1630
+
1631
+ if (self.tile_description.math_instruction.element_a != self.A.element and
1632
+ self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator):
1633
+ intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
1634
+
1635
+ return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
1636
+
1637
+ def extended_name(self):
1638
+ """Append data types if they differ from compute type."""
1639
+ if self.is_complex():
1640
+ extended_name = "${core_name}"
1641
+ else:
1642
+ if (self.C.element != self.tile_description.math_instruction.element_accumulator and
1643
+ self.A.element != self.tile_description.math_instruction.element_accumulator):
1644
+ extended_name = "${element_c}_${core_name}_${element_a}"
1645
+ elif (self.C.element == self.tile_description.math_instruction.element_accumulator and
1646
+ self.A.element != self.tile_description.math_instruction.element_accumulator):
1647
+ extended_name = "${core_name}_${element_a}"
1648
+ else:
1649
+ extended_name = "${core_name}"
1650
+
1651
+ extended_name = SubstituteTemplate(extended_name, {
1652
+ "element_a": DataTypeNames[self.A.element],
1653
+ "element_c": DataTypeNames[self.C.element],
1654
+ "core_name": self.core_name(),
1655
+ })
1656
+
1657
+ return extended_name
1658
+
1659
+ def extended_name_3x(self):
1660
+ """Generates a string representing the MMA atom. Assumes accumulator type is C type."""
1661
+ extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
1662
+ element_a=DataTypeNames[self.A.element],
1663
+ element_b=DataTypeNames[self.B.element],
1664
+ element_acc=DataTypeNames[self.accumulator_type()],
1665
+ element_c=DataTypeNames[self.C.element],
1666
+ element_d=DataTypeNames[self.epilogue_functor.element_output],
1667
+ core_name=self.core_name())
1668
+ return extended_name
1669
+
1670
+ def layout_name(self):
1671
+ if self.is_complex() or self.is_planar_complex():
1672
+ return "%s%s" % (
1673
+ ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
1674
+ ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
1675
+ )
1676
+ return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
1677
+
1678
+ # Generates a short string representing the ABC layout tags (e.g. ntn or tnn)
1679
+ def layout_name_3x(self):
1680
+ if self.is_complex() or self.is_planar_complex():
1681
+ return "{}{}{}".format(
1682
+ ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
1683
+ ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
1684
+ ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)])
1685
+ else:
1686
+ return "{}{}{}".format(
1687
+ ShortLayoutTypeNames[self.A.layout],
1688
+ ShortLayoutTypeNames[self.B.layout],
1689
+ ShortLayoutTypeNames[self.C.layout])
1690
+
1691
+ # Generates a short string representing underlying kernel schedule type
1692
+ def kernel_schedule_name_3x(self):
1693
+ if self.tile_description.kernel_schedule is None:
1694
+ return KernelScheduleSuffixes[KernelScheduleType.ScheduleAuto]
1695
+ else:
1696
+ return KernelScheduleSuffixes[self.tile_description.kernel_schedule]
1697
+
1698
+ # Generates a short string representing underlying epilogue schedule type
1699
+ def epilogue_schedule_name_3x(self):
1700
+ if self.tile_description.epilogue_schedule is None:
1701
+ return EpilogueScheduleSuffixes[EpilogueScheduleType.ScheduleAuto]
1702
+ else:
1703
+ return EpilogueScheduleSuffixes[self.tile_description.epilogue_schedule]
1704
+
1705
+ def procedural_name(self):
1706
+ """The full procedural name indicates architecture, extended name, tile size, and layout."""
1707
+ opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
1708
+ if self.api == ApiVersion.v3x and self.arch >= 90:
1709
+ kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}"
1710
+ return kernel_name_template.format(
1711
+ p=self.prefix,
1712
+ ar=self.arch,
1713
+ op=opcode_class_name,
1714
+ ex=self.extended_name_3x(),
1715
+ tbm=self.tile_description.threadblock_shape[0],
1716
+ tbn=self.tile_description.threadblock_shape[1],
1717
+ tbk=self.tile_description.threadblock_shape[2],
1718
+ cm=self.tile_description.cluster_shape[0],
1719
+ cn=self.tile_description.cluster_shape[1],
1720
+ ck=self.tile_description.cluster_shape[2],
1721
+ l=self.tile_description.stages,
1722
+ s=self.layout_name_3x(),
1723
+ al=str(self.A.alignment),
1724
+ k=self.kernel_schedule_name_3x(),
1725
+ e=self.epilogue_schedule_name_3x()
1726
+ )
1727
+ else:
1728
+ threadblock = self.tile_description.procedural_name_2x()
1729
+ return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
1730
+ p=self.prefix,
1731
+ op=opcode_class_name,
1732
+ ex=self.extended_name(),
1733
+ tb=threadblock,
1734
+ l=self.layout_name(),
1735
+ a=str(self.A.alignment)
1736
+ )
1737
+
1738
+ def configuration_name(self):
1739
+ """The full procedural name indicates architecture, extended name, tile size, and layout."""
1740
+ return self.procedural_name()
1741
+
1742
+
1743
+ class GemmOperationUniversal(GemmOperationBase):
1744
+ def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
1745
+ epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs):
1746
+ api = api_version(arch, tile_description.math_instruction.opcode_class, A.element)
1747
+ super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description,
1748
+ A, B, C, epilogue_functor, swizzling_functor,
1749
+ api=api, **kwargs, )
1750
+ if api == ApiVersion.v3x:
1751
+ if swizzling_functor == SwizzlingFunctor.StreamK:
1752
+ raise Exception("Stream K swizzle functor is currently only supported for CUTLASS 2.x kernels")
1753
+ self.rt_module = GemmRTUniversal3x(self)
1754
+ else:
1755
+ if swizzling_functor == SwizzlingFunctor.StreamK:
1756
+ self.rt_module = GemmRTUniversalStreamK(self)
1757
+ else:
1758
+ self.rt_module = GemmRTUniversal(self)
1759
+ self.argument_type = self.rt_module.argument_type
1760
+ self.epilogue_type = self.rt_module.epilogue_type
1761
+
1762
+ def device_op(self):
1763
+ """
1764
+ Returns a new GemmOperationUniversal object that is constructed with emission type
1765
+ ``EmissionType.Device``. Since the device-emitted kernel does not require swapping,
1766
+ any swappng performed by the kernel-emitted operation is reversed.
1767
+
1768
+ :return: operation ready for device-level code emission
1769
+ :rtype: GemmUniversalOperation
1770
+ """
1771
+ A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched)
1772
+ return GemmOperationUniversal(self.arch, self.tile_description, A, B, C,
1773
+ self.epilogue_functor, self.swizzling_functor,
1774
+ emission_type=EmissionType.Device, direct_store=self.direct_store)
1775
+
1776
+
1777
+ class GemmOperationGrouped(GemmOperationBase):
1778
+ def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
1779
+ epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs):
1780
+ super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description,
1781
+ A, B, C, epilogue_functor, swizzling_functor, **kwargs)
1782
+ assert "precompute_mode" in kwargs.keys(), "missing keyword arguement 'precompute_mode'."
1783
+ self.precompute_mode = kwargs["precompute_mode"]
1784
+ self.rt_module = GemmRTGrouped(self)
1785
+ self.argument_type = self.rt_module.argument_type
1786
+ self.epilogue_type = self.rt_module.epilogue_type
1787
+
1788
+ def device_op(self):
1789
+ """
1790
+ Returns a new GemmOperationGrouped object that is constructed with emission type
1791
+ ``EmissionType.Device``. Since the device-emitted kernel does not require swapping,
1792
+ any swappng performed by the kernel-emitted operation is reversed.
1793
+
1794
+ :return: operation ready for device-level code emission
1795
+ :rtype: GemmOperationGrouped
1796
+ """
1797
+ A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched)
1798
+ return GemmOperationGrouped(
1799
+ self.arch, self.tile_description, A, B, C, self.epilogue_functor,
1800
+ self.swizzling_functor, emission_type=EmissionType.Device,
1801
+ direct_store=self.direct_store, precompute_mode=self.precompute_mode, )
1802
+
1803
+
1804
+ ###################################################################################################
1805
+ #
1806
+ # Emits single instances of a CUTLASS device-wide operator
1807
+ #
1808
+ ###################################################################################################
1809
+
1810
+
1811
+ class EmitGemmUniversalInstance:
1812
+ """Responsible for emitting a CUTLASS template definition"""
1813
+
1814
+ def __init__(
1815
+ self,
1816
+ operation_suffix="",
1817
+ direct_store=False
1818
+ ):
1819
+ self.operation_suffix = operation_suffix
1820
+ self.direct_store = direct_store
1821
+ self.includes = [
1822
+ "cutlass/cutlass.h",
1823
+ "cutlass/gemm_coord.h",
1824
+ "cutlass/numeric_types.h",
1825
+ "cutlass/arch/arch.h",
1826
+ "cutlass/arch/mma.h",
1827
+ "cutlass/layout/matrix.h",
1828
+ "cutlass/gemm/device/gemm.h",
1829
+ "cutlass/gemm/device/gemm_universal_adapter.h",
1830
+ "cutlass/gemm/kernel/default_gemm_universal.h",
1831
+ ]
1832
+ if self.direct_store:
1833
+ self.includes.append(
1834
+ "cutlass/epilogue/threadblock/default_epilogue_direct_store.h"
1835
+ )
1836
+ self.gemm_template_kernel = """
1837
+ // Gemm operator ${operation_name}
1838
+ using ${operation_name}_base =
1839
+ typename cutlass::gemm::kernel::DefaultGemmUniversal<
1840
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1841
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1842
+ ${element_c}, ${layout_c},
1843
+ ${element_accumulator},
1844
+ ${opcode_class},
1845
+ ${arch},
1846
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1847
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1848
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1849
+ ${epilogue_functor},
1850
+ ${swizzling_functor},
1851
+ ${stages},
1852
+ ${math_operation}
1853
+ >::GemmKernel;
1854
+
1855
+ // Define named type
1856
+ struct ${operation_name}${operation_suffix} :
1857
+ public ${operation_name}_base { };
1858
+ """
1859
+
1860
+ self.gemm_template_device = """
1861
+ // Gemm operator ${operation_name}
1862
+ using DeviceKernel =
1863
+ typename cutlass::gemm::device::GemmUniversal<
1864
+ // Data type and layout of operand A
1865
+ ${element_a}, ${layout_a},
1866
+ // Data type and layout of operand B
1867
+ ${element_b}, ${layout_b},
1868
+ // Data type and layout of operand C
1869
+ ${element_c}, ${layout_c},
1870
+ // Data type of accumulator
1871
+ ${element_accumulator},
1872
+ // Class of operation
1873
+ ${opcode_class},
1874
+ // Compute capability of the target kernel
1875
+ ${arch},
1876
+ // Threadblock tile shape
1877
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1878
+ // Warp tile shape
1879
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1880
+ // Instruction shape
1881
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1882
+ // Epilogue functor
1883
+ ${epilogue_functor},
1884
+ // Swizzling function
1885
+ ${swizzling_functor},
1886
+ // Number of pipeline stages
1887
+ ${stages},
1888
+ // Alignment of operands A and B
1889
+ ${align_a}, ${align_b},
1890
+ // Type of math operation
1891
+ ${math_operation},
1892
+ // Complex transform types of operands A and B
1893
+ ${transform_a}, ${transform_b}
1894
+ >;
1895
+ """
1896
+ self.gemm_template_direct_store = """
1897
+ // Gemm operator ${operation_name}
1898
+ using ${operation_name}_default =
1899
+ typename cutlass::gemm::kernel::DefaultGemmUniversal<
1900
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1901
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1902
+ ${element_c}, ${layout_c},
1903
+ ${element_accumulator},
1904
+ ${opcode_class},
1905
+ ${arch},
1906
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1907
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1908
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1909
+ ${epilogue_functor},
1910
+ ${swizzling_functor},
1911
+ ${stages},
1912
+ ${math_operation}
1913
+ >::GemmKernel;
1914
+
1915
+ using ${operation_name}_base =
1916
+ cutlass::gemm::kernel::GemmUniversal<
1917
+ ${operation_name}_default::Mma,
1918
+ cutlass::epilogue::threadblock::DefaultEpilogueDirectStore<
1919
+ ${operation_name}_default::Epilogue
1920
+ >::Epilogue,
1921
+ ${operation_name}_default::ThreadblockSwizzle
1922
+ >;
1923
+
1924
+ // Define named type
1925
+ struct ${operation_name}${operation_suffix} :
1926
+ public ${operation_name}_base { };
1927
+ """
1928
+ self.gemm_template_kernel_visitor = """
1929
+
1930
+ using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
1931
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1932
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1933
+ ${element_c},
1934
+ ${align_c},
1935
+ ${epilogue_stages} /* epilogue stages */
1936
+ >;
1937
+
1938
+ ${callback_decl}
1939
+
1940
+ // Gemm operator ${operation_name}
1941
+ using ${operation_name}_base =
1942
+ typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
1943
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
1944
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
1945
+ ${element_c}, ${layout_c}, ${align_c},
1946
+ ${element_accumulator},
1947
+ ${element_epilogue},
1948
+ ${opcode_class},
1949
+ ${arch},
1950
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
1951
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
1952
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
1953
+ ${callback_name},
1954
+ ${swizzling_functor},
1955
+ ${stages},
1956
+ ${math_operation},
1957
+ ${epilogue_stages} /* epilogue stages */
1958
+ >::GemmKernel;
1959
+
1960
+ // Define named type
1961
+ struct ${operation_name}${operation_suffix} :
1962
+ public ${operation_name}_base { };
1963
+ """
1964
+
1965
+ def instance_template(self):
1966
+ return """
1967
+ ${compile_guard_start}
1968
+ manifest.append(new ${gemm_kind}<
1969
+ cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
1970
+ >("${operation_name}"));
1971
+ ${compile_guard_end}
1972
+ """
1973
+
1974
+ def emit(self, operation):
1975
+ threadblock_shape = operation.tile_description.threadblock_shape
1976
+ warp_count = operation.tile_description.warp_count
1977
+
1978
+ warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
1979
+
1980
+ instance_layout_A, instance_layout_B, instance_layout_C = \
1981
+ (operation.A.layout, operation.B.layout, operation.C.layout)
1982
+
1983
+ if operation.emission_type == EmissionType.Kernel:
1984
+ if self.direct_store:
1985
+ gemm_template = self.gemm_template_direct_store
1986
+ else:
1987
+ gemm_template = self.gemm_template_kernel
1988
+ else:
1989
+ gemm_template = self.gemm_template_device
1990
+
1991
+ values = {
1992
+ "operation_name": operation.procedural_name(),
1993
+ "operation_suffix": self.operation_suffix,
1994
+ "element_a": DataTypeTag[operation.A.element],
1995
+ "layout_a": LayoutTag[instance_layout_A],
1996
+ "element_b": DataTypeTag[operation.B.element],
1997
+ "layout_b": LayoutTag[instance_layout_B],
1998
+ "element_c": DataTypeTag[operation.C.element],
1999
+ "layout_c": LayoutTag[instance_layout_C],
2000
+ "element_accumulator": DataTypeTag[operation.accumulator_type()],
2001
+ "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
2002
+ "arch": "cutlass::arch::Sm%d" % operation.arch,
2003
+ "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
2004
+ "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
2005
+ "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
2006
+ "warp_shape_m": str(warp_shape[0]),
2007
+ "warp_shape_n": str(warp_shape[1]),
2008
+ "warp_shape_k": str(warp_shape[2]),
2009
+ "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
2010
+ "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
2011
+ "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
2012
+ "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
2013
+ "stages": str(operation.tile_description.stages),
2014
+ "align_a": str(operation.A.alignment),
2015
+ "align_b": str(operation.B.alignment),
2016
+ "transform_a": ComplexTransformTag[operation.A.complex_transform],
2017
+ "transform_b": ComplexTransformTag[operation.B.complex_transform],
2018
+ "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation],
2019
+ }
2020
+
2021
+ if hasattr(operation.epilogue_functor, "visitor"):
2022
+ self.includes += [
2023
+ "cutlass/epilogue/threadblock/fusion/visitors.hpp",
2024
+ "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
2025
+ ]
2026
+ callback_name, callback_decl = operation.epilogue_functor.emit(operation)
2027
+ values["callback_name"] = callback_name
2028
+ values["callback_decl"] = callback_decl
2029
+ values["align_c"] = str(operation.C.alignment)
2030
+ values["element_epilogue"] = DataTypeTag[operation.epilogue_functor.element_epilogue]
2031
+ if hasattr(operation.epilogue_functor, "epilogue_stages"):
2032
+ epilogue_stages = operation.epilogue_functor.epilogue_stages
2033
+ else:
2034
+ epilogue_stages = 1
2035
+ values["epilogue_stages"] = str(epilogue_stages)
2036
+ return SubstituteTemplate(self.gemm_template_kernel_visitor, values)
2037
+ else:
2038
+ values["epilogue_functor"] = operation.epilogue_functor.emit()
2039
+ return SubstituteTemplate(gemm_template, values)
2040
+
2041
+
2042
+ class EmitGemmGroupedInstance:
2043
+ """Responsible for emitting a CUTLASS template definition"""
2044
+
2045
+ def __init__(self, operation_suffix=""):
2046
+ self.operation_suffix = operation_suffix
2047
+ self.includes = [
2048
+ "cutlass/cutlass.h",
2049
+ "cutlass/numeric_types.h",
2050
+ "cutlass/arch/arch.h",
2051
+ "cutlass/arch/mma.h",
2052
+ "cutlass/layout/matrix.h",
2053
+ "cutlass/gemm/kernel/gemm_grouped.h",
2054
+ "cutlass/gemm/kernel/default_gemm_grouped.h",
2055
+ ]
2056
+ self.gemm_template_kernel = """
2057
+ // Gemm operator ${operation_name}
2058
+ using ${operation_name}_base =
2059
+ typename cutlass::gemm::kernel::DefaultGemmGrouped<
2060
+ ${element_a}, ${layout_a}, ${transform_a}, ${align_a},
2061
+ ${element_b}, ${layout_b}, ${transform_b}, ${align_b},
2062
+ ${element_c}, ${layout_c},
2063
+ ${element_accumulator},
2064
+ ${opcode_class},
2065
+ ${arch},
2066
+ cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
2067
+ cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
2068
+ cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
2069
+ ${epilogue_functor},
2070
+ ${swizzling_functor},
2071
+ ${stages},
2072
+ ${precompute_mode},
2073
+ ${math_operation}
2074
+ >::GemmKernel;
2075
+
2076
+ // Define named type
2077
+ struct ${operation_name}${operation_suffix} :
2078
+ public ${operation_name}_base { };
2079
+ """
2080
+ self.gemm_template_device = (
2081
+ self.gemm_template_kernel
2082
+ + """
2083
+ using DeviceKernel = cutlass::gemm::device::GemmGrouped<${operation_name}_base>;
2084
+ """
2085
+ )
2086
+
2087
+ def instance_template(self):
2088
+ return """
2089
+ ${compile_guard_start}
2090
+ manifest.append(new ${gemm_kind}<
2091
+ cutlass::gemm::device::GemmGrouped<${operation_name}>
2092
+ >("${operation_name}"));
2093
+ ${compile_guard_end}
2094
+ """
2095
+
2096
+ def emit(self, operation):
2097
+ threadblock_shape = operation.tile_description.threadblock_shape
2098
+ warp_count = operation.tile_description.warp_count
2099
+
2100
+ warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
2101
+
2102
+ instance_layout_A, instance_layout_B, instance_layout_C = \
2103
+ (operation.A.layout, operation.B.layout, operation.C.layout)
2104
+
2105
+ # Support built-in epilogue functors or user-defined functions
2106
+ epilogue_functor = operation.epilogue_functor.emit()
2107
+
2108
+ values = {
2109
+ "operation_name": operation.procedural_name(),
2110
+ "operation_suffix": self.operation_suffix,
2111
+ "element_a": DataTypeTag[operation.A.element],
2112
+ "layout_a": LayoutTag[instance_layout_A],
2113
+ "element_b": DataTypeTag[operation.B.element],
2114
+ "layout_b": LayoutTag[instance_layout_B],
2115
+ "element_c": DataTypeTag[operation.C.element],
2116
+ "layout_c": LayoutTag[instance_layout_C],
2117
+ "element_accumulator": DataTypeTag[operation.accumulator_type()],
2118
+ "opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
2119
+ "arch": "cutlass::arch::Sm%d" % operation.arch,
2120
+ "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
2121
+ "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
2122
+ "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
2123
+ "warp_shape_m": str(warp_shape[0]),
2124
+ "warp_shape_n": str(warp_shape[1]),
2125
+ "warp_shape_k": str(warp_shape[2]),
2126
+ "instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
2127
+ "instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
2128
+ "instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
2129
+ "epilogue_functor": epilogue_functor,
2130
+ "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
2131
+ "stages": str(operation.tile_description.stages),
2132
+ "align_a": str(operation.A.alignment),
2133
+ "align_b": str(operation.B.alignment),
2134
+ "transform_a": ComplexTransformTag[operation.A.complex_transform],
2135
+ "transform_b": ComplexTransformTag[operation.B.complex_transform],
2136
+ "precompute_mode": SchedulerModeTag[operation.precompute_mode],
2137
+ "math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation],
2138
+ }
2139
+
2140
+ if operation.emission_type == EmissionType.Kernel:
2141
+ gemm_template = self.gemm_template_kernel
2142
+ else:
2143
+ gemm_template = self.gemm_template_device
2144
+
2145
+ return SubstituteTemplate(gemm_template, values)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Common data types and string names/tags for them
35
+ """
36
+
37
+ import enum
38
+
39
+ from cutlass_library import (
40
+ ComplexTransform,
41
+ DataType,
42
+ DataTypeSize,
43
+ EpilogueScheduleType,
44
+ KernelScheduleSuffixes,
45
+ KernelScheduleType,
46
+ MathOperation,
47
+ OpcodeClass,
48
+ TileSchedulerType
49
+ )
50
+
51
+
52
+ # The following block implements enum.auto() for Python 3.5 variants that don't include it such
53
+ # as the default 3.5.2 on Ubuntu 16.04.
54
+ #
55
+ # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
56
+
57
+ try:
58
+ from enum import auto as enum_auto
59
+ except ImportError:
60
+ __cutlass_library_auto_enum = 0
61
+
62
+ def enum_auto() -> int:
63
+ global __cutlass_library_auto_enum
64
+ i = __cutlass_library_auto_enum
65
+ __cutlass_library_auto_enum += 1
66
+ return i
67
+
68
+
69
+ class DataTypeSizeBytes:
70
+ """
71
+ Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the
72
+ data type key is less than a full byte or a non-integer number of bytes.
73
+ """
74
+
75
+ @staticmethod
76
+ def __class_getitem__(datatype):
77
+ """
78
+ Returns the number of bytes in size the data type is. Raises an exception if the data type
79
+ is either less than a full byte or a non-integer number of bytes in size.
80
+
81
+ :param datatype: data type to query
82
+
83
+ :return: number of bytes the data type occupies
84
+ :rtype: int
85
+ """
86
+ bits = DataTypeSize[datatype]
87
+ if bits < 8:
88
+ raise Exception(
89
+ f"Data type {datatype} is less than one byte in size."
90
+ )
91
+ elif bits % 8 != 0:
92
+ raise Exception(
93
+ f"Data type datatype is not an integer number of bytes."
94
+ )
95
+ return bits // 8
96
+
97
+
98
+ class SchedulerMode(enum.Enum):
99
+ Device = enum_auto()
100
+ Host = enum_auto()
101
+
102
+
103
+ SchedulerModeTag = {
104
+ SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
105
+ SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute",
106
+ }
107
+
108
+
109
+ ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"}
110
+
111
+
112
+ class FunctionalOp(enum.Enum):
113
+ AtomicAdd = enum_auto()
114
+ AtomicMaximum = enum_auto()
115
+ Divides = enum_auto()
116
+ Maximum = enum_auto()
117
+ Minimum = enum_auto()
118
+ Minus = enum_auto()
119
+ Multiplies = enum_auto()
120
+ MultiplyAdd = enum_auto()
121
+ Plus = enum_auto()
122
+ Exp = enum_auto()
123
+
124
+
125
+ FunctionalOpTag = {
126
+ FunctionalOp.AtomicAdd: "cutlass::atomic_add",
127
+ FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
128
+ FunctionalOp.Divides: "cutlass::divides",
129
+ FunctionalOp.Maximum: "cutlass::maximum",
130
+ FunctionalOp.Minimum: "cutlass::minimum",
131
+ FunctionalOp.Minus: "cutlass::minus",
132
+ FunctionalOp.Multiplies: "cutlass::multiplies",
133
+ FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
134
+ FunctionalOp.Plus: "cutlass::plus",
135
+ FunctionalOp.Exp: "cutlass::fast_exp_op",
136
+ }
137
+
138
+
139
+ class ActivationOp(enum.Enum):
140
+ DGelu = enum_auto()
141
+ Gelu = enum_auto()
142
+ GeluTaylor = enum_auto()
143
+ HardSwish = enum_auto()
144
+ Identity = enum_auto()
145
+ LeakyReLU = enum_auto()
146
+ ReLU = enum_auto()
147
+ Sigmoid = enum_auto()
148
+ SiLU = enum_auto()
149
+ Tanh = enum_auto()
150
+
151
+
152
+ ActivationOpTag = {
153
+ ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
154
+ ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
155
+ ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
156
+ ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
157
+ ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
158
+ ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
159
+ ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
160
+ ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
161
+ ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
162
+ ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
163
+ }
164
+
165
+
166
+ def op_tag(op) -> str:
167
+ """
168
+ Dispatches `op` to the appropriate *Tag dictionary depending on whether
169
+ `op` is an ActivationOp or FunctionalOp. This is useful for cases in which
170
+ either type can be used.
171
+
172
+ :param op: operation to emit a tag for
173
+ :type op: ActivationOp | FunctionalOp
174
+
175
+ :return: tag corresponding to op
176
+ :rtype: str
177
+ """
178
+ if isinstance(op, ActivationOp):
179
+ return ActivationOpTag[op]
180
+ elif isinstance(op, FunctionalOp):
181
+ return FunctionalOpTag[op]
182
+ else:
183
+ raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.")
184
+
185
+
186
+ class FloatRoundStyle(enum.Enum):
187
+ ToNearest = enum_auto()
188
+ ToNearestSatfinite = enum_auto()
189
+ Indeterminate = enum_auto()
190
+ TowardZero = enum_auto()
191
+ TowardInfinity = enum_auto()
192
+ TowardNegInfinity = enum_auto()
193
+ HalfUlpTruncDntz = enum_auto()
194
+ HalfUlpTruncate = enum_auto()
195
+
196
+
197
+ FloatRoundStyleTag = {
198
+ FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest",
199
+ FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite",
200
+ FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate",
201
+ FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero",
202
+ FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity",
203
+ FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity",
204
+ FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz",
205
+ FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate",
206
+ }
207
+
208
+
209
+ class MathInstruction:
210
+ """
211
+ Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ instruction_shape,
217
+ element_a,
218
+ element_b,
219
+ element_accumulator,
220
+ opcode_class=OpcodeClass.Simt,
221
+ math_operation=MathOperation.multiply_add,
222
+ ):
223
+ """
224
+ :param instruction_shape: size of the [M, N, K] dimensions of the instruction
225
+ :type instruction_shape: list or tuple
226
+ :param element_a: data type of operand A
227
+ :param element_b: data type of operand B
228
+ :param element_accumulator: data type used in accumulation
229
+ :param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core)
230
+ :type opcode_class: cutlass_library.library.OpcodeClass
231
+ :param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate)
232
+ :type math_operation: MathOperation
233
+ """
234
+ self.instruction_shape = instruction_shape
235
+ self.element_a = element_a
236
+ self.element_b = element_b
237
+ self.element_accumulator = element_accumulator
238
+ self.opcode_class = opcode_class
239
+ self.math_operation = math_operation
240
+
241
+
242
+ def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule):
243
+ blackwell_threadblock_shape = tile_description.threadblock_shape
244
+ is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule])
245
+ if cluster_shape[0] > 0:
246
+ blackwell_threadblock_shape = [
247
+ tile_description.threadblock_shape[0] // cluster_shape[0],
248
+ tile_description.threadblock_shape[1] // cluster_shape[1],
249
+ tile_description.threadblock_shape[2] // cluster_shape[2]
250
+ ]
251
+ if is_2sm:
252
+ blackwell_threadblock_shape[0] *= 2
253
+ else:
254
+ blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape
255
+ return blackwell_threadblock_shape, is_2sm
256
+
257
+
258
+ class TileDescription:
259
+ """
260
+ Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes,
261
+ stage count, and math instruction specification
262
+ """
263
+
264
+ def __init__(
265
+ self,
266
+ threadblock_shape,
267
+ stages,
268
+ warp_count,
269
+ math_instruction,
270
+ cluster_shape=[1, 1, 1],
271
+ kernel_schedule: KernelScheduleType = None,
272
+ epilogue_schedule: EpilogueScheduleType = None,
273
+ tile_scheduler: TileSchedulerType = None
274
+ ):
275
+ """
276
+ :param threadblock_shape: shape of a threadblock tyle
277
+ :type threadblock_shape: list or tuple
278
+ :param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum
279
+ number of stages that can be supported for an operation on a given architecture will be computed at a later time
280
+ :type stages: int or None
281
+ :param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile
282
+ :type warp_count: list, tuple, or None
283
+ :param math_instruction: specification of the instruction type and shape to be performed and the types of its operands
284
+ :type math_instruction: MathInstruction
285
+ :param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster
286
+ :param kernel_schedule: type of kernel schedule to use (only available for SM90+)
287
+ :type kernel_schedule: cutlass_library.KernelScheduleType
288
+ :param epilogue_schedule: type of epilogue schedule to use (only available for SM90+)
289
+ :type epilogue_schedule: cutlass_library.EpilogueScheduleType
290
+ :param tile_scheduler: type of tile scheduler to use (only available for SM90+)
291
+ :type tile_scheduler: cutlass_library.TileSchedulerType
292
+ """
293
+ if ((kernel_schedule is None and epilogue_schedule is not None) or
294
+ (kernel_schedule is not None and epilogue_schedule is None)):
295
+ raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.")
296
+
297
+ self.threadblock_shape = threadblock_shape
298
+ self.cluster_shape = cluster_shape
299
+ self.kernel_schedule = kernel_schedule
300
+ self.epilogue_schedule = epilogue_schedule
301
+ self.tile_scheduler = tile_scheduler
302
+ self.stages = stages
303
+
304
+ self.math_instruction = math_instruction
305
+ self.instruction_shape = math_instruction.instruction_shape
306
+
307
+ # Number of warps along x, y, z directions
308
+ self.warp_count = warp_count
309
+
310
+ self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule)
311
+
312
+ def clone_and_update(self, td: dict):
313
+ attrs = {
314
+ "cluster_shape": None,
315
+ "threadblock_shape": None,
316
+ "warp_count": None,
317
+ "stages": None,
318
+ "instruction_shape": None,
319
+ "kernel_schedule": None,
320
+ "epilogue_schedule": None,
321
+ "tile_scheduler": None
322
+ }
323
+ for key in attrs.keys():
324
+ if key in td.keys():
325
+ attrs[key] = td[key]
326
+ else:
327
+ attrs[key] = getattr(self, key)
328
+
329
+ attrs["math_instruction"] = MathInstruction(
330
+ attrs["instruction_shape"],
331
+ self.math_instruction.element_a,
332
+ self.math_instruction.element_b,
333
+ self.math_instruction.element_accumulator,
334
+ self.math_instruction.opcode_class,
335
+ self.math_instruction.math_operation
336
+ )
337
+
338
+ # Remove the instruction shape
339
+ del attrs["instruction_shape"]
340
+
341
+ return TileDescription(**attrs)
342
+
343
+ @property
344
+ def num_threads(self):
345
+ """
346
+ Returns the number of threads in the threadblock
347
+
348
+ :return: number of threads in the threadblock
349
+ :rtype: int or None (if warp count is None)
350
+ """
351
+ if self.warp_count is not None:
352
+ threads = 32
353
+ for cnt in self.warp_count:
354
+ threads *= cnt
355
+ return threads
356
+ return None
357
+
358
+ def procedural_name(self):
359
+ """
360
+ Returns a name identifying the tile description
361
+
362
+ :return: name identifying the tile description
363
+ :rtype: int
364
+ """
365
+ emit_stages = 0 if self.stages is None else self.stages
366
+ name = "%dx%dx%d_%dx%d_%dx%d" % (
367
+ self.cluster_shape[0],
368
+ self.cluster_shape[1],
369
+ self.cluster_shape[2],
370
+ self.threadblock_shape[0],
371
+ self.threadblock_shape[1],
372
+ self.threadblock_shape[2],
373
+ emit_stages
374
+ )
375
+
376
+ return name
377
+
378
+ def procedural_name_2x(self):
379
+ """
380
+ Returns a name identifying the tile description
381
+
382
+ :return: name identifying the tile description
383
+ :rtype: int
384
+ """
385
+ return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
386
+
387
+ def __str__(self):
388
+ """
389
+ Returns a string with containing each of the tile description's values
390
+
391
+ :return: contents of tile description
392
+ :rtype: str
393
+ """
394
+ if self.kernel_schedule is not None:
395
+ kschedule = self.kernel_schedule
396
+ else:
397
+ kschedule = KernelScheduleType.ScheduleAuto
398
+
399
+ if self.epilogue_schedule is not None:
400
+ eschedule = self.epilogue_schedule
401
+ else:
402
+ eschedule = EpilogueScheduleType.ScheduleAuto
403
+
404
+ if self.tile_scheduler is not None:
405
+ tschedule = self.tile_scheduler.name
406
+ else:
407
+ tschedule = "None"
408
+ return f"""
409
+ {{
410
+ ClusterShape: {self.cluster_shape}
411
+ ThreadblockShape: {self.threadblock_shape}
412
+ WarpCount: {self.warp_count}
413
+ Stages: {self.stages if self.stages is not None else 'Auto'}
414
+ InstructionShape: {self.math_instruction.instruction_shape}
415
+ Kernel schedule: {kschedule.name}
416
+ Epilogue schedule: {kschedule.name}
417
+ TileScheduler: {tschedule}
418
+ }}"""
419
+
420
+
421
+ class TensorDescription:
422
+ def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none):
423
+ self.element = element
424
+ self.layout = layout
425
+ if element != DataType.void:
426
+ self.alignment = min(128 // DataTypeSize[self.element], alignment)
427
+ else:
428
+ self.alignment = alignment
429
+ self.complex_transform = complex_transform
430
+
431
+
432
+ def CalculateSmemUsagePerStage(operation):
433
+ """
434
+ Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
435
+
436
+ :param op: operation for which the maximum stages should be computed. If stages are
437
+ set via the `op.tile_description.stages` parameter, this setting is ignored
438
+ in the present calculation
439
+ :type op: cutlass_cppgen.backend.Operation
440
+
441
+ :return: number of bytes of shared memory consumed by a single stage
442
+ :rtype: int
443
+ """
444
+ m, n, k = operation.tile_description.threadblock_shape
445
+
446
+ if operation.operation_kind == OperationKind.Gemm:
447
+ stage_barrier_bytes = 32
448
+ return (
449
+ (DataTypeSize[operation.A.element] * m * k // 8)
450
+ + (DataTypeSize[operation.B.element] * k * n // 8)
451
+ + stage_barrier_bytes
452
+ )
453
+ else:
454
+ raise Exception("Unsupported operation kind {}.".format(operation.operation_kind))
455
+
456
+
457
+ def CalculateSmemUsage(operation):
458
+ """
459
+ Returns the amount of shared memory in bytes consumed by a kernel.
460
+
461
+ :param op: operation for which the maximum stages should be computed. If stages are
462
+ set via the `op.tile_description.stages` parameter, this setting is ignored
463
+ in the present calculation
464
+ :type op: cutlass_cppgen.backend.Operation
465
+
466
+ :return: int
467
+ """
468
+ return operation.tile_description.stages * CalculateSmemUsagePerStage(operation)
469
+
470
+
471
+ class ApiVersion(enum.Enum):
472
+ """
473
+ Differentiate between CUTLASS 2.x and 3.x API versions
474
+ """
475
+
476
+ v2x = enum_auto()
477
+ v3x = enum_auto()
478
+
479
+
480
+ def api_version(arch, opclass, dtype):
481
+ """
482
+ Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x
483
+ or 3.x for code emission.
484
+
485
+ :param arch: compute capability of device on which to run
486
+ :type arch: int
487
+ :param opclass: class of the operation being performed
488
+ :type opclass: cutlass_library.OpcodeClass
489
+ :param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same)
490
+ :type dtype: cutlass_library.DataType
491
+
492
+ :return: API version to be used in code emission
493
+ :rtype: ApiVersion
494
+ """
495
+ if (arch in [90, 100, 101, 103] and
496
+ opclass == OpcodeClass.TensorOp and
497
+ (dtype != DataType.f64)):
498
+ return ApiVersion.v3x
499
+ else:
500
+ return ApiVersion.v2x
501
+
502
+
503
+ class EmissionType(enum.Enum):
504
+ """
505
+ Tags for whether to emit a kernel- or device-level operation
506
+ """
507
+
508
+ Kernel = enum_auto()
509
+ Device = enum_auto()
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import numpy as np
34
+
35
+ import cutlass_cppgen
36
+ from cutlass_cppgen.utils.datatypes import is_numpy_tensor
37
+ from cutlass_cppgen.utils.lazy_import import lazy_import
38
+
39
+ if cutlass_cppgen.use_rmm:
40
+ import rmm
41
+ else:
42
+ cudart = lazy_import("cuda.cudart")
43
+
44
+
45
+ class PoolMemoryManager:
46
+ def __init__(self, init_pool_size: int, max_pool_size: int) -> None:
47
+ self.pool = rmm.mr.PoolMemoryResource(
48
+ rmm.mr.CudaMemoryResource(),
49
+ initial_pool_size=init_pool_size,
50
+ maximum_pool_size=max_pool_size
51
+ )
52
+ self.mr = rmm.mr.TrackingResourceAdaptor(self.pool)
53
+ rmm.mr.set_current_device_resource(self.mr)
54
+
55
+ def pool_size(self):
56
+ return self.pool.pool_size()
57
+
58
+
59
+ class DevicePtrWrapper:
60
+ """
61
+ Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer
62
+ (at least in terms of the interface used by the CUTLASS Python interface)
63
+ """
64
+ def __init__(self, dev_ptr):
65
+ self.dev_ptr = dev_ptr
66
+
67
+ @property
68
+ def ptr(self):
69
+ return self.dev_ptr
70
+
71
+
72
+ def _todevice(host_data):
73
+ """
74
+ Helper for transferring host data to device memory
75
+ """
76
+ if cutlass_cppgen.use_rmm:
77
+ return rmm.DeviceBuffer.to_device(host_data.tobytes())
78
+ else:
79
+ nbytes = len(host_data.tobytes())
80
+ dev_ptr_wrapper = device_mem_alloc(nbytes)
81
+ err, = cudart.cudaMemcpy(
82
+ dev_ptr_wrapper.ptr,
83
+ host_data.__array_interface__['data'][0],
84
+ nbytes,
85
+ cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
86
+ )
87
+ if err != cudart.cudaError_t.cudaSuccess:
88
+ raise Exception(f"cudaMemcpy failed with error {err}")
89
+ return dev_ptr_wrapper
90
+
91
+
92
+ def todevice(host_data, dtype=np.float32):
93
+ """
94
+ Pass the host_data to device memory
95
+ """
96
+ if isinstance(host_data, list):
97
+ return _todevice(np.array(host_data, dtype=dtype))
98
+ elif is_numpy_tensor(host_data):
99
+ return _todevice(host_data)
100
+
101
+
102
+ def device_mem_alloc(size):
103
+ if cutlass_cppgen.use_rmm:
104
+ return rmm.DeviceBuffer(size=size)
105
+ else:
106
+ err, ptr = cudart.cudaMalloc(size)
107
+ if err != cudart.cudaError_t.cudaSuccess:
108
+ raise Exception(f"cudaMalloc failed with error {err}")
109
+ return DevicePtrWrapper(ptr)
110
+
111
+
112
+ def align_size(size, alignment=256):
113
+ return ((size + alignment - 1) // alignment) * alignment
114
+
115
+
116
+ def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34):
117
+ if cutlass_cppgen.use_rmm:
118
+ memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size)
119
+ return memory_pool
120
+ else:
121
+ return None
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import ctypes
34
+ from cutlass_cppgen.utils.lazy_import import lazy_import
35
+ cuda = lazy_import("cuda.cuda")
36
+
37
+ from cutlass_cppgen.backend.utils.device import device_cc
38
+
39
+ _supports_cluster_launch = None
40
+
41
+
42
+ def supports_cluster_launch():
43
+ from cuda import __version__
44
+ _version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")]
45
+ global _supports_cluster_launch
46
+ if _supports_cluster_launch is None:
47
+ major, minor = _version_splits[0], _version_splits[1]
48
+ _supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8))
49
+ return _supports_cluster_launch
50
+
51
+
52
+ class LaunchConfiguration:
53
+ def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0):
54
+ self.grid = grid
55
+ self.block = block
56
+ self.shared_memory_capacity = smem
57
+
58
+
59
+ class ExecutableOperation:
60
+ def __init__(self, operation):
61
+ self.operation = operation
62
+ self.module = None
63
+ self.kernel = None
64
+
65
+ def name(self):
66
+ return self.operation.procedural_name()
67
+
68
+ def emit(self):
69
+ return ""
70
+
71
+ def can_implement(self, configuration, arguments):
72
+ raise NotImplementedError()
73
+
74
+ def get_host_workspace_size(self, arguments):
75
+ raise NotImplementedError()
76
+
77
+ def get_device_workspace_size(self, arguments):
78
+ raise NotImplementedError()
79
+
80
+ def plan(self, arguments):
81
+ raise NotImplementedError()
82
+
83
+ def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None):
84
+ raise NotImplementedError()
85
+
86
+ def run_with_clusters(self, launch_config, kernel_params, stream=None):
87
+ if not stream:
88
+ stream = cuda.CUstream(0)
89
+ if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"):
90
+ attr = cuda.CUlaunchAttribute()
91
+ attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape
92
+ attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
93
+ attrs = [attr]
94
+
95
+ # Allow for non-portable cluster sizes
96
+ err, = cuda.cuFuncSetAttribute(
97
+ self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)
98
+ if err != cuda.CUresult.CUDA_SUCCESS:
99
+ return err
100
+ else:
101
+ attrs = []
102
+
103
+ config = cuda.CUlaunchConfig()
104
+ config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid
105
+ config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block
106
+ config.blockDimZ = launch_config.block[2]
107
+ config.sharedMemBytes = launch_config.shared_memory_capacity
108
+ config.hStream = stream
109
+ config.attrs = attrs
110
+ config.numAttrs = len(attrs)
111
+
112
+ err, = cuda.cuLaunchKernelEx(
113
+ config, f=self.kernel, kernelParams=kernel_params, extra=0)
114
+ return err
115
+
116
+ def run_without_clusters(self, launch_config, kernel_params, stream=None):
117
+ if not stream:
118
+ stream = cuda.CUstream(0)
119
+ err, = cuda.cuLaunchKernel(
120
+ self.kernel,
121
+ launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
122
+ launch_config.block[0], launch_config.block[1], launch_config.block[2],
123
+ launch_config.shared_memory_capacity,
124
+ stream,
125
+ kernel_params,
126
+ 0)
127
+
128
+ return err
129
+
130
+ def run(self, host_workspace, device_workspace, launch_config, stream=None):
131
+ if not stream:
132
+ stream = cuda.CUstream(0)
133
+ cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
134
+ packed = (ctypes.c_void_p * 1)()
135
+ packed[0] = ctypes.addressof(cArg)
136
+
137
+ if supports_cluster_launch():
138
+ return self.run_with_clusters(launch_config, packed, stream)
139
+ else:
140
+ return self.run_without_clusters(launch_config, packed, stream)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+ from __future__ import annotations
33
+
34
+ import ctypes
35
+ from typing import Union
36
+
37
+ from cutlass_cppgen.utils.lazy_import import lazy_import
38
+ cuda = lazy_import("cuda.cuda")
39
+ cudart = lazy_import("cuda.cudart")
40
+ import numpy as np
41
+
42
+ from cutlass_library import (
43
+ DataTypeNames,
44
+ DataTypeSize,
45
+ DataTypeTag,
46
+ LayoutType,
47
+ SubstituteTemplate
48
+ )
49
+
50
+ import cutlass_cppgen
51
+ from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
52
+ from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend
53
+ from cutlass_cppgen.backend.library import TensorDescription
54
+ from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
55
+ from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
56
+ from cutlass_cppgen.shape import MatrixCoord
57
+ from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor
58
+
59
+
60
+ class ReductionOperation:
61
+ pass
62
+
63
+
64
+ class ReductionArguments:
65
+ """
66
+ Arguments of reduction
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ operation: ReductionOperation,
72
+ problem_size: "list[int]",
73
+ partitions: int,
74
+ workspace: cuda.CUdeviceptr,
75
+ destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
76
+ source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
77
+ **kwargs,
78
+ ) -> None:
79
+ # tensor_C can be interpreted as the bias with bias=True in keyword args
80
+ if "bias" in kwargs.keys():
81
+ self.bias = kwargs["bias"]
82
+ else:
83
+ # by default, tensor_C is not bias
84
+ self.bias = False
85
+ if "stream" in kwargs.keys():
86
+ self.stream = kwargs["stream"]
87
+ else:
88
+ self.stream = cuda.CUstream(0)
89
+
90
+ self.operation = operation
91
+ self.ptr_workspace = workspace
92
+
93
+ # number of split-k partitions
94
+ self.partitions = partitions
95
+
96
+ if is_numpy_tensor(destination):
97
+ self.host_D = destination
98
+ self.destination_buffer = NumpyFrontend.argument(destination, True)
99
+ self.source_buffer = NumpyFrontend.argument(source, False)
100
+ self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr)
101
+ self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr)
102
+ elif is_torch_tensor(destination):
103
+ self.ptr_destination = TorchFrontend.argument(destination)
104
+ self.ptr_source = TorchFrontend.argument(source)
105
+ elif isinstance(destination, cuda.CUdeviceptr):
106
+ self.ptr_destination = destination
107
+ self.ptr_source = source
108
+ else:
109
+ raise TypeError("unknown Type")
110
+
111
+ self.problem_size = MatrixCoord_(problem_size[0], problem_size[1])
112
+
113
+ self.partition_stride = (
114
+ problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8
115
+ )
116
+
117
+ if "output_op" in kwargs.keys():
118
+ self.output_op = kwargs["output_op"]
119
+ else:
120
+ self.output_op = self.operation.epilogue_type(1.0, 0.0)
121
+
122
+ self.get_arguments()
123
+
124
+ @staticmethod
125
+ def get_tensor_ref(
126
+ extent: "tuple[int]",
127
+ device_ptr: cuda.CUdeviceptr,
128
+ layout: LayoutType,
129
+ ):
130
+ if layout == LayoutType.RowMajor:
131
+ return TensorRef2D_(int(device_ptr), extent[1])
132
+ else:
133
+ raise ValueError(f"Unknown layout type {layout}")
134
+
135
+ def get_arguments(self):
136
+ ref_workspace = ReductionArguments.get_tensor_ref(
137
+ extent=[
138
+ self.problem_size.row,
139
+ self.problem_size.column,
140
+ ],
141
+ device_ptr=self.ptr_workspace,
142
+ layout=LayoutType.RowMajor,
143
+ )
144
+ if self.bias:
145
+ ref_source = ReductionArguments.get_tensor_ref(
146
+ extent=[0, 0],
147
+ device_ptr=self.ptr_source,
148
+ layout=LayoutType.RowMajor,
149
+ )
150
+ else:
151
+ ref_source = ReductionArguments.get_tensor_ref(
152
+ extent=[
153
+ self.problem_size.row,
154
+ self.problem_size.column,
155
+ ],
156
+ device_ptr=self.ptr_source,
157
+ layout=LayoutType.RowMajor,
158
+ )
159
+
160
+ ref_destination = ReductionArguments.get_tensor_ref(
161
+ extent=[
162
+ self.problem_size.row,
163
+ self.problem_size.column,
164
+ ],
165
+ device_ptr=self.ptr_destination,
166
+ layout=LayoutType.RowMajor,
167
+ )
168
+
169
+ self.c_arguments = self.operation.argument_type(
170
+ self.problem_size,
171
+ self.partitions,
172
+ self.partition_stride,
173
+ ref_workspace,
174
+ ref_destination,
175
+ ref_source,
176
+ self.output_op,
177
+ )
178
+
179
+ params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments))
180
+ self.host_workspace = bytearray(params_.contents)
181
+
182
+ def sync(self):
183
+ (err,) = cudart.cudaDeviceSynchronize()
184
+ if err != cuda.CUresult.CUDA_SUCCESS:
185
+ raise RuntimeError(f"CUDA Error {str(err)}")
186
+
187
+ if hasattr(self, "host_D"):
188
+ (err,) = cuda.cuMemcpyDtoH(
189
+ self.host_D,
190
+ self.ptr_destination,
191
+ self.host_D.size * self.host_D.itemsize,
192
+ )
193
+ if err != cuda.CUresult.CUDA_SUCCESS:
194
+ raise RuntimeError("CUDA Error %s" % str(err))
195
+
196
+ self.free()
197
+
198
+ def free(self):
199
+ """
200
+ Frees allocated device-side memory
201
+ """
202
+ # Free any device memory allocated manually
203
+ if not cutlass_cppgen.use_rmm:
204
+ for attr in ["destination_buffer", "source_buffer"]:
205
+ if hasattr(self, attr):
206
+ buf = getattr(self, attr)
207
+ if isinstance(buf, DevicePtrWrapper):
208
+ err, = cudart.cudaFree(buf.ptr)
209
+ if err != cudart.cudaError_t.cudaSuccess:
210
+ raise RuntimeError(f"cudaFree failed with error {err}")
211
+ del buf
212
+
213
+
214
+ class ReductionRT(ExecutableOperation):
215
+ """
216
+ ReductionRT manages the CUTLASS runtime components for reduction
217
+ """
218
+
219
+ KernelTemplate = r"""
220
+ extern "C"
221
+ __global__ void
222
+ ${operation_name}(${operation_name}${operation_suffix}::Params params) {
223
+
224
+ // Dynamic shared memory base pointer
225
+ extern __shared__ int SharedStorageBase[];
226
+
227
+ // Declare pointer to dynamic shared memory.
228
+ ${operation_name}${operation_suffix}::SharedStorage *shared_storage =
229
+ reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
230
+
231
+ ${operation_name}${operation_suffix} op;
232
+
233
+ op(params, *shared_storage);
234
+ }
235
+ """
236
+ HostTemplate = r"""
237
+ extern "C" {
238
+ // Get the size of params in bytes
239
+ int ${operation_name}_get_param_size(){
240
+ return sizeof(${operation_name}${operation_suffix}::Params);
241
+ }
242
+
243
+ // Get the size of dynamic shared memory in bytes
244
+ int ${operation_name}_shared_memory_size() {
245
+ return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
246
+ }
247
+
248
+ // Get the params as byte array
249
+ char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){
250
+ char *bytes = ((char*)(params));
251
+ char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
252
+ for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
253
+ output[i] = bytes[i];
254
+
255
+ return output;
256
+ }
257
+ }
258
+ """
259
+
260
+ def __init__(self, operation: ReductionOperation):
261
+ super().__init__(operation)
262
+
263
+ self.operation: ReductionOperation = operation
264
+ self.emitter = EmitReductionInstance("_type")
265
+
266
+ self.elements_per_access = self.operation.count
267
+ (
268
+ self.argument_type,
269
+ self.epilogue_type,
270
+ ) = get_reduction_params(operation.epilogue_functor)
271
+ self.argtype = [ctypes.POINTER(self.argument_type)]
272
+
273
+ def emit(self):
274
+ return self.emitter.emit(self.operation)
275
+
276
+ def plan(self, arguments: ReductionArguments):
277
+ block_shape = [
278
+ self.operation.shape.column // self.elements_per_access,
279
+ self.operation.shape.row,
280
+ 1,
281
+ ]
282
+ grid_shape = [
283
+ (arguments.problem_size.row + self.operation.shape.row - 1)
284
+ // self.operation.shape.row,
285
+ (arguments.problem_size.column + self.operation.shape.column - 1)
286
+ // self.operation.shape.column,
287
+ 1,
288
+ ]
289
+ return LaunchConfiguration(
290
+ grid_shape,
291
+ block_shape,
292
+ self.shared_memory_capacity,
293
+ )
294
+
295
+ def initialize(self):
296
+ (err,) = cuda.cuFuncSetAttribute(
297
+ self.kernel,
298
+ attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
299
+ value=self.shared_memory_capacity,
300
+ )
301
+ if err != cuda.CUresult.CUDA_SUCCESS:
302
+ raise RuntimeError(f"CUDA Error: {err}")
303
+
304
+
305
+ class ReductionOperation:
306
+ """
307
+ CUTLASS reduction Operation
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ shape: MatrixCoord,
313
+ C: TensorDescription,
314
+ element_accumulator,
315
+ element_workspace=None,
316
+ element_compute=None,
317
+ epilogue_functor=None,
318
+ count: int = 1,
319
+ partitions_per_stage: int = 4,
320
+ ) -> None:
321
+ self.shape = shape
322
+ self.epilogue_functor = epilogue_functor
323
+ self.element_accumulator = element_accumulator
324
+
325
+ if element_workspace is None:
326
+ self.element_workspace = element_accumulator
327
+ else:
328
+ self.element_workspace = element_workspace
329
+
330
+ if element_compute is None:
331
+ self.element_compute = element_accumulator
332
+ else:
333
+ self.element_compute = element_compute
334
+
335
+ self.element_output = C.element
336
+ self.C: TensorDescription = C
337
+
338
+ # Reduce op processing size
339
+ self.count: int = count
340
+
341
+ # Number of partitions to reduce per stage
342
+ self.partitions_per_stage: int = partitions_per_stage
343
+
344
+ self.rt_module: ReductionRT = ReductionRT(self)
345
+ self.argument_type = self.rt_module.argument_type
346
+ self.epilogue_type = self.rt_module.epilogue_type
347
+
348
+ def extended_name(self):
349
+ extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}"
350
+
351
+ return SubstituteTemplate(
352
+ extend_name,
353
+ {
354
+ "element_workspace": DataTypeNames[self.element_workspace],
355
+ "element_accumulator": DataTypeNames[self.element_accumulator],
356
+ "element_compute": DataTypeNames[self.element_compute],
357
+ "element_output": DataTypeNames[self.element_output],
358
+ },
359
+ )
360
+
361
+ def configuration_name(self):
362
+ """The full procedural name indicates architecture, extended name, tile size"""
363
+
364
+ configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}"
365
+
366
+ threadblock = "%dx%d" % (
367
+ self.shape.row,
368
+ self.shape.column,
369
+ )
370
+
371
+ return SubstituteTemplate(
372
+ configuration_name,
373
+ {
374
+ "extended_name": self.extended_name(),
375
+ "threadblock": threadblock,
376
+ },
377
+ )
378
+
379
+ def procedural_name(self):
380
+ """The full procedural name indicates architeture, extended name, tile size"""
381
+ return self.configuration_name()
382
+
383
+ def run(self, arguments: ReductionArguments) -> cuda.CUresult:
384
+ """
385
+ Configure and launch the cuda kernel with input arguments
386
+ """
387
+ launch_config = self.rt_module.plan(arguments)
388
+
389
+ host_workspace = arguments.host_workspace
390
+ device_workspace = None
391
+
392
+ err = self.rt_module.run(
393
+ host_workspace,
394
+ device_workspace,
395
+ launch_config,
396
+ arguments.stream
397
+ )
398
+
399
+ if err != cuda.CUresult.CUDA_SUCCESS:
400
+ raise RuntimeError(f"CUDA Error {str(err)}")
401
+
402
+ return err
403
+
404
+
405
+ class EmitReductionInstance:
406
+ def __init__(self, operation_suffix="") -> None:
407
+ self.operation_suffix = operation_suffix
408
+ self.includes = [
409
+ "cutlass/cutlass.h",
410
+ "cutlass/numeric_types.h",
411
+ "cutlass/arch/arch.h",
412
+ "cutlass/arch/mma.h",
413
+ "cutlass/layout/matrix.h",
414
+ "cutlass/gemm/device/gemm.h",
415
+ "cutlass/gemm/device/gemm_universal_adapter.h",
416
+ "cutlass/gemm/kernel/default_gemm_universal.h",
417
+ "cutlass/reduction/kernel/reduce_split_k.h",
418
+ "cutlass/reduction/thread/reduction_operators.h",
419
+ ]
420
+ self.template = """
421
+ // Reduction kernel instance
422
+ using ${operation_name}_base =
423
+ typename cutlass::reduction::kernel::ReduceSplitK<
424
+ cutlass::MatrixShape<${shape_row}, ${shape_column}>,
425
+ ${epilogue_functor},
426
+ cutlass::reduction::thread::ReduceAdd<
427
+ ${element_accumulator},
428
+ ${element_output},
429
+ ${count}>,
430
+ ${partition_per_stage}>;
431
+
432
+ struct ${operation_name}${operation_suffix}:
433
+ public ${operation_name}_base { };
434
+ """
435
+
436
+ def emit(self, operation: ReductionOperation):
437
+ vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
438
+ epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element]
439
+
440
+ values = {
441
+ "operation_name": operation.configuration_name(),
442
+ "operation_suffix": self.operation_suffix,
443
+ "shape_row": str(operation.shape.row),
444
+ "shape_column": str(operation.shape.column),
445
+ "epilogue_functor": operation.epilogue_functor.emit(),
446
+ "element_output": DataTypeTag[operation.element_output],
447
+ "epilogue_vector_length": str(epilogue_vector_length),
448
+ "element_accumulator": DataTypeTag[operation.element_accumulator],
449
+ "element_compute": DataTypeTag[operation.element_compute],
450
+ "element_workspace": DataTypeTag[operation.element_workspace],
451
+ "count": str(operation.count),
452
+ "partition_per_stage": str(operation.partitions_per_stage),
453
+ }
454
+
455
+ return SubstituteTemplate(self.template, values)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]"
34
+
35
+ Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]"
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utility functions for interacting with the device
35
+ """
36
+ from __future__ import annotations
37
+
38
+ from cutlass_cppgen.utils.lazy_import import lazy_import
39
+ cuda = lazy_import("cuda.cuda")
40
+ cudart = lazy_import("cuda.cudart")
41
+
42
+ import cutlass_cppgen
43
+ from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
44
+
45
+
46
+ def check_cuda_errors(result: list):
47
+ """
48
+ Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
49
+ returns the result contained in the remaining fields of `result`.
50
+
51
+ :param result: the results of the `cudart` method, consisting of an error code and any method results
52
+ :type result: list
53
+
54
+ :return: non-error-code results from the `results` parameter
55
+ """
56
+ # `result` is of the format : (cudaError_t, result...)
57
+ err = result[0]
58
+ if err.value:
59
+ raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err)))
60
+
61
+ if len(result) == 1:
62
+ return None
63
+ elif len(result) == 2:
64
+ return result[1]
65
+ else:
66
+ return result[1:]
67
+
68
+
69
+ def device_cc(device: int = -1) -> int:
70
+ """
71
+ Returns the compute capability of the device with ID `device`.
72
+
73
+ :param device: ID of the device to query
74
+ :type device: int
75
+
76
+ :return: compute capability of the queried device (e.g., 80 for SM80)
77
+ :rtype: int
78
+ """
79
+ if device == -1:
80
+ device = cutlass_cppgen.device_id()
81
+
82
+ deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
83
+ major = str(deviceProp.major)
84
+ minor = str(deviceProp.minor)
85
+ return int(major + minor)
86
+
87
+
88
+ def device_sm_count(device: int = -1):
89
+ if device == -1:
90
+ device = cutlass_cppgen.device_id()
91
+ err, device_sm_count = cuda.cuDeviceGetAttribute(
92
+ cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device
93
+ )
94
+ if err != cuda.CUresult.CUDA_SUCCESS:
95
+ raise Exception(
96
+ "Failed to retireve SM count. "
97
+ f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}"
98
+ )
99
+
100
+ return device_sm_count
101
+
102
+
103
+ def to_device_ptr(tensor) -> cuda.CUdeviceptr:
104
+ """
105
+ Converts a tensor to a CUdeviceptr
106
+
107
+ :param tensor: tensor to convert
108
+ :type tensor: np.ndarray | torch.Tensor | cp.ndarray | int
109
+
110
+ :return: device pointer
111
+ :rtype: cuda.CUdeviceptr
112
+ """
113
+ if is_numpy_tensor(tensor):
114
+ ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
115
+ elif is_torch_tensor(tensor):
116
+ ptr = cuda.CUdeviceptr(tensor.data_ptr())
117
+ elif is_cupy_tensor(tensor):
118
+ ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
119
+ elif isinstance(tensor, cuda.CUdeviceptr):
120
+ ptr = tensor
121
+ elif isinstance(tensor, int):
122
+ ptr = cuda.CUdeviceptr(tensor)
123
+ else:
124
+ raise NotImplementedError(tensor)
125
+
126
+ return ptr
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from cutlass_cppgen.emit.pytorch import pytorch
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Common utilities for emitting CUTLASS kernels
35
+ """
36
+
37
+ import cutlass_cppgen
38
+
39
+ # Strings used for printing information about the generation of emitted scripts
40
+ _AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
41
+
42
+
43
+ _CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR}
44
+ """
45
+
46
+
47
+ _PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR}
48
+ """
49
+
50
+ _CUTLASS_KERNEL_ARGS_2x = """
51
+ typename DeviceKernel::Arguments arguments {
52
+ cutlass::gemm::GemmUniversalMode::kGemm,
53
+ {M, N, K}, // problem size
54
+ 1,
55
+ {alpha, beta},
56
+ A, B, C, D,
57
+ 0, 0, 0, 0, // batch strides
58
+ DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
59
+ DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
60
+ DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
61
+ DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
62
+ };
63
+ """
64
+
65
+ _CUTLASS_KERNEL_ARGS_2x_STREAM_K = """
66
+ typename DeviceKernel::Arguments arguments {
67
+ cutlass::gemm::GemmUniversalMode::kGemm,
68
+ {M, N, K}, // problem size
69
+ 1,
70
+ {alpha, beta},
71
+ A, B, C, D,
72
+ 0, 0, 0, 0, // batch strides
73
+ DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
74
+ DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
75
+ DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
76
+ DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
77
+ -1 // avail_sms
78
+ };
79
+ """
80
+
81
+ _CUTLASS_KERNEL_RUN_GEMM_2x = """
82
+ using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
83
+
84
+ cutlass::Status ${name}_kernel_run(int M, int N, int K,
85
+ const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
86
+ ElementCompute alpha, ElementCompute beta) {
87
+ ${args}
88
+ size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
89
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
90
+
91
+ DeviceKernel gemm_op;
92
+ cutlass::Status status = gemm_op.initialize(arguments,
93
+ workspace.get(),
94
+ nullptr); // CUDA stream
95
+
96
+ if (status != cutlass::Status::kSuccess) {
97
+ return status;
98
+ }
99
+
100
+ status = gemm_op();
101
+ return status;
102
+ }
103
+ """
104
+
105
+ _CUTLASS_KERNEL_RUN_GEMM_3x = """
106
+ using StrideA = typename DeviceKernel::GemmKernel::StrideA;
107
+ using StrideB = typename DeviceKernel::GemmKernel::StrideB;
108
+ using StrideC = typename DeviceKernel::GemmKernel::StrideC;
109
+ using StrideD = typename DeviceKernel::GemmKernel::StrideD;
110
+
111
+ using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
112
+
113
+ cutlass::Status ${name}_kernel_run(
114
+ int M, int N, int K, int L,
115
+ const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
116
+ ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) {
117
+
118
+ typename DeviceKernel::Arguments arguments{
119
+ cutlass::gemm::GemmUniversalMode::kGemm,
120
+ {M, N, K, L}, // problem size
121
+ {
122
+ A, // ptrA
123
+ cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
124
+ B, // ptrB
125
+ cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
126
+ },
127
+ {
128
+ {alpha, beta},
129
+ C, // ptrC
130
+ cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
131
+ D, // ptrD
132
+ cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
133
+ },
134
+ hw_info
135
+ };
136
+
137
+ size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
138
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
139
+
140
+ DeviceKernel gemm_op;
141
+ cutlass::Status status = gemm_op.run(arguments,
142
+ workspace.get(),
143
+ nullptr); // CUDA stream
144
+
145
+ return status;
146
+ }
147
+ """
148
+
149
+
150
+ _CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """
151
+ using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
152
+
153
+ int threadblock_count = DeviceKernel::sufficient();
154
+
155
+ cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes,
156
+ DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D,
157
+ int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd,
158
+ ElementCompute alpha, ElementCompute beta) {
159
+
160
+ typename DeviceKernel::Arguments arguments {
161
+ problem_sizes,
162
+ problem_count,
163
+ threadblock_count,
164
+ {alpha, beta},
165
+ A, B, C, D,
166
+ lda, ldb, ldc, ldd
167
+ };
168
+
169
+ size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
170
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
171
+
172
+ DeviceKernel gemm_op;
173
+ cutlass::Status status = gemm_op.initialize(arguments,
174
+ workspace.get(),
175
+ nullptr); // CUDA stream
176
+
177
+ if (status != cutlass::Status::kSuccess) {
178
+ return status;
179
+ }
180
+
181
+ status = gemm_op();
182
+ return status;
183
+ }
184
+ """
185
+
186
+
187
+ _CUTLASS_KERNEL_RUN_CONV2D_2x = """
188
+
189
+ using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
190
+ namespace {
191
+ using TensorRefA = typename UnderlyingKernel::TensorRefA;
192
+ using TensorRefB = typename UnderlyingKernel::TensorRefB;
193
+ using TensorRefC = typename UnderlyingKernel::TensorRefC;
194
+ using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
195
+ }
196
+
197
+ template<typename TensorRef, typename Element>
198
+ TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){
199
+ cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
200
+ TensorRef tensor_ref(ptr, layout);
201
+ return tensor_ref;
202
+ }
203
+
204
+ cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size,
205
+ UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B,
206
+ UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D,
207
+ ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
208
+ cudaStream_t stream, int device_id=0) {
209
+ // create the tensor references
210
+ cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
211
+ cutlass::conv::Operator::k${conv_kind_name}, *problem_size
212
+ );
213
+ cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
214
+ cutlass::conv::Operator::k${conv_kind_name}, *problem_size
215
+ );
216
+ cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
217
+ cutlass::conv::Operator::k${conv_kind_name}, *problem_size
218
+ );
219
+
220
+ TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
221
+ TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
222
+ TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
223
+ TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
224
+
225
+ cutlass::conv::SplitKMode mode;
226
+ if (split_k_mode == "serial") {
227
+ mode = cutlass::conv::SplitKMode::kSerial;
228
+ } else if (split_k_mode == "parallel") {
229
+ mode = cutlass::conv::SplitKMode::kParallel;
230
+ } else {
231
+ throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
232
+ }
233
+
234
+ typename DeviceKernel::Arguments arguments{
235
+ *problem_size,
236
+ tensor_ref_A,
237
+ tensor_ref_B,
238
+ tensor_ref_C,
239
+ tensor_ref_D,
240
+ {alpha, beta},
241
+ mode
242
+ };
243
+
244
+ DeviceKernel implicit_gemm_op;
245
+
246
+ size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
247
+
248
+ void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
249
+
250
+ cutlass::Status status = implicit_gemm_op.can_implement(arguments);
251
+ if (status != cutlass::Status::kSuccess) {
252
+ return status;
253
+ }
254
+
255
+ status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
256
+ if (status != cutlass::Status::kSuccess) {
257
+ return status;
258
+ }
259
+
260
+ //
261
+ // Launch initialized CUTLASS kernel
262
+ //
263
+ status = implicit_gemm_op(stream);
264
+
265
+ return status;
266
+ }
267
+ """
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.
35
+ If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.
36
+
37
+ Example usage with JIT compilation:
38
+
39
+ .. highlight:: python
40
+ .. code-block:: python
41
+
42
+ plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
43
+ op = plan.construct()
44
+ mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
45
+
46
+ # Generate inputs for the GEMM
47
+ A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
48
+
49
+ # Run the module
50
+ D = mod.run(A, B, C)
51
+
52
+
53
+ Example usage without JIT compilation:
54
+
55
+ .. highlight:: python
56
+ .. code-block:: python
57
+
58
+ plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
59
+ op = plan.construct()
60
+ cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
61
+
62
+ After this call, the directory ``output`` contains ``setup.py``,
63
+ ``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
64
+ within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.
65
+
66
+ The module can later be used in Python via:
67
+
68
+ .. highlight:: python
69
+ .. code-block:: python
70
+
71
+ import torch
72
+ import cutlass_gemm
73
+
74
+ # Generate inputs for the GEMM
75
+ A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
76
+
77
+ # Run the module
78
+ D = cutlass_gemm.run(A, B, C)
79
+ """
80
+
81
+ import logging
82
+ import os
83
+
84
+ from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate
85
+
86
+ from cutlass_cppgen import CUTLASS_PATH, logger, swizzle
87
+ from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
88
+ from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation
89
+ from cutlass_cppgen.backend.library import ApiVersion
90
+ from cutlass_cppgen.emit import common
91
+ from cutlass_cppgen.utils.datatypes import is_torch_available
92
+
93
+ if is_torch_available():
94
+ import torch
95
+
96
+
97
+ _PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
98
+ #include <cuda_runtime.h>
99
+ #include <torch/extension.h>
100
+ #include <ATen/ATen.h>
101
+ #include <ATen/cuda/CUDAContext.h>
102
+ #include "cutlass/cutlass.h"
103
+ #include "cutlass/util/device_memory.h"
104
+
105
+ // helper function allocating the memory
106
+ void* device_memory_allocation(size_t size, int device_id=0) {
107
+ if (size > 0) {
108
+ torch::Device device(torch::kCUDA, device_id);
109
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
110
+ torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
111
+ at::Tensor device_tensor = torch::empty({(long)size,}, options);
112
+ return reinterpret_cast<void*>(device_tensor.data_ptr());
113
+ } else {
114
+ return nullptr;
115
+ }
116
+ }
117
+
118
+ ${includes}
119
+ ${declaration}
120
+ ${impl}
121
+ """
122
+
123
+ _PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
124
+ #include <torch/extension.h>
125
+ #include <ATen/ATen.h>
126
+ #include <pybind11/stl.h>
127
+
128
+ // CUDA forward declarations
129
+ at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);
130
+
131
+ // C++ interface
132
+ at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {
133
+ return ${name}_kernel(A, B, C, alpha, beta);
134
+ }
135
+
136
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
137
+ m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
138
+ }
139
+ """
140
+
141
+ _PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
142
+ #include <torch/extension.h>
143
+ #include <ATen/ATen.h>
144
+ #include <pybind11/stl.h>
145
+
146
+ // CUDA forward declarations
147
+ std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);
148
+
149
+ // C++ interface
150
+ std::vector<at::Tensor> ${name}(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {
151
+ return ${name}_kernel(A, B, C, alpha, beta);
152
+ }
153
+
154
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
155
+ m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&${name}),
156
+ py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
157
+ }
158
+ """
159
+
160
+ _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
161
+ #include <torch/extension.h>
162
+ #include <ATen/ATen.h>
163
+ #include <pybind11/stl.h>
164
+
165
+ // CUDA forward declarations
166
+ at::Tensor ${name}_kernel(
167
+ const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
168
+ std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
169
+ float alpha=1.f, float beta=0.f,
170
+ std::string split_k_mode="serial", int split_k_slices=1);
171
+
172
+ // C++ interface
173
+ at::Tensor ${name}(
174
+ const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
175
+ std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
176
+ float alpha=1.f, float beta=0.f,
177
+ std::string split_k_mode="serial", int split_k_slices=1) {
178
+ return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
179
+ }
180
+
181
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
182
+ m.def("run",
183
+ py::overload_cast<
184
+ const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
185
+ std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
186
+ &${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
187
+ py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
188
+ py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
189
+ py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
190
+ }
191
+ """
192
+
193
+ _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
194
+ #include <torch/extension.h>
195
+ #include <ATen/ATen.h>
196
+ #include <pybind11/stl.h>
197
+
198
+ // CUDA forward declarations
199
+ at::Tensor ${name}_kernel(
200
+ std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
201
+ std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
202
+ float alpha=1.f, float beta=0.f,
203
+ std::string split_k_mode="serial", int split_k_slices=1);
204
+
205
+ // C++ interface
206
+ at::Tensor ${name}(
207
+ std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
208
+ std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
209
+ float alpha=1.f, float beta=0.f,
210
+ std::string split_k_mode="serial", int split_k_slices=1) {
211
+ return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
212
+ }
213
+
214
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
215
+ m.def("run",
216
+ py::overload_cast<
217
+ std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
218
+ std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
219
+ &${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
220
+ py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
221
+ py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
222
+ py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
223
+ }
224
+ """
225
+
226
+ _PYTORCH_GEMM_INCLUDES = {
227
+ ApiVersion.v2x: """
228
+ #include "cutlass/gemm/device/gemm_universal.h"
229
+ """,
230
+ ApiVersion.v3x: """
231
+ #include "cutlass/gemm/device/gemm_universal_adapter.h"
232
+ #include "cutlass/gemm/collective/collective_builder.hpp"
233
+ #include "cutlass/gemm/device/gemm_universal_adapter.h"
234
+ #include "cutlass/gemm/kernel/gemm_universal.hpp"
235
+ #include "cutlass/epilogue/collective/collective_builder.hpp"
236
+ #include "cutlass/util/packed_stride.hpp"
237
+ """,
238
+ }
239
+
240
+ _PYTORCH_GROUPED_GEMM_INCLUDES = """
241
+ #include "cutlass/gemm/kernel/default_gemm_grouped.h"
242
+ #include "cutlass/gemm/device/gemm_grouped.h"
243
+ """
244
+
245
+ _PYTORCH_CONV2D_INCLUDES = """
246
+ #include "cutlass/conv/kernel/default_conv2d_fprop.h"
247
+ #include "cutlass/conv/kernel/default_conv2d_dgrad.h"
248
+ #include "cutlass/conv/kernel/default_conv2d_wgrad.h"
249
+ #include "cutlass/conv/device/implicit_gemm_convolution.h"
250
+ """
251
+
252
+ _CUTLASS_TYPE_TO_TORCH_TYPE = {
253
+ DataType.f16: "torch::kF16",
254
+ DataType.f32: "torch::kF32",
255
+ DataType.f64: "torch::kF64",
256
+ DataType.s8: "torch::kI8",
257
+ DataType.s32: "torch::kI32",
258
+ DataType.bf16: "torch::kBFloat16",
259
+ }
260
+
261
+ _PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
262
+ common._CUTLASS_KERNEL_RUN_GEMM_2x
263
+ + """
264
+ at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
265
+ int M = A.size(0);
266
+ int N = B.size(1);
267
+ int K = A.size(1);
268
+
269
+ typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
270
+ nullptr :
271
+ reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
272
+ at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
273
+
274
+ cutlass::Status status = ${name}_kernel_run(M, N, K,
275
+ reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
276
+ reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
277
+ ptrC,
278
+ reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
279
+ ElementCompute(alpha), ElementCompute(beta));
280
+
281
+ TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
282
+ return D;
283
+ }
284
+ """
285
+ )
286
+
287
+ _PYTORCH_GEMM_IMPL_TEMPLATE_3x = (
288
+ common._CUTLASS_KERNEL_RUN_GEMM_3x
289
+ + """
290
+ bool hw_info_queried = false;
291
+ cutlass::KernelHardwareInfo hw_info;
292
+
293
+ at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
294
+ int M = A.size(0);
295
+ int N = B.size(1);
296
+ int K = A.size(1);
297
+ int L = 1;
298
+
299
+ // Query hardware info if we haven't already
300
+ if (!hw_info_queried) {
301
+ hw_info.device_id = 0;
302
+ hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
303
+ }
304
+
305
+ typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
306
+ nullptr :
307
+ reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
308
+ at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
309
+
310
+ cutlass::Status status = ${name}_kernel_run(M, N, K, L,
311
+ reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
312
+ reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
313
+ ptrC,
314
+ reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
315
+ ElementCompute(alpha), ElementCompute(beta),
316
+ hw_info);
317
+
318
+ TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
319
+ return D;
320
+ }
321
+ """
322
+ )
323
+
324
+
325
+ _PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = (
326
+ common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x
327
+ + """
328
+ std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {
329
+ size_t num = A.size();
330
+
331
+ // To avoid performing many small cudaMallocs and host-to-device copies,
332
+ // we serialize the grouped GEMM arguments on the host, allocate one
333
+ // large chunk of device memory, and perform a single cudaMemcpy to
334
+ // copy the host data to the device. Allocation overheads could be
335
+ // avoided by using a memory pool.
336
+
337
+ // Calculate the total size of the data to be copied from host to device
338
+ size_t total_size = sizeof(cutlass::gemm::GemmCoord) +
339
+ sizeof(DeviceKernel::ElementA*) +
340
+ sizeof(DeviceKernel::ElementB*) +
341
+ sizeof(DeviceKernel::ElementC*) +
342
+ sizeof(DeviceKernel::ElementC*) +
343
+ sizeof(int64_t) +
344
+ sizeof(int64_t) +
345
+ sizeof(int64_t);
346
+ total_size *= num;
347
+
348
+ // num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple
349
+ // of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).
350
+ // To ensure that we don't end up having misaligned loads in the kernel,
351
+ // we pad to the nearest multiple of 8.
352
+ //
353
+ // Note that, even on a 32-bit system (for which sizeof(X*) will not equal
354
+ // sizeof(int64_t)), only padding between the list of GemmCoords and the
355
+ // list of ptr_As is sufficient because the set of four equal-length lists of pointers
356
+ // (A*, B*, C*, D*) will ensure that the first list of int64_ts will always
357
+ // start on a multiple of 8.
358
+ int64_t padding = 8 - (total_size % 8);
359
+ total_size += padding;
360
+
361
+ uint8_t* host_data = new uint8_t[total_size];
362
+ cutlass::DeviceAllocation<uint8_t> device_data(total_size);
363
+
364
+ uint8_t* start = host_data;
365
+ cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);
366
+
367
+ // Apply the padding after the list of GemmCoords
368
+ start += num * sizeof(cutlass::gemm::GemmCoord) + padding;
369
+
370
+ int64_t ptr_A_offset = start - host_data;
371
+ DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);
372
+ start += num * sizeof(DeviceKernel::ElementA*);
373
+
374
+ int64_t ptr_B_offset = start - host_data;
375
+ DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);
376
+ start += num * sizeof(DeviceKernel::ElementB*);
377
+
378
+ int64_t ptr_C_offset = start - host_data;
379
+ DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
380
+ start += num * sizeof(DeviceKernel::ElementC*);
381
+
382
+ int64_t ptr_D_offset = start - host_data;
383
+ DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
384
+ start += num * sizeof(DeviceKernel::ElementC*);
385
+
386
+ int64_t lda_offset = start - host_data;
387
+ int64_t* lda_host = reinterpret_cast<int64_t*>(start);
388
+ start += num * sizeof(int64_t);
389
+
390
+ int64_t ldb_offset = start - host_data;
391
+ int64_t* ldb_host = reinterpret_cast<int64_t*>(start);
392
+ start += num * sizeof(int64_t);
393
+
394
+ int64_t ldc_offset = start - host_data;
395
+ int64_t* ldc_host = reinterpret_cast<int64_t*>(start);
396
+ start += num * sizeof(int64_t);
397
+
398
+ std::vector<at::Tensor> D(num);
399
+
400
+ bool need_C = (C != at::nullopt) && (beta != 0.f);
401
+ for (size_t i = 0; i < num; ++i) {
402
+ int M = A[i].size(0);
403
+ int N = B[i].size(1);
404
+ int K = A[i].size(1);
405
+ *(problem_sizes_host + i) = {M, N, K};
406
+ *(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());
407
+ *(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());
408
+
409
+ if (need_C) {
410
+ *(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());
411
+ }
412
+ else {
413
+ *(ptr_C_host + i) = nullptr;
414
+ }
415
+
416
+ D[i] = B[i].new_empty({M, N}, ${torch_type_C});
417
+ *(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());
418
+
419
+ *(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);
420
+ *(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);
421
+ *(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);
422
+ }
423
+
424
+ device_data.copy_from_host(host_data);
425
+
426
+ cutlass::Status status = ${name}_kernel_run(
427
+ num,
428
+ reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),
429
+ reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),
430
+ reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),
431
+ reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),
432
+ reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),
433
+ reinterpret_cast<int64_t*>(device_data.get() + lda_offset),
434
+ reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),
435
+ reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
436
+ reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
437
+ ElementCompute(alpha), ElementCompute(beta));
438
+
439
+ delete[] host_data;
440
+
441
+ TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
442
+ return D;
443
+ }
444
+ """
445
+ )
446
+
447
+ _PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
448
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
449
+
450
+ cutlass::Status status = ${name}_kernel_run(
451
+ &problem_size,
452
+ reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
453
+ reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
454
+ ptrC,
455
+ reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
456
+ alpha, beta,
457
+ split_k_mode, stream, B.device().index());
458
+
459
+ TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
460
+ return D;
461
+ }
462
+ """
463
+
464
+ _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
465
+ common._CUTLASS_KERNEL_RUN_CONV2D_2x
466
+ + """
467
+ at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
468
+ std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
469
+ float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
470
+ int N, H, W, C_, K, R, S, P, Q;
471
+ N = A.size(0);
472
+ C_ = A.size(1);
473
+ H = A.size(2);
474
+ W = A.size(3);
475
+
476
+ K = B.size(0);
477
+ R = B.size(2);
478
+ S = B.size(3);
479
+
480
+ cutlass::conv::Conv2dProblemSize problem_size(
481
+ cutlass::Tensor4DCoord(N, H, W, C_),
482
+ cutlass::Tensor4DCoord(K, R, S, C_),
483
+ cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
484
+ cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
485
+ cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
486
+ cutlass::conv::Mode::kCrossCorrelation,
487
+ split_k_slices
488
+ );
489
+
490
+ P = problem_size.P;
491
+ Q = problem_size.Q;
492
+
493
+ typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
494
+ nullptr :
495
+ reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
496
+
497
+ torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
498
+ at::Tensor D = torch::zeros({N, K, P, Q}, options);
499
+ """ + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
500
+ )
501
+
502
+
503
+ _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
504
+ common._CUTLASS_KERNEL_RUN_CONV2D_2x
505
+ + """
506
+ at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
507
+ std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
508
+ std::string split_k_mode="serial", int split_k_slices=1) {
509
+ int N, H, W, C_, K, R, S;
510
+ N = std::get<0>(input_size);
511
+ C_ = std::get<1>(input_size);
512
+ H = std::get<2>(input_size);
513
+ W = std::get<3>(input_size);
514
+
515
+ K = B.size(0);
516
+ R = B.size(2);
517
+ S = B.size(3);
518
+
519
+ cutlass::conv::Conv2dProblemSize problem_size(
520
+ cutlass::Tensor4DCoord(N, H, W, C_),
521
+ cutlass::Tensor4DCoord(K, R, S, C_),
522
+ cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
523
+ cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
524
+ cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
525
+ cutlass::conv::Mode::kCrossCorrelation,
526
+ split_k_slices
527
+ );
528
+
529
+ typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
530
+ nullptr :
531
+ reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
532
+
533
+ torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
534
+ at::Tensor D = torch::empty({N, C_, H, W}, options);
535
+ """ + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
536
+ )
537
+
538
+
539
+ _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
540
+ common._CUTLASS_KERNEL_RUN_CONV2D_2x
541
+ + """
542
+ at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
543
+ std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
544
+ std::string split_k_mode="serial", int split_k_slices=1) {
545
+ int N, H, W, C_, K, R, S;
546
+ K = std::get<0>(weight_size);
547
+ C_ = std::get<1>(weight_size);
548
+ R = std::get<2>(weight_size);
549
+ S = std::get<3>(weight_size);
550
+
551
+ N = B.size(0);
552
+ H = B.size(2);
553
+ W = B.size(3);
554
+
555
+ cutlass::conv::Conv2dProblemSize problem_size(
556
+ cutlass::Tensor4DCoord(N, H, W, C_),
557
+ cutlass::Tensor4DCoord(K, R, S, C_),
558
+ cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
559
+ cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
560
+ cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
561
+ cutlass::conv::Mode::kCrossCorrelation,
562
+ split_k_slices
563
+ );
564
+
565
+ typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
566
+ nullptr :
567
+ reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
568
+
569
+ torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
570
+ at::Tensor D = torch::empty({K, C_, R, S}, options);
571
+ """ + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
572
+ )
573
+
574
+
575
+ _PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
576
+ from setuptools import setup
577
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
578
+
579
+ setup(
580
+ name='${name}',
581
+ ext_modules=[
582
+ CUDAExtension('${name}', [
583
+ '${name}.cpp',
584
+ '${name}_kernel.cu',
585
+ ],
586
+ include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
587
+ extra_compile_args={
588
+ 'cxx': ['-std=c++17'],
589
+ 'nvcc': ['-std=c++17', ${extra_compile_args}],
590
+ },
591
+ libraries=['cuda']
592
+ ),
593
+ ],
594
+ cmdclass={
595
+ 'build_ext': BuildExtension
596
+ })
597
+
598
+ """
599
+
600
+
601
+ def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""):
602
+ """
603
+ Generates a setup.py file for the extension
604
+
605
+ :param name: name of the module to generate
606
+ :type name: str
607
+ :param sourcedir: directory to which generated source files should be written
608
+ :type sourcedir: str
609
+ :param extra_compile_args: additional arguments to pass to setup.py
610
+ :type extra_args: str
611
+ """
612
+ setup_py_file = os.path.join(sourcedir, "setup.py")
613
+ setup_source = SubstituteTemplate(
614
+ _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args}
615
+ )
616
+ with open(setup_py_file, "w") as outfile:
617
+ outfile.write(setup_source)
618
+
619
+
620
+ class _ArchListSetter:
621
+ """
622
+ Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``
623
+ environment variable when building a PyTorch CUDA module.
624
+
625
+ ``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch
626
+ CUDA module should be compiled.
627
+
628
+ For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of
629
+ ``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the
630
+ compilation of the module.
631
+
632
+ This utility wraps the building of a PyTorch CUDA module with a setting of this environment
633
+ variable according to the current compute capability being targetted.
634
+
635
+ Example usage:
636
+
637
+ .. highlight:: python
638
+ .. code-block:: python
639
+
640
+ # Temporarily set TORCH_CUDA_ARCH_LIST="8.0"
641
+ with _ArchListSetter(80):
642
+ # Perform JIT compilation and loading of the module
643
+ mod = torch.utils.cpp_extension.load(...)
644
+
645
+ :param cc: compute capability
646
+ :type cc: int
647
+ """
648
+
649
+ _TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST"
650
+
651
+ def __init__(self, cc: int):
652
+ self.cc_str = ".".join(list(str(cc)))
653
+
654
+ def __enter__(self):
655
+ """
656
+ Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``
657
+ """
658
+ self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST)
659
+ os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str
660
+
661
+ return self
662
+
663
+ def __exit__(self, exc_type, exc_val, traceback):
664
+ """
665
+ Restores the old value of TORCH_CUDA_ARCH_LIST
666
+ """
667
+ if self.old_arch_list is None:
668
+ del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST]
669
+ else:
670
+ os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list
671
+
672
+
673
+ def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
674
+ """
675
+ JIT compiles and loads a PyTorch CUDA extension.
676
+
677
+ :param name: name of the module to generate
678
+ :type name: str
679
+ :param cc: compute capability of the device the module should target
680
+ :type cc: int
681
+ :param cpp_file: path to file containing extension's C++ interface
682
+ :type cpp_file: str
683
+ :param cuda_file: path to file containing extension's CUDA interface
684
+ :type cuda_file: str
685
+
686
+ :return: loaded PyTorch module
687
+ """
688
+
689
+ from torch.utils.cpp_extension import load
690
+
691
+ extra_cuda_cflags = ["-std=c++17"]
692
+ if cc in [90, 100, 101, 103]:
693
+ # PyTorch does not currently add the sm_90a target when compute capability
694
+ # 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.
695
+ extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a")
696
+
697
+ with _ArchListSetter(cc):
698
+ jitmodule = load(
699
+ name,
700
+ [cpp_file, cuda_file],
701
+ extra_cuda_cflags=extra_cuda_cflags,
702
+ extra_include_paths=[
703
+ os.path.join(CUTLASS_PATH, "include"),
704
+ os.path.join(CUTLASS_PATH, "tools/util/include"),
705
+ ],
706
+ extra_ldflags=["-lcuda"],
707
+ verbose=(logger.level == logging.DEBUG)
708
+ )
709
+ return jitmodule
710
+
711
+
712
+ def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
713
+ """
714
+ Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM
715
+ specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
716
+ compiled, loaded, and returned.
717
+
718
+ :param op: operation to emit in the module
719
+ :param name: name of the module to generate
720
+ :type name: str
721
+ :param cc: compute capability of the device the module should target
722
+ :type cc: int
723
+ :param jit: whether the module should be just-in-time compiled
724
+ :type jit: bool
725
+ :param sourcedir: directory to which generated source files should be written
726
+ :type sourcedir: str
727
+
728
+ :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
729
+ """
730
+ if sourcedir != "" and not os.path.isdir(sourcedir):
731
+ os.makedirs(sourcedir)
732
+
733
+ cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
734
+ extra_kw = {}
735
+ if op.api == ApiVersion.v3x:
736
+ impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
737
+ else:
738
+ impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
739
+ if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
740
+ extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
741
+ else:
742
+ extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
743
+ impl_template = (
744
+ _PYTORCH_GEMM_IMPL_TEMPLATE_3x
745
+ if op.api == ApiVersion.v3x
746
+ else _PYTORCH_GEMM_IMPL_TEMPLATE_2x
747
+ )
748
+ cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
749
+ cuda_source = SubstituteTemplate(
750
+ _PYTORCH_CUDA_TEMPLATE,
751
+ {
752
+ "includes": _PYTORCH_GEMM_INCLUDES[op.api],
753
+ "declaration": op.rt_module.emit(),
754
+ "procedural_name": op.procedural_name(),
755
+ "impl": cuda_impl,
756
+ "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
757
+ },
758
+ )
759
+ with open(cuda_file, "w") as outfile:
760
+ outfile.write(cuda_source)
761
+
762
+ cpp_file = os.path.join(sourcedir, name + ".cpp")
763
+ cpp_source = SubstituteTemplate(
764
+ _PYTORCH_GEMM_CPP_TEMPLATE,
765
+ {"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"},
766
+ )
767
+ with open(cpp_file, "w") as outfile:
768
+ outfile.write(cpp_source)
769
+
770
+ extra_compile_args = ""
771
+ if cc in [90, 100, 101, 103]:
772
+ extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'"
773
+ _generate_setup(name, sourcedir, extra_compile_args)
774
+
775
+ if jit:
776
+ return _jit(name, cc, cpp_file, cuda_file)
777
+
778
+ return None
779
+
780
+
781
+ def _pytorch_grouped_gemm(
782
+ op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
783
+ ):
784
+ """
785
+ Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM
786
+ specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
787
+ compiled, loaded, and returned.
788
+
789
+ :param op: operation to emit in the module
790
+ :param name: name of the module to generate
791
+ :type name: str
792
+ :param cc: compute capability of the device the module should target
793
+ :type cc: int
794
+ :param jit: whether the module should be just-in-time compiled
795
+ :type jit: bool
796
+ :param sourcedir: directory to which generated source files should be written
797
+ :type sourcedir: str
798
+
799
+ :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
800
+ """
801
+ if op.api != ApiVersion.v2x:
802
+ raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x")
803
+
804
+ if sourcedir != "" and not os.path.isdir(sourcedir):
805
+ os.makedirs(sourcedir)
806
+
807
+ cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
808
+ cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name})
809
+ cuda_source = SubstituteTemplate(
810
+ _PYTORCH_CUDA_TEMPLATE,
811
+ {
812
+ "includes": _PYTORCH_GROUPED_GEMM_INCLUDES,
813
+ "declaration": op.rt_module.emit(),
814
+ "procedural_name": op.procedural_name(),
815
+ "impl": cuda_impl,
816
+ "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
817
+ },
818
+ )
819
+ with open(cuda_file, "w") as outfile:
820
+ outfile.write(cuda_source)
821
+
822
+ cpp_file = os.path.join(sourcedir, name + ".cpp")
823
+ cpp_source = SubstituteTemplate(
824
+ _PYTORCH_GROUPED_GEMM_CPP_TEMPLATE,
825
+ {"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"},
826
+ )
827
+ with open(cpp_file, "w") as outfile:
828
+ outfile.write(cpp_source)
829
+
830
+ _generate_setup(name, sourcedir)
831
+
832
+ if jit:
833
+ return _jit(name, cc, cpp_file, cuda_file)
834
+
835
+ return None
836
+
837
+
838
+ def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
839
+ """
840
+ Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d
841
+ specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
842
+ compiled, loaded, and returned.
843
+
844
+ :param op: operation to emit in the module
845
+ :param name: name of the module to generate
846
+ :type name: str
847
+ :param cc: compute capability of the device the module should target
848
+ :type cc: int
849
+ :param jit: whether the module should be just-in-time compiled
850
+ :type jit: bool
851
+ :param sourcedir: directory to which generated source files should be written
852
+ :type sourcedir: str
853
+
854
+ Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
855
+ weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
856
+ for H/W/R/S given the same P/Q.
857
+
858
+ :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
859
+ """
860
+ if sourcedir != "" and not os.path.isdir(sourcedir):
861
+ os.makedirs(sourcedir)
862
+ cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
863
+ extra_kw = {}
864
+ if op.conv_kind == ConvKind.Fprop:
865
+ impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
866
+ cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
867
+ elif op.conv_kind == ConvKind.Dgrad:
868
+ impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
869
+ cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
870
+ elif op.conv_kind == ConvKind.Wgrad:
871
+ impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
872
+ cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
873
+ extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
874
+ extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element]
875
+ cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
876
+ cuda_source = SubstituteTemplate(
877
+ _PYTORCH_CUDA_TEMPLATE,
878
+ {
879
+ "includes": _PYTORCH_CONV2D_INCLUDES,
880
+ "declaration": op.rt_module.emit(),
881
+ "procedural_name": op.procedural_name(),
882
+ "impl": cuda_impl,
883
+ "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
884
+ },
885
+ )
886
+ with open(cuda_file, "w") as outfile:
887
+ outfile.write(cuda_source)
888
+
889
+ cpp_file = os.path.join(sourcedir, name + ".cpp")
890
+ cpp_source = SubstituteTemplate(
891
+ cpp_template,
892
+ {"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"},
893
+ )
894
+ with open(cpp_file, "w") as outfile:
895
+ outfile.write(cpp_source)
896
+
897
+ _generate_setup(name, sourcedir)
898
+
899
+ if jit:
900
+ return _jit(name, cc, cpp_file, cuda_file)
901
+
902
+ return None
903
+
904
+
905
+ def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
906
+ """
907
+ Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
908
+ specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
909
+ compiled, loaded, and returned.
910
+
911
+ The result of this method is files within ``sourcedir`` that can be used for building
912
+ a PyTorch module.
913
+
914
+ :param op: operation to emit in the module
915
+ :param name: name of the module to generate
916
+ :type name: str
917
+ :param cc: compute capability of the device the module should target
918
+ :type cc: int
919
+ :param jit: whether the module should be just-in-time compiled
920
+ :type jit: bool
921
+ :param sourcedir: directory to which generated source files should be written
922
+ :type sourcedir: str
923
+
924
+ :return: loaded PyTorch module (if ``jit=True``) or None
925
+ """
926
+ device_op = op.device_op()
927
+ if isinstance(op, GemmOperationUniversal):
928
+ return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
929
+ elif isinstance(op, GemmOperationGrouped):
930
+ return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
931
+ elif isinstance(op, Conv2dOperation):
932
+ return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
933
+ else:
934
+ raise Exception(
935
+ f"Operation type {type(op)} is not currently supported for PyTorch emission."
936
+ )
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from cutlass_cppgen.epilogue.epilogue import (
34
+ get_activations,
35
+ get_activation_epilogue,
36
+ gelu,
37
+ hardswish,
38
+ identity,
39
+ leaky_relu,
40
+ relu,
41
+ sigmoid,
42
+ silu,
43
+ tanh,
44
+ trace
45
+ )
46
+
47
+ from cutlass_cppgen.epilogue.evt_ops import (
48
+ max,
49
+ multiply_add,
50
+ sum,
51
+ permute,
52
+ reshape,
53
+ maximum,
54
+ minimum,
55
+ exp
56
+ )
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Registry of elementwise epilogues
35
+
36
+ Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via
37
+ code like the following for GEMM:
38
+
39
+ .. highlight:: python
40
+ .. code-block:: python
41
+
42
+ plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
43
+ plan.activation = cutlass_cppgen.epilogue.relu
44
+ """
45
+
46
+ from cutlass_cppgen.backend import epilogue, device_cc
47
+
48
+
49
+ gelu = epilogue.gelu
50
+ hardswish = epilogue.hardswish
51
+ identity = epilogue.identity
52
+ leaky_relu = epilogue.leaky_relu
53
+ relu = epilogue.relu
54
+ sigmoid = epilogue.sigmoid
55
+ silu = epilogue.silu
56
+ tanh = epilogue.tanh
57
+
58
+
59
+ _activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh]
60
+
61
+
62
+ def get_activations() -> list:
63
+ """
64
+ Returns a list of available activation functions
65
+
66
+ :return: list of available activation functions
67
+ :rtype: list
68
+ """
69
+ return _activations
70
+
71
+
72
+ def get_activation_epilogue(
73
+ activation,
74
+ element_output,
75
+ elements_per_access,
76
+ element_accumulator,
77
+ element_compute,
78
+ ):
79
+ """
80
+ Return an epilogue corresponding to the activation function, data types, and alignment
81
+ used in the kernel
82
+
83
+ :param activation: elementwise activation function to use
84
+ :param element_output: data type of the output
85
+ :param elements_per_access: alignment of operand C of the kernel
86
+ :type elements_per_access: int
87
+ :param element_accumulator: data type of the accumulated output C
88
+ :param element_compute: data type in which compute operations should be performed
89
+
90
+ :return: epilogue functor
91
+ """
92
+ if activation not in _activations:
93
+ raise Exception(
94
+ f"Unsupported activation type {activation}. Available activations are: {_activations}"
95
+ )
96
+
97
+ if activation == identity:
98
+ return epilogue.LinearCombination(
99
+ element_output, elements_per_access, element_accumulator, element_compute
100
+ )
101
+ else:
102
+ return epilogue.LinearCombinationGeneric(
103
+ activation,
104
+ element_output,
105
+ elements_per_access,
106
+ element_accumulator,
107
+ element_compute,
108
+ )
109
+
110
+
111
+ """
112
+ Frontend for EVT that generates epilogue functor through tracing the input function
113
+ """
114
+ from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
115
+
116
+
117
+ def trace(fn, example_tensors, **kwargs):
118
+ """
119
+ Trace `fn(**example_tensors)` and generates epilogue visitor
120
+
121
+ :param fn or str: Python callable or string of the epilogue function
122
+ :param example_tensors: example inputs for fn
123
+ :type example_tensors: dict
124
+
125
+ .. hightlight:: python
126
+ .. code-block:: python
127
+ import cutlass_cppgen.backend.evt
128
+
129
+ # Define epilogue function as Python callable
130
+ def example_fn(accum, C, alpha, beta, gamma):
131
+ D = ((accum + C) * alpha - gamma) / beta
132
+ return D
133
+
134
+ # Define the example tensors
135
+ example_inputs = {
136
+ "accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
137
+ "C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
138
+ "alpha": 1.5,
139
+ "beta": 0.5,
140
+ "gamma": 2.5,
141
+ "D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda")
142
+ }
143
+
144
+ # Generate the epilogue functor
145
+ epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs)
146
+ """
147
+ if callable(fn):
148
+ class EpilogueFunctor(PythonASTFrontend):
149
+ def __init__(self, cc=None, **kwargs):
150
+ if not cc:
151
+ cc = device_cc()
152
+ super().__init__(cc, **kwargs)
153
+ pass
154
+ setattr(EpilogueFunctor, "__call__", staticmethod(fn))
155
+
156
+ epilogue_functor = EpilogueFunctor(**kwargs)
157
+ epilogue_functor.trace(example_tensors)
158
+ return epilogue_functor
159
+ elif isinstance(fn, str):
160
+ class EpilogueFunctor(PythonASTFrontend):
161
+ def __init__(self, cc=None, **kwargs):
162
+ self.source = textwrap.dedent(fn)
163
+ if not cc:
164
+ cc = device_cc()
165
+ super().__init__(cc, **kwargs)
166
+
167
+ def parse(self, example_inputs) -> None:
168
+ self.example_inputs = example_inputs
169
+ self.ast = ast.parse(self.source)
170
+ self.visit(self.ast)
171
+
172
+ epilogue_functor = EpilogueFunctor(**kwargs)
173
+ epilogue_functor.trace(example_tensors)
174
+ return epilogue_functor
175
+ else:
176
+ raise NotImplementedError("Expect a callable Python function")
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Collection of builtin functions used for host reference in EVT
35
+ """
36
+
37
+ import numpy as np
38
+
39
+ from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor
40
+
41
+ if is_torch_available():
42
+ import torch
43
+
44
+
45
+ def multiply_add(x, y, z):
46
+ return x * y + z
47
+
48
+
49
+ def sum(x, dim):
50
+ if is_numpy_tensor(x):
51
+ return x.sum(axis=tuple(dim))
52
+ elif is_torch_tensor(x):
53
+ return torch.sum(x, dim)
54
+
55
+
56
+ def max(x, dim):
57
+ if is_numpy_tensor(x):
58
+ return x.max(axis=tuple(dim))
59
+ elif is_torch_tensor(x):
60
+ return torch.amax(x, dim)
61
+
62
+
63
+ def maximum(x, y):
64
+ if is_numpy_tensor(x):
65
+ return np.maximum(x, y)
66
+ elif is_torch_tensor(x):
67
+ return torch.maximum(x, torch.tensor(y))
68
+
69
+
70
+ def minimum(x, y):
71
+ if is_numpy_tensor(x):
72
+ return np.minimum(x, y)
73
+ elif is_torch_tensor(x):
74
+ return torch.minimum(x, torch.tensor(y))
75
+
76
+ def exp(x):
77
+ if is_numpy_tensor(x):
78
+ return np.exp(x)
79
+ elif is_torch_tensor(x):
80
+ return torch.exp(x)
81
+
82
+
83
+ ##############################################################################
84
+ # Layout manipulate nodes
85
+ ##############################################################################
86
+
87
+ def permute(x, indices: tuple):
88
+ if is_numpy_tensor(x):
89
+ return np.transpose(x, axes=indices)
90
+ elif is_torch_tensor(x):
91
+ return x.permute(*indices)
92
+
93
+
94
+ def reshape(x, new_shape: tuple):
95
+ if is_numpy_tensor(x):
96
+ return np.reshape(x, newshape=new_shape)
97
+ elif is_torch_tensor(x):
98
+ return x.view(new_shape)
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Classes containing valid operations for a given compute capability and data types.
35
+ """
36
+
37
+ from itertools import combinations_with_replacement
38
+ import logging
39
+
40
+ import cutlass_library
41
+ from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
42
+
43
+ import cutlass_cppgen
44
+ from cutlass_cppgen.utils.check import valid_stage_count
45
+ from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op
46
+
47
+
48
+ _generator_ccs = [50, 60, 61, 70, 75, 80, 90, 100]
49
+
50
+
51
+ class KernelsForDataType:
52
+ """
53
+ Container class for keeping track of kernels that correspond to a particular combination
54
+ of data types for operands A, B, and accumulator
55
+ """
56
+
57
+ def __init__(self, datatype_comb: tuple, layout_comb: tuple):
58
+ self.datatype_comb = datatype_comb
59
+ self.layout_comb = layout_comb
60
+ self.math_operations = set()
61
+
62
+ # Dictionary mapping from alignment (int) to a list of kernels that fit the alignment
63
+ # constraint for the data type combination
64
+ self.kernels_by_alignment = {}
65
+
66
+ def add(self, operation):
67
+ """
68
+ Add an operation to the list of supported kernels
69
+ """
70
+ alignment_key = f"{operation.A.alignment} {operation.B.alignment} {operation.C.alignment}"
71
+ if alignment_key not in self.kernels_by_alignment:
72
+ self.kernels_by_alignment[alignment_key] = []
73
+ self.kernels_by_alignment[alignment_key].append(operation)
74
+ self.math_operations.add(operation.tile_description.math_instruction.math_operation)
75
+
76
+ def alignments(self, operand: str):
77
+ """
78
+ Returns an unsorted list of alignments supported by this data type combination
79
+
80
+ :param operand: identifier of operand in question (e.g., A, B, C)
81
+ :type operand: str
82
+
83
+ :return: unsorted list of alignments supported by this data type combination
84
+ :rtype: list
85
+ """
86
+ operand_idx = self._operand_idx(operand)
87
+ return [int(key.split(" ")[operand_idx]) for key in self.kernels_by_alignment.keys()]
88
+
89
+ @property
90
+ def all_operations(self):
91
+ """
92
+ Returns a list of all operations supported by this data type combination
93
+
94
+ :return: list of all operations supported by this data type combination
95
+ :rtype: list
96
+ """
97
+ ops = []
98
+ for _, alignment_ops in self.kernels_by_alignment.items():
99
+ ops.extend(alignment_ops)
100
+ return ops
101
+
102
+ def default_operation(self, math_operation: cutlass_cppgen.MathOperation):
103
+ key = sorted(list(self.kernels_by_alignment.keys()))[0]
104
+ kernels = self.kernels_by_alignment[key]
105
+ if math_operation is not None:
106
+ kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation]
107
+ return kernels[0]
108
+
109
+ def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation):
110
+ """
111
+ Returns operations satisfying the alignment constraints
112
+
113
+ :param alignment_A: alignment constraint of operations to return
114
+ :type alignment_A: int
115
+ :param alignment_B: alignment constraint of operations to return
116
+ :type alignment_B: int
117
+ :param alignment_C: alignment constraint of operations to return
118
+ :type alignment_C: int
119
+ :param math_operation: math operation to consider
120
+ :type math_operation: cutlass_cppgen.MathOperation
121
+
122
+ :return: list of operations
123
+ :rtype: list
124
+ """
125
+ key = f"{alignment_A} {alignment_B} {alignment_C}"
126
+
127
+ if key not in self.kernels_by_alignment:
128
+ og_key = key
129
+ # Reconcile A, B, and C alignments by trying to align to the minimum
130
+ min_alignment = min(alignment_A, alignment_B, alignment_C)
131
+ key = f"{min_alignment} {min_alignment} {min_alignment}"
132
+ if key not in self.kernels_by_alignment:
133
+ # Finally, go through all available alignment combinations and find
134
+ # one for which all values are less than those passed in.
135
+ key = None
136
+ alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
137
+ for align_A, align_B, align_C in alignments:
138
+ if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0:
139
+ key = f"{align_A} {align_B} {align_C}"
140
+ break
141
+
142
+ if key is None:
143
+ raise Exception(
144
+ f"No operations of alignment {og_key} found for data type and layout "
145
+ f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments "
146
+ f"are {self.kernels_by_alignment.keys()}"
147
+ )
148
+
149
+ ops = self.kernels_by_alignment[key]
150
+ if math_operation is not None:
151
+ ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation]
152
+ return ops
153
+
154
+ def _operand_idx(self, key: str) -> int:
155
+ operand_list = ["A", "B", "C"]
156
+ if key not in operand_list:
157
+ raise Exception(f"Unexpected operand {operand}")
158
+
159
+ return operand_list.index(key)
160
+
161
+ def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int:
162
+ """
163
+ Returns the most preferable alignment for a given shape and layout
164
+
165
+ :param shape: extent of each dimension of the tensor
166
+ :type shape: tuple
167
+ :param layout: layout of the tensor
168
+ :type layout: cutlass_cppgen.LayoutType
169
+ :param operand: descriptor of the operand in question
170
+ :type operand: str
171
+
172
+ :return: maximum alignment supported by the data type combination and tensor size
173
+ :rtype: int
174
+ """
175
+ operand_idx = self._operand_idx(operand)
176
+
177
+ # Determine the leading dimension of the shape
178
+ if layout == cutlass_cppgen.LayoutType.ColumnMajor:
179
+ ld = shape[-2]
180
+ elif layout == cutlass_cppgen.LayoutType.RowMajor:
181
+ ld = shape[-1]
182
+ elif layout == cutlass_cppgen.LayoutType.TensorNHWC:
183
+ ld = shape[-1]
184
+ else:
185
+ raise Exception(f"Unexpected or unsupported layout {layout}")
186
+
187
+ for alignments in sorted(list(self.kernels_by_alignment.keys()), reverse=True):
188
+ alignment = int(alignments.split(" ")[operand_idx])
189
+ if ld % alignment == 0:
190
+ return alignment
191
+
192
+ # Default to alignment of 1 if no others match
193
+ return 1
194
+
195
+ def sort(self):
196
+ """
197
+ Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape
198
+ """
199
+ key = lambda op: (
200
+ op.tile_description.threadblock_shape[0]
201
+ * op.tile_description.threadblock_shape[1]
202
+ * op.tile_description.threadblock_shape[2]
203
+ )
204
+ for alignment in self.kernels_by_alignment.keys():
205
+ self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
206
+
207
+ def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool:
208
+ """
209
+ Returns whether `math_operation` is supported by at least one operation.
210
+
211
+ :param math_operation: math operation to consider
212
+ :type math_operation: cutlass_cppgen.MathOperation
213
+
214
+ :return: whether math_operation is supported by at least one operation
215
+ :rtype: bool
216
+ """
217
+ return math_operation is None or math_operation in self.math_operations
218
+
219
+
220
+ class ArchOptions:
221
+ """
222
+ Structure for keeping track of kernels available on a given compute capability
223
+
224
+ :param target_cc: compute capability of the device on which kernels will be run
225
+ :type target_cc: int
226
+ :param kernel_cc: compute capability of the kernels to generate
227
+ :type kernel_cc: int
228
+ :param operation_kind: type of operation to register
229
+ :type operation_kind: cutlass_library.OperationKind
230
+ :param gemm_kinds: types of GEMM operations that can be included
231
+ :type gemm_kinds: list
232
+ :param allowed_math_operations: types of primitive math operations allowed
233
+ :type allowed_math_operations: list
234
+ """
235
+
236
+ def __init__(
237
+ self,
238
+ target_cc: int,
239
+ kernel_cc: int,
240
+ operation_kind: cutlass_library.OperationKind,
241
+ gemm_kinds: list,
242
+ allowed_math_operations: list = [
243
+ cutlass_library.MathOperation.multiply_add,
244
+ cutlass_library.MathOperation.multiply_add_saturate,
245
+ cutlass_library.MathOperation.multiply_add_mixed_input_upcast,
246
+ cutlass_library.MathOperation.multiply_add_fast_f32
247
+ ]
248
+ ):
249
+ self.cc = kernel_cc
250
+
251
+ # Dictionary with following structure:
252
+ # Key: OpcodeClass
253
+ # Value: Dictionary with the following structure:
254
+ # Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType),
255
+ # representing ((element_a, element_b, element_accumulator), (layout_a, layout_b))
256
+ # Value: KernelsForDataType
257
+ self.operations_by_opclass = {}
258
+ self.op_class = None
259
+ self.allowed_math_operations = allowed_math_operations
260
+
261
+ if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100:
262
+ return
263
+
264
+ # Identify the method within CUTLASS generator script that generates kernel
265
+ # descriptions for the target CC
266
+ generate_function_name = "GenerateSM" + str(kernel_cc)
267
+ if not hasattr(cutlass_library.generator, generate_function_name):
268
+ cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}")
269
+ return
270
+ generate_function = getattr(cutlass_library.generator, generate_function_name)
271
+
272
+ # Initialize a default manifest and populate it with valid kernel descriptions
273
+ # for the target CC
274
+ args = [
275
+ "--kernels=all",
276
+ f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}"
277
+ ]
278
+ manifest_args = cutlass_library.generator.define_parser().parse_args(args)
279
+ manifest = cutlass_library.manifest.Manifest(manifest_args)
280
+ generate_function(manifest, cutlass_cppgen._nvcc_version)
281
+
282
+ if operation_kind not in manifest.operations:
283
+ # No kernels generated for this architecture, this could be because the CUDA
284
+ # toolkit is insufficient to support operations in this CC
285
+ cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
286
+ return
287
+
288
+ # Only one CC should be returned, given the setup above of calling only the generation scripts
289
+ # for a given CC
290
+ if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]:
291
+ raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version "
292
+ "is sufficient for the architecture in question.")
293
+
294
+ # Iterate through the available operations for this operation kind and
295
+ # find available opclasses and data types
296
+ for name, op_list in manifest.operations[operation_kind][kernel_cc].items():
297
+ for op in op_list:
298
+
299
+ if operation_kind == cutlass_library.OperationKind.Gemm:
300
+ if op.gemm_kind not in gemm_kinds:
301
+ continue
302
+
303
+ mi = op.tile_description.math_instruction
304
+ if mi.math_operation not in self.allowed_math_operations:
305
+ continue
306
+
307
+ # Prune operations that don't fit in shared memory
308
+ td = td_from_profiler_op(op)
309
+ if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]:
310
+ continue
311
+
312
+ if mi.opcode_class not in self.operations_by_opclass:
313
+ self.operations_by_opclass[mi.opcode_class] = {}
314
+
315
+ datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
316
+ layout_comb = (op.A.layout, op.B.layout)
317
+
318
+ # Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations
319
+ if datatype_comb == (cutlass_library.DataType.tf32, cutlass_library.DataType.tf32, cutlass_library.DataType.f32):
320
+ # TF32 kernels only supported on SM80 and beyond
321
+ if self.cc < 80:
322
+ continue
323
+ elif self.cc == 90 or self.cc == 100:
324
+ if (op.A.element != cutlass_library.DataType.f32
325
+ or op.B.element != cutlass_library.DataType.f32
326
+ or op.C.element != cutlass_library.DataType.f32):
327
+ continue
328
+
329
+ datatype_comb = (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32)
330
+
331
+ opclass_dict = self.operations_by_opclass[mi.opcode_class]
332
+ key = (datatype_comb, layout_comb)
333
+ if key not in opclass_dict:
334
+ opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb)
335
+ opclass_dict[key].add(op)
336
+
337
+ # Set the default opclass to TensorOp, if available. Otherwise default to SIMT
338
+ if cutlass_library.OpcodeClass.TensorOp in self.operations_by_opclass:
339
+ self.op_class = cutlass_library.OpcodeClass.TensorOp
340
+ else:
341
+ self.op_class = cutlass_library.OpcodeClass.Simt
342
+
343
+ # The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels.
344
+ # Here, we generate additional versions via a generic TileDescription.
345
+ if cutlass_library.OpcodeClass.Simt not in self.operations_by_opclass:
346
+ self.operations_by_opclass[cutlass_library.OpcodeClass.Simt] = {}
347
+
348
+ if operation_kind == cutlass_library.OperationKind.Gemm:
349
+ types = [
350
+ (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s8),
351
+ (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32),
352
+ (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
353
+ (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
354
+ (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
355
+ (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
356
+ ]
357
+
358
+ # Add FP8 A/B/C
359
+ fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2]
360
+ for type_comb in combinations_with_replacement(fp8_types, 3):
361
+ types.append(type_comb)
362
+
363
+ # Add FP8 A/B with FP32 C
364
+ for type_comb in combinations_with_replacement(fp8_types, 2):
365
+ types.append(type_comb + (cutlass_cppgen.DataType.f32,))
366
+
367
+ layouts = [
368
+ (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor),
369
+ (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor),
370
+ (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.RowMajor),
371
+ (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.ColumnMajor),
372
+ ]
373
+ elif operation_kind == cutlass_library.OperationKind.Conv2d:
374
+ types = [
375
+ (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
376
+ (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
377
+ (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
378
+ (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
379
+ ]
380
+
381
+ layouts = [
382
+ (cutlass_library.LayoutType.TensorNHWC, cutlass_library.LayoutType.TensorNHWC),
383
+ ]
384
+ else:
385
+ raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.")
386
+
387
+ alignment = 1
388
+ epilogue_functor = cutlass_library.EpilogueFunctor.LinearCombination
389
+ swizzling_functor = cutlass_library.SwizzlingFunctor.Identity8
390
+ for type_comb in types:
391
+ for layout_comb in layouts:
392
+ comb = (type_comb, layout_comb)
393
+ if comb in self.operations_by_opclass[cutlass_library.OpcodeClass.Simt]:
394
+ continue
395
+
396
+ A = cutlass_library.TensorDescription(type_comb[0], layout_comb[0], alignment)
397
+ B = cutlass_library.TensorDescription(type_comb[1], layout_comb[1], alignment)
398
+ C = cutlass_library.TensorDescription(type_comb[2], cutlass_library.LayoutType.ColumnMajor, alignment)
399
+ math_inst = cutlass_library.MathInstruction(
400
+ [1, 1, 1],
401
+ type_comb[0],
402
+ type_comb[1],
403
+ type_comb[2],
404
+ cutlass_library.OpcodeClass.Simt,
405
+ cutlass_library.MathOperation.multiply_add
406
+ )
407
+
408
+ td = cutlass_library.TileDescription(
409
+ [128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024)
410
+
411
+ # Prune operations that don't fit in shared memory
412
+ if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td), verbose=False)[0]:
413
+ continue
414
+
415
+ new_kernels = KernelsForDataType(type_comb, layout_comb)
416
+
417
+ if operation_kind == cutlass_library.OperationKind.Gemm:
418
+ new_operation = cutlass_library.manifest.GemmOperation(
419
+ cutlass_library.GemmKind.Universal, td.minimum_compute_capability,
420
+ td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
421
+ new_kernels.add(new_operation)
422
+ elif operation_kind == cutlass_library.OperationKind.Conv2d:
423
+ for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
424
+ new_operation = cutlass_library.manifest.Conv2dOperation(
425
+ conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td,
426
+ A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor,
427
+ group_mode=GroupMode.SingleGroup
428
+ )
429
+ new_kernels.add(new_operation)
430
+
431
+ self.operations_by_opclass[cutlass_library.OpcodeClass.Simt][comb] = new_kernels
432
+
433
+ # Sort all operations
434
+ for oc in self.operations_by_opclass.keys():
435
+ for comb in self.operations_by_opclass[oc].keys():
436
+ self.operations_by_opclass[oc][comb].sort()
437
+
438
+ def opclass_supports_combination(
439
+ self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation
440
+ ) -> bool:
441
+ """
442
+ Returns whether the provided operation class supports the provided data type and layout combination
443
+
444
+ :param op_class: operation class to consider
445
+ :type op_class: cutlass_library.OpcodeClass
446
+ :param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator)
447
+ :type datatype_comb: tuple[cutlass_library.DataType]
448
+ :param layout_comb: tuple of data types for (layout_A, layout_B)
449
+ :type layout_comb: tuple[cutlass_library.LayoutType]
450
+ :param math_operation: math operation to consider or None if any can be considered
451
+ :type math_operation: cutlass_cppgen.MathOperation
452
+
453
+ :return: set of operation classes that support the provided data type and layout combination
454
+ :rtype: set
455
+ """
456
+ if op_class not in self.operations_by_opclass:
457
+ raise Exception(f"Unexpected or unsupported operation class {op_class}")
458
+
459
+ if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)):
460
+ if math_operation is not None:
461
+ return operations.supports_math_operation(math_operation)
462
+ else:
463
+ return True
464
+
465
+ return False
466
+
467
+
468
+ def supporting_opclasses(
469
+ self,
470
+ element_a: cutlass_library.DataType,
471
+ element_b: cutlass_library.DataType,
472
+ element_accumulator: cutlass_library.DataType,
473
+ layout_a: cutlass_library.LayoutType,
474
+ layout_b: cutlass_library.LayoutType,
475
+ math_operation: cutlass_library.MathOperation,
476
+ ) -> set:
477
+ """
478
+ Returns a set of operation classes that support the provided data type combination
479
+
480
+ :param element_a: data type of operand A
481
+ :type element_a: cutlass_library.DataType
482
+ :param element_b: data type of operand B
483
+ :type element_b: cutlass_library.DataType
484
+ :param element_accumulator: data type of accumulator
485
+ :type element_accumulator: cutlass_library.DataType
486
+ :param layout_a: layout of operand A
487
+ :type layout_a: cutlass_library.LayoutType
488
+ :param layout_b: layout of operand B
489
+ :type layout_b: cutlass_library.LayoutType
490
+ :param math_operation: math operation to consider
491
+ :type math_operation: cutlass_cppgen.MathOperation
492
+
493
+ :return: set of operation classes that support the provided data type combination
494
+ :rtype: set
495
+ """
496
+ supporting_op_classes = set()
497
+ datatype_comb = (element_a, element_b, element_accumulator)
498
+ layout_comb = (layout_a, layout_b)
499
+
500
+ for op_class in self.operations_by_opclass.keys():
501
+ if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
502
+ supporting_op_classes.add(op_class)
503
+ return supporting_op_classes
504
+
505
+ def operations(
506
+ self,
507
+ op_class: cutlass_library.OpcodeClass,
508
+ element_a: cutlass_library.DataType,
509
+ element_b: cutlass_library.DataType,
510
+ element_accumulator: cutlass_library.DataType,
511
+ layout_a: cutlass_library.LayoutType,
512
+ layout_b: cutlass_library.LayoutType,
513
+ math_operation: cutlass_library.MathOperation,
514
+ ) -> KernelsForDataType:
515
+ """
516
+ Returns whether the provided operation class supports the provided data type combination
517
+
518
+ :param op_class: operation class to consider
519
+ :type op_class: cutlass_library.OpcodeClass
520
+ :param element_a: data type of operand A
521
+ :type element_a: cutlass_library.DataType
522
+ :param element_b: data type of operand B
523
+ :type element_b: cutlass_library.DataType
524
+ :param element_accumulator: data type of accumulator
525
+ :type element_accumulator: cutlass_library.DataType
526
+ :param layout_a: layout of operand A
527
+ :type layout_a: cutlass_library.LayoutType
528
+ :param layout_b: layout of operand B
529
+ :type layout_b: cutlass_library.LayoutType
530
+ :param math_operation: math operation to consider
531
+ :type math_operation: cutlass_cppgen.MathOperation
532
+
533
+ :return: container of kernels by alignment supported by the provided combination of parameters
534
+ :rtype: KernelsForDataType
535
+ """
536
+ datatype_comb = (element_a, element_b, element_accumulator)
537
+ layout_comb = (layout_a, layout_b)
538
+ if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
539
+ raise Exception(
540
+ f"Data type layout combination {datatype_comb}, {layout_comb} "
541
+ f"is not supported by opcode class {op_class} on CC {self.cc}."
542
+ )
543
+ return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)]
544
+
545
+
546
+ class OptionRegistry:
547
+ """
548
+ Container of all architecture-specific options
549
+
550
+ :param target_cc: compute capability of the device on which operations will be run
551
+ :type target_cc: int
552
+ """
553
+
554
+ def __init__(self, target_cc: int):
555
+ self.registry = {}
556
+
557
+ if target_cc > 100 and (target_cc not in [101, 103, 120, 121]):
558
+ raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to the Blackwell architecture.")
559
+
560
+ gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x]
561
+ operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d]
562
+ # Construct options for each CC
563
+ for kernel_cc in _generator_ccs:
564
+ self.registry[kernel_cc] = {}
565
+ for opkind in operation_kinds:
566
+ self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds)
567
+
568
+ def options_for_cc(self, cc: int, op_kind=cutlass_library.OperationKind.Gemm) -> ArchOptions:
569
+ return self.registry.get(cc, None)[op_kind]
build/torch212-cu130-x86_64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
34
+ from cutlass_cppgen.op.gemm import Gemm
35
+ from cutlass_cppgen.op.gemm_grouped import GroupedGemm
36
+ from cutlass_cppgen.op.op import OperationBase