Kernels:
Trusted publisher
Uploaded using `kernel-builder` (batch 5/6).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py +116 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py +134 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py +47 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py +258 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py +98 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py +329 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py +168 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py +33 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py +272 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py +194 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py +53 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py +91 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py +254 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py +324 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py +336 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py +294 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py +306 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py +277 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py +137 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py +42 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py +143 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py +120 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py +169 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py +64 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py +90 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py +217 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py +164 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py +53 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py +97 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py +59 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py +319 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py +46 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py +109 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py +2145 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py +509 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py +121 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py +140 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py +455 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py +35 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py +33 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py +126 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py +33 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py +267 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py +936 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py +56 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py +176 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py +98 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py +569 -0
- build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py +36 -0
.gitattributes
CHANGED
|
@@ -65,3 +65,4 @@ build/torch212-cu132-x86_64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs dif
|
|
| 65 |
build/torch211-cu128-aarch64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 66 |
build/torch211-cu130-aarch64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 67 |
build/torch212-cu130-aarch64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 65 |
build/torch211-cu128-aarch64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 66 |
build/torch211-cu130-aarch64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 67 |
build/torch212-cu130-aarch64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
build/torch212-cu132-aarch64-linux/_deep_gemm_cuda_47ad41b.abi3.so filter=lfs diff=lfs merge=lfs -text
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_emitter.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Emitter for Sm100 Epilogue Visitor
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from cutlass_library import DataType, DataTypeTag, EpilogueScheduleTag, OpcodeClassTag
|
| 38 |
+
from cutlass_cppgen.backend.library import to_blackwell_threadblock_shape
|
| 39 |
+
from cutlass_cppgen.backend import GemmOperationUniversal
|
| 40 |
+
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
| 41 |
+
from cutlass_cppgen.backend.evt.ir.node import TupleEmitter
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Sm100CollectiveEpilogue:
|
| 45 |
+
def __init__(self, tile_description,
|
| 46 |
+
kernel_schedule,
|
| 47 |
+
epilogue_schedule,
|
| 48 |
+
element_accumulator,
|
| 49 |
+
element_d,
|
| 50 |
+
fusion_callbacks) -> None:
|
| 51 |
+
|
| 52 |
+
self.cta_tile_mnk, _ = to_blackwell_threadblock_shape(tile_description, tile_description.cluster_shape, kernel_schedule)
|
| 53 |
+
self.element_accumulator = element_accumulator
|
| 54 |
+
if fusion_callbacks.dag_ir.has_node("C"):
|
| 55 |
+
self.element_c = fusion_callbacks.dag_ir.get_node_meta("C").element
|
| 56 |
+
else:
|
| 57 |
+
self.element_c = DataType.void
|
| 58 |
+
self.element_d = element_d
|
| 59 |
+
self.schedule = epilogue_schedule
|
| 60 |
+
self.fusion_callbacks = fusion_callbacks
|
| 61 |
+
self.opclass = tile_description.math_instruction.opcode_class
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def CtaTileMNK(self) -> str:
|
| 65 |
+
"""
|
| 66 |
+
The threadblock shape
|
| 67 |
+
"""
|
| 68 |
+
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def EpilogueTileType(self) -> str:
|
| 72 |
+
"""
|
| 73 |
+
The epilogue tile type
|
| 74 |
+
"""
|
| 75 |
+
return "cutlass::epilogue::collective::EpilogueTileAuto"
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def Schedule(self) -> str:
|
| 79 |
+
return EpilogueScheduleTag[self.schedule]
|
| 80 |
+
|
| 81 |
+
def emit(self):
|
| 82 |
+
tuple_emitter = TupleEmitter("int64_t")
|
| 83 |
+
stride_D_str = self.fusion_callbacks.dag_ir.get_node_meta("D").underlying_impl.stride_mnl
|
| 84 |
+
stride_C_str = stride_D_str
|
| 85 |
+
if self.fusion_callbacks.dag_ir.has_node("C"):
|
| 86 |
+
stride_C_str = self.fusion_callbacks.dag_ir.get_node_meta("C").underlying_impl.stride_mnl
|
| 87 |
+
|
| 88 |
+
callback_decl, callback_name = self.fusion_callbacks.emit()
|
| 89 |
+
return callback_name, f"""
|
| 90 |
+
using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor<
|
| 91 |
+
{OpcodeClassTag[self.opclass]},
|
| 92 |
+
{self.CtaTileMNK}, {self.EpilogueTileType},
|
| 93 |
+
{DataTypeTag[self.element_accumulator]}, {DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
|
| 94 |
+
{self.Schedule}, {stride_C_str}, {stride_D_str},
|
| 95 |
+
false /* IsPerColScaleSupported */,
|
| 96 |
+
false /* IsBlockScaleSupported */
|
| 97 |
+
>;
|
| 98 |
+
{callback_decl}
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class Sm100Emitter:
|
| 103 |
+
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
| 104 |
+
fusion_callbacks = FusionCallbacks(graph, cc=100, emit_CD=False)
|
| 105 |
+
|
| 106 |
+
self.collective_epilogue = Sm100CollectiveEpilogue(
|
| 107 |
+
tile_description=operation.tile_description,
|
| 108 |
+
kernel_schedule=operation.tile_description.kernel_schedule,
|
| 109 |
+
epilogue_schedule=operation.tile_description.epilogue_schedule,
|
| 110 |
+
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
| 111 |
+
element_d=fusion_callbacks.dag_ir.get_node_meta("D").element,
|
| 112 |
+
fusion_callbacks=fusion_callbacks
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def emit(self):
|
| 116 |
+
return self.collective_epilogue.emit()
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm100_nodes.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from pycute import product
|
| 34 |
+
|
| 35 |
+
from cutlass_library import DataTypeSize, DataTypeTag
|
| 36 |
+
|
| 37 |
+
from cutlass_cppgen.backend.evt.ir import AuxLoadImpl, AuxStoreImpl
|
| 38 |
+
import cutlass_cppgen.backend.evt.backend.sm90_nodes as sm90_nodes
|
| 39 |
+
|
| 40 |
+
from cutlass_cppgen.backend.library import FloatRoundStyleTag
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
Sm100AccumulatorImpl = sm90_nodes.Sm90AccumulatorImpl
|
| 44 |
+
Sm100LoadSrcImpl = sm90_nodes.Sm90LoadSrcImpl
|
| 45 |
+
Sm100ScalarBroadcastImpl = sm90_nodes.Sm90ScalarBroadcastImpl
|
| 46 |
+
Sm100RowBroadcastImpl = sm90_nodes.Sm90RowBroadcastImpl
|
| 47 |
+
Sm100ColumnBroadcastImpl = sm90_nodes.Sm90ColumnBroadcastImpl
|
| 48 |
+
Sm100ComputeImpl = sm90_nodes.Sm90ComputeImpl
|
| 49 |
+
Sm100StoreDImpl = sm90_nodes.Sm90StoreDImpl
|
| 50 |
+
Sm100ColumnReductionImpl = sm90_nodes.Sm90ColumnReductionImpl
|
| 51 |
+
Sm100RowReductionImpl = sm90_nodes.Sm90RowReductionImpl
|
| 52 |
+
Sm100ScalarReductionImpl = sm90_nodes.Sm90ScalarReductionImpl
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Sm100AuxLoadImpl(AuxLoadImpl):
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
def descriptor(self) -> str:
|
| 59 |
+
"""
|
| 60 |
+
Descriptor for Aux Load
|
| 61 |
+
"""
|
| 62 |
+
return f"{self.name_camel}Descriptor"
|
| 63 |
+
|
| 64 |
+
def decl_descriptor(self) -> str:
|
| 65 |
+
"""
|
| 66 |
+
Declare the descriptor type
|
| 67 |
+
"""
|
| 68 |
+
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def type_decl(self):
|
| 72 |
+
"""
|
| 73 |
+
Return the string defining the type
|
| 74 |
+
"""
|
| 75 |
+
if self._type_decl is not None:
|
| 76 |
+
return self._type_decl
|
| 77 |
+
|
| 78 |
+
self._type_decl = self.decl_descriptor()
|
| 79 |
+
self._type_decl += f"""
|
| 80 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
|
| 81 |
+
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
| 82 |
+
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
|
| 83 |
+
>;
|
| 84 |
+
"""
|
| 85 |
+
return self._type_decl
|
| 86 |
+
|
| 87 |
+
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
| 88 |
+
"""
|
| 89 |
+
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
| 90 |
+
"""
|
| 91 |
+
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class Sm100AuxStoreImpl(AuxStoreImpl):
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def descriptor(self) -> str:
|
| 98 |
+
"""
|
| 99 |
+
Descriptor for Aux Load
|
| 100 |
+
"""
|
| 101 |
+
return f"{self.name_camel}Descriptor"
|
| 102 |
+
|
| 103 |
+
def decl_descriptor(self) -> str:
|
| 104 |
+
"""
|
| 105 |
+
Declare the descriptor type
|
| 106 |
+
"""
|
| 107 |
+
return f"""
|
| 108 |
+
using {self.descriptor} = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor<
|
| 109 |
+
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
|
| 110 |
+
>;
|
| 111 |
+
"""
|
| 112 |
+
@property
|
| 113 |
+
def type_decl(self):
|
| 114 |
+
"""
|
| 115 |
+
Return the string defining the type
|
| 116 |
+
"""
|
| 117 |
+
if self._type_decl is not None:
|
| 118 |
+
return self._type_decl
|
| 119 |
+
|
| 120 |
+
self._type_decl = self.decl_descriptor()
|
| 121 |
+
self._type_decl += f"""
|
| 122 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
|
| 123 |
+
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
| 124 |
+
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
|
| 125 |
+
typename {self.descriptor}::CopyOpR2S
|
| 126 |
+
>;
|
| 127 |
+
"""
|
| 128 |
+
return self._type_decl
|
| 129 |
+
|
| 130 |
+
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
| 131 |
+
"""
|
| 132 |
+
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
| 133 |
+
"""
|
| 134 |
+
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_emitter.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Emitter for Sm80 Epilogue Visitor
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
| 38 |
+
from cutlass_cppgen.backend import GemmOperationUniversal
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class Sm80Emitter:
|
| 42 |
+
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
| 43 |
+
self.fusion_callbacks = FusionCallbacks(graph, cc=80)
|
| 44 |
+
|
| 45 |
+
def emit(self):
|
| 46 |
+
callback_decl, callback_name = self.fusion_callbacks.emit()
|
| 47 |
+
return callback_name, callback_decl
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm80_nodes.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from cutlass_library import DataTypeSize, DataTypeTag
|
| 34 |
+
|
| 35 |
+
from cutlass_cppgen.backend.evt.ir import (
|
| 36 |
+
# Load Node
|
| 37 |
+
AccumulatorImpl,
|
| 38 |
+
AuxLoadImpl,
|
| 39 |
+
ColumnBroadcastImpl,
|
| 40 |
+
LoadNode,
|
| 41 |
+
LoadSrcImpl,
|
| 42 |
+
RowBroadcastImpl,
|
| 43 |
+
ScalarBroadcastImpl,
|
| 44 |
+
# Compute Node
|
| 45 |
+
ComputeImpl,
|
| 46 |
+
# Store Node
|
| 47 |
+
AuxStoreImpl,
|
| 48 |
+
ColumnReductionImpl,
|
| 49 |
+
RowReductionImpl,
|
| 50 |
+
ScalarReductionImpl
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
from cutlass_cppgen.backend.library import (
|
| 54 |
+
FloatRoundStyleTag,
|
| 55 |
+
FunctionalOp,
|
| 56 |
+
op_tag,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Sm80AccumulatorImpl(AccumulatorImpl):
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def type_decl(self):
|
| 64 |
+
"""
|
| 65 |
+
Return the string defining the type
|
| 66 |
+
"""
|
| 67 |
+
if self._type_decl is not None:
|
| 68 |
+
return self._type_decl
|
| 69 |
+
|
| 70 |
+
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::threadblock::VisitorAccFetch;\n"""
|
| 71 |
+
return self._type_decl
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Sm80AuxLoadImpl(AuxLoadImpl):
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def type_decl(self):
|
| 78 |
+
"""
|
| 79 |
+
Return the string defining the type
|
| 80 |
+
"""
|
| 81 |
+
if self._type_decl is not None:
|
| 82 |
+
return self._type_decl
|
| 83 |
+
|
| 84 |
+
self._type_decl = f"""
|
| 85 |
+
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxLoad<
|
| 86 |
+
OutputTileThreadMap, {DataTypeTag[self.element]}, {self.stride_mnl}
|
| 87 |
+
>;
|
| 88 |
+
"""
|
| 89 |
+
return self._type_decl
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class Sm80LoadSrcImpl(Sm80AuxLoadImpl):
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class Sm80ScalarBroadcastImpl(ScalarBroadcastImpl):
|
| 97 |
+
def __init__(self, node: LoadNode) -> None:
|
| 98 |
+
super().__init__(node)
|
| 99 |
+
self.broadcast_count = 1
|
| 100 |
+
self.reduction_fn = FunctionalOp.Multiplies
|
| 101 |
+
|
| 102 |
+
@property
|
| 103 |
+
def type_decl(self):
|
| 104 |
+
"""
|
| 105 |
+
Return the string defining the type
|
| 106 |
+
"""
|
| 107 |
+
if self._type_decl is not None:
|
| 108 |
+
return self._type_decl
|
| 109 |
+
|
| 110 |
+
self._type_decl = f"""
|
| 111 |
+
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarBroadcast<
|
| 112 |
+
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
|
| 113 |
+
>;
|
| 114 |
+
"""
|
| 115 |
+
return self._type_decl
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class Sm80RowBroadcastImpl(RowBroadcastImpl):
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def type_decl(self):
|
| 122 |
+
"""
|
| 123 |
+
Return the string defining the type
|
| 124 |
+
"""
|
| 125 |
+
if self._type_decl is not None:
|
| 126 |
+
return self._type_decl
|
| 127 |
+
|
| 128 |
+
self._type_decl = f"""
|
| 129 |
+
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
| 130 |
+
OutputTileThreadMap, {DataTypeTag[self.element]},
|
| 131 |
+
{self.stride_mnl}
|
| 132 |
+
>;
|
| 133 |
+
"""
|
| 134 |
+
return self._type_decl
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class Sm80ColumnBroadcastImpl(ColumnBroadcastImpl):
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def type_decl(self):
|
| 141 |
+
"""
|
| 142 |
+
Return the string defining the type
|
| 143 |
+
"""
|
| 144 |
+
if self._type_decl is not None:
|
| 145 |
+
return self._type_decl
|
| 146 |
+
|
| 147 |
+
self._type_decl = f"""
|
| 148 |
+
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
| 149 |
+
OutputTileThreadMap, {DataTypeTag[self.element]},
|
| 150 |
+
{self.stride_mnl}
|
| 151 |
+
>;
|
| 152 |
+
"""
|
| 153 |
+
return self._type_decl
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Sm80ComputeImpl(ComputeImpl):
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def type_decl(self):
|
| 160 |
+
"""
|
| 161 |
+
Return the string defining the type
|
| 162 |
+
"""
|
| 163 |
+
if self._type_decl is not None:
|
| 164 |
+
return self._type_decl
|
| 165 |
+
|
| 166 |
+
self._type_decl = f"""
|
| 167 |
+
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorCompute<
|
| 168 |
+
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
|
| 169 |
+
{FloatRoundStyleTag[self.round_style]}
|
| 170 |
+
>;
|
| 171 |
+
"""
|
| 172 |
+
return self._type_decl
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class Sm80AuxStoreImpl(AuxStoreImpl):
|
| 176 |
+
|
| 177 |
+
@property
|
| 178 |
+
def type_decl(self):
|
| 179 |
+
"""
|
| 180 |
+
Return the string defining the type
|
| 181 |
+
"""
|
| 182 |
+
if self._type_decl is not None:
|
| 183 |
+
return self._type_decl
|
| 184 |
+
|
| 185 |
+
self._type_decl = f"""
|
| 186 |
+
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorAuxStore<
|
| 187 |
+
OutputTileThreadMap, {DataTypeTag[self.element]}, {FloatRoundStyleTag[self.round_style]},
|
| 188 |
+
{self.stride_mnl}
|
| 189 |
+
>;
|
| 190 |
+
"""
|
| 191 |
+
return self._type_decl
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class Sm80StoreDImpl(Sm80AuxStoreImpl):
|
| 195 |
+
pass
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class Sm80ColumnReductionImpl(ColumnReductionImpl):
|
| 199 |
+
|
| 200 |
+
@property
|
| 201 |
+
def type_decl(self):
|
| 202 |
+
"""
|
| 203 |
+
Return the string defining the type
|
| 204 |
+
"""
|
| 205 |
+
if self._type_decl is not None:
|
| 206 |
+
return self._type_decl
|
| 207 |
+
|
| 208 |
+
self._type_decl = f"""
|
| 209 |
+
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorColReduction<
|
| 210 |
+
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
| 211 |
+
OutputTileThreadMap, {DataTypeTag[self.element]},
|
| 212 |
+
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
| 213 |
+
{self.stride_mnl}
|
| 214 |
+
>;
|
| 215 |
+
"""
|
| 216 |
+
return self._type_decl
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class Sm80RowReductionImpl(RowReductionImpl):
|
| 220 |
+
|
| 221 |
+
@property
|
| 222 |
+
def type_decl(self):
|
| 223 |
+
"""
|
| 224 |
+
Return the string defining the type
|
| 225 |
+
"""
|
| 226 |
+
if self._type_decl is not None:
|
| 227 |
+
return self._type_decl
|
| 228 |
+
|
| 229 |
+
self._type_decl = f"""
|
| 230 |
+
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorRowReduction<
|
| 231 |
+
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
| 232 |
+
OutputTileThreadMap, {DataTypeTag[self.element]},
|
| 233 |
+
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
| 234 |
+
{self.stride_mnl}
|
| 235 |
+
>;
|
| 236 |
+
"""
|
| 237 |
+
return self._type_decl
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class Sm80ScalarReductionImpl(ScalarReductionImpl):
|
| 241 |
+
|
| 242 |
+
@property
|
| 243 |
+
def type_decl(self):
|
| 244 |
+
"""
|
| 245 |
+
Return the string defining the type
|
| 246 |
+
"""
|
| 247 |
+
if self._type_decl is not None:
|
| 248 |
+
return self._type_decl
|
| 249 |
+
|
| 250 |
+
self._type_decl = f"""
|
| 251 |
+
using {self.name_camel} = cutlass::epilogue::threadblock::VisitorScalarReduction<
|
| 252 |
+
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
| 253 |
+
OutputTileThreadMap, {DataTypeTag[self.element]},
|
| 254 |
+
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
| 255 |
+
{self.stride_mnl}
|
| 256 |
+
>;
|
| 257 |
+
"""
|
| 258 |
+
return self._type_decl
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_emitter.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Emitter for Sm90 Epilogue Visitor
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from cutlass_library import DataTypeTag, EpilogueScheduleTag
|
| 38 |
+
from cutlass_cppgen.backend import GemmOperationUniversal
|
| 39 |
+
from cutlass_cppgen.backend.evt.backend.emitter_base import FusionCallbacks
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class CollectiveEpilogue:
|
| 43 |
+
def __init__(self, tile_description,
|
| 44 |
+
schedule,
|
| 45 |
+
element_c,
|
| 46 |
+
element_d,
|
| 47 |
+
fusion_callbacks) -> None:
|
| 48 |
+
|
| 49 |
+
self.cta_tile_mnk = tile_description.threadblock_shape
|
| 50 |
+
self.element_c = element_c
|
| 51 |
+
self.element_d = element_d
|
| 52 |
+
self.schedule = schedule
|
| 53 |
+
self.fusion_callbacks = fusion_callbacks
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def CtaTileMNK(self) -> str:
|
| 57 |
+
"""
|
| 58 |
+
The threadblock shape
|
| 59 |
+
"""
|
| 60 |
+
return f"cute::Shape<_{self.cta_tile_mnk[0]}, _{self.cta_tile_mnk[1]}, _{self.cta_tile_mnk[2]}>"
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def EpilogueTileType(self) -> str:
|
| 64 |
+
"""
|
| 65 |
+
The epilogue tile type
|
| 66 |
+
"""
|
| 67 |
+
return "cutlass::epilogue::collective::EpilogueTileAuto"
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def Schedule(self) -> str:
|
| 71 |
+
return EpilogueScheduleTag[self.schedule]
|
| 72 |
+
|
| 73 |
+
def emit(self):
|
| 74 |
+
callback_decl, callback_name = self.fusion_callbacks.emit()
|
| 75 |
+
return callback_name, f"""
|
| 76 |
+
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
| 77 |
+
{self.CtaTileMNK}, {self.EpilogueTileType},
|
| 78 |
+
{DataTypeTag[self.element_c]}, {DataTypeTag[self.element_d]},
|
| 79 |
+
{self.Schedule}
|
| 80 |
+
>;
|
| 81 |
+
{callback_decl}
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Sm90Emitter:
|
| 86 |
+
def __init__(self, operation: GemmOperationUniversal, graph) -> None:
|
| 87 |
+
fusion_callbacks = FusionCallbacks(graph, cc=90, emit_CD=False)
|
| 88 |
+
|
| 89 |
+
self.collective_epilogue = CollectiveEpilogue(
|
| 90 |
+
tile_description=operation.tile_description,
|
| 91 |
+
schedule=operation.tile_description.epilogue_schedule,
|
| 92 |
+
element_c=operation.C.element,
|
| 93 |
+
element_d=operation.C.element,
|
| 94 |
+
fusion_callbacks=fusion_callbacks
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
def emit(self):
|
| 98 |
+
return self.collective_epilogue.emit()
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/backend/sm90_nodes.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from pycute import product
|
| 34 |
+
|
| 35 |
+
from cutlass_library import DataTypeSize, DataTypeTag
|
| 36 |
+
from cutlass_cppgen.backend.evt.ir import (
|
| 37 |
+
# Load Node
|
| 38 |
+
AccumulatorImpl,
|
| 39 |
+
AuxLoadImpl,
|
| 40 |
+
ColumnBroadcastImpl,
|
| 41 |
+
LoadNode,
|
| 42 |
+
LoadSrcImpl,
|
| 43 |
+
RowBroadcastImpl,
|
| 44 |
+
ScalarBroadcastImpl,
|
| 45 |
+
# Compute Node
|
| 46 |
+
ComputeImpl,
|
| 47 |
+
ComputeNode,
|
| 48 |
+
# Store Node
|
| 49 |
+
AuxStoreImpl,
|
| 50 |
+
ColumnReductionImpl,
|
| 51 |
+
RowReductionImpl,
|
| 52 |
+
ScalarReductionImpl,
|
| 53 |
+
StoreNode,
|
| 54 |
+
StoreDImpl,
|
| 55 |
+
)
|
| 56 |
+
from cutlass_cppgen.backend.library import (
|
| 57 |
+
FloatRoundStyleTag,
|
| 58 |
+
FunctionalOp,
|
| 59 |
+
op_tag,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class Sm90AccumulatorImpl(AccumulatorImpl):
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def type_decl(self):
|
| 67 |
+
"""
|
| 68 |
+
Return the string defining the type
|
| 69 |
+
"""
|
| 70 |
+
if self._type_decl is not None:
|
| 71 |
+
return self._type_decl
|
| 72 |
+
|
| 73 |
+
self._type_decl = f"""\nusing {self.name_camel} = cutlass::epilogue::fusion::Sm90AccFetch;\n"""
|
| 74 |
+
return self._type_decl
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Sm90LoadSrcImpl(LoadSrcImpl):
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def type_decl(self):
|
| 81 |
+
"""
|
| 82 |
+
Return the string defining the type
|
| 83 |
+
"""
|
| 84 |
+
if self._type_decl is not None:
|
| 85 |
+
return self._type_decl
|
| 86 |
+
|
| 87 |
+
self._type_decl = f"""
|
| 88 |
+
using ElementC = {DataTypeTag[self.element]};
|
| 89 |
+
using StrideC = {self.stride_mnl};
|
| 90 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90SrcFetch<{DataTypeTag[self.element]}>;
|
| 91 |
+
"""
|
| 92 |
+
return self._type_decl
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class Sm90AuxLoadImpl(AuxLoadImpl):
|
| 96 |
+
|
| 97 |
+
@property
|
| 98 |
+
def descriptor(self) -> str:
|
| 99 |
+
"""
|
| 100 |
+
Descriptor for Aux Load
|
| 101 |
+
"""
|
| 102 |
+
return f"{self.name_camel}Descriptor"
|
| 103 |
+
|
| 104 |
+
def decl_descriptor(self) -> str:
|
| 105 |
+
"""
|
| 106 |
+
Declare the descriptor type
|
| 107 |
+
"""
|
| 108 |
+
return f"\nusing {self.descriptor} = cutlass::epilogue::collective::detail::AuxLoadDescriptor<EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}>;\n"
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def type_decl(self):
|
| 112 |
+
"""
|
| 113 |
+
Return the string defining the type
|
| 114 |
+
"""
|
| 115 |
+
if self._type_decl is not None:
|
| 116 |
+
return self._type_decl
|
| 117 |
+
|
| 118 |
+
self._type_decl = self.decl_descriptor()
|
| 119 |
+
self._type_decl += f"""
|
| 120 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxLoad<
|
| 121 |
+
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
| 122 |
+
{self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom, typename {self.descriptor}::CopyOpS2R
|
| 123 |
+
>;
|
| 124 |
+
"""
|
| 125 |
+
return self._type_decl
|
| 126 |
+
|
| 127 |
+
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
| 128 |
+
"""
|
| 129 |
+
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
| 130 |
+
"""
|
| 131 |
+
return (DataTypeSize[self.element] * stages_c * product(epilogue_tile_mn) // 8, 128)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class Sm90ScalarBroadcastImpl(ScalarBroadcastImpl):
|
| 135 |
+
def __init__(self, node: LoadNode) -> None:
|
| 136 |
+
super().__init__(node)
|
| 137 |
+
self.broadcast_count = 1
|
| 138 |
+
self.reduction_fn = FunctionalOp.Multiplies
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def type_decl(self):
|
| 142 |
+
"""
|
| 143 |
+
Return the string defining the type
|
| 144 |
+
"""
|
| 145 |
+
if self._type_decl is not None:
|
| 146 |
+
return self._type_decl
|
| 147 |
+
|
| 148 |
+
self._type_decl = f"""
|
| 149 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarBroadcast<
|
| 150 |
+
{DataTypeTag[self.element]}, {self.stride_mnl}, {self.broadcast_count}, {op_tag(self.reduction_fn)}
|
| 151 |
+
>;
|
| 152 |
+
"""
|
| 153 |
+
return self._type_decl
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Sm90RowBroadcastImpl(RowBroadcastImpl):
|
| 157 |
+
@property
|
| 158 |
+
def type_decl(self):
|
| 159 |
+
"""
|
| 160 |
+
Return the string defining the type
|
| 161 |
+
"""
|
| 162 |
+
if self._type_decl is not None:
|
| 163 |
+
return self._type_decl
|
| 164 |
+
|
| 165 |
+
self._type_decl = f"""
|
| 166 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
| 167 |
+
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
| 168 |
+
{self.stride_mnl}
|
| 169 |
+
>;
|
| 170 |
+
"""
|
| 171 |
+
return self._type_decl
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class Sm90ColumnBroadcastImpl(ColumnBroadcastImpl):
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def type_decl(self):
|
| 178 |
+
"""
|
| 179 |
+
Return the string defining the type
|
| 180 |
+
"""
|
| 181 |
+
if self._type_decl is not None:
|
| 182 |
+
return self._type_decl
|
| 183 |
+
|
| 184 |
+
self._type_decl = f"""
|
| 185 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
| 186 |
+
0 /*Stages*/, typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]}, {DataTypeTag[self.element_output]},
|
| 187 |
+
{self.stride_mnl}
|
| 188 |
+
>;
|
| 189 |
+
"""
|
| 190 |
+
return self._type_decl
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
class Sm90ComputeImpl(ComputeImpl):
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def type_decl(self):
|
| 197 |
+
"""
|
| 198 |
+
Return the string defining the type
|
| 199 |
+
"""
|
| 200 |
+
if self._type_decl is not None:
|
| 201 |
+
return self._type_decl
|
| 202 |
+
|
| 203 |
+
self._type_decl = f"""
|
| 204 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90Compute<
|
| 205 |
+
{op_tag(self.fn)}, {DataTypeTag[self.element_output]}, {DataTypeTag[self.element_compute]},
|
| 206 |
+
{FloatRoundStyleTag[self.round_style]}
|
| 207 |
+
>;
|
| 208 |
+
"""
|
| 209 |
+
return self._type_decl
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class Sm90AuxStoreImpl(AuxStoreImpl):
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def descriptor(self) -> str:
|
| 216 |
+
"""
|
| 217 |
+
Descriptor for Aux Load
|
| 218 |
+
"""
|
| 219 |
+
return f"{self.name_camel}Descriptor"
|
| 220 |
+
|
| 221 |
+
def decl_descriptor(self) -> str:
|
| 222 |
+
"""
|
| 223 |
+
Declare the descriptor type
|
| 224 |
+
"""
|
| 225 |
+
return f"""
|
| 226 |
+
using {self.descriptor} = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
| 227 |
+
EpilogueDescriptor, {self.stride_mnl}, {DataTypeTag[self.element]}
|
| 228 |
+
>;
|
| 229 |
+
"""
|
| 230 |
+
@property
|
| 231 |
+
def type_decl(self):
|
| 232 |
+
"""
|
| 233 |
+
Return the string defining the type
|
| 234 |
+
"""
|
| 235 |
+
if self._type_decl is not None:
|
| 236 |
+
return self._type_decl
|
| 237 |
+
|
| 238 |
+
self._type_decl = self.decl_descriptor()
|
| 239 |
+
self._type_decl += f"""
|
| 240 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90AuxStore<
|
| 241 |
+
{self.descriptor}::Stages, typename {self.descriptor}::EpilogueTile, {DataTypeTag[self.element]},
|
| 242 |
+
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}, typename {self.descriptor}::SmemLayoutAtom,
|
| 243 |
+
typename {self.descriptor}::CopyOpR2S
|
| 244 |
+
>;
|
| 245 |
+
"""
|
| 246 |
+
return self._type_decl
|
| 247 |
+
|
| 248 |
+
def get_smem_size(self, cta_tile_mnk, epilogue_tile_mn, stages_c, stages_d, epi_tiles):
|
| 249 |
+
"""
|
| 250 |
+
Get the shared memory size based on epilogue_tile_mn, stages_c, and stages_d
|
| 251 |
+
"""
|
| 252 |
+
return (DataTypeSize[self.element] * stages_d * product(epilogue_tile_mn) // 8, 128)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class Sm90StoreDImpl(StoreDImpl):
|
| 256 |
+
|
| 257 |
+
@property
|
| 258 |
+
def type_decl(self):
|
| 259 |
+
"""
|
| 260 |
+
Return the string defining the type
|
| 261 |
+
"""
|
| 262 |
+
return f"""
|
| 263 |
+
using ElementD = {DataTypeTag[self.element]};
|
| 264 |
+
using StrideD = {self.stride_mnl};
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class Sm90ColumnReductionImpl(ColumnReductionImpl):
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def type_decl(self):
|
| 272 |
+
"""
|
| 273 |
+
Return the string defining the type
|
| 274 |
+
"""
|
| 275 |
+
if self._type_decl is not None:
|
| 276 |
+
return self._type_decl
|
| 277 |
+
|
| 278 |
+
self._type_decl = f"""
|
| 279 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ColReduction<
|
| 280 |
+
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0,
|
| 281 |
+
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
| 282 |
+
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
| 283 |
+
{self.stride_mnl}
|
| 284 |
+
>;
|
| 285 |
+
"""
|
| 286 |
+
return self._type_decl
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class Sm90RowReductionImpl(RowReductionImpl):
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
@property
|
| 293 |
+
def type_decl(self):
|
| 294 |
+
"""
|
| 295 |
+
Return the string defining the type
|
| 296 |
+
"""
|
| 297 |
+
if self._type_decl is not None:
|
| 298 |
+
return self._type_decl
|
| 299 |
+
|
| 300 |
+
self._type_decl = f"""
|
| 301 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90RowReduction<
|
| 302 |
+
{op_tag(self.reg_reduce_fn)}, {op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)}, 0 /* Stages */,
|
| 303 |
+
typename EpilogueDescriptor::TileShape, {DataTypeTag[self.element]},
|
| 304 |
+
{DataTypeTag[self.element_compute]}, {FloatRoundStyleTag[self.round_style]},
|
| 305 |
+
{self.stride_mnl}
|
| 306 |
+
>;
|
| 307 |
+
"""
|
| 308 |
+
return self._type_decl
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class Sm90ScalarReductionImpl(ScalarReductionImpl):
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
@property
|
| 315 |
+
def type_decl(self):
|
| 316 |
+
"""
|
| 317 |
+
Return the string defining the type
|
| 318 |
+
"""
|
| 319 |
+
if self._type_decl is not None:
|
| 320 |
+
return self._type_decl
|
| 321 |
+
|
| 322 |
+
self._type_decl = f"""
|
| 323 |
+
using {self.name_camel} = cutlass::epilogue::fusion::Sm90ScalarReduction<
|
| 324 |
+
{op_tag(self.reg_reduce_fn)}, {op_tag(self.gmem_reduce_fn)},
|
| 325 |
+
{DataTypeTag[self.element]}, {DataTypeTag[self.element_compute]},
|
| 326 |
+
{FloatRoundStyleTag[self.round_style]}, {self.stride_mnl}
|
| 327 |
+
>;
|
| 328 |
+
"""
|
| 329 |
+
return self._type_decl
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/epilogue.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Epilogue Visitor interface for compiling, and running visitor-based epilogue.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import ctypes
|
| 38 |
+
|
| 39 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 40 |
+
cuda = lazy_import("cuda.cuda")
|
| 41 |
+
from cutlass_library import DataType
|
| 42 |
+
import numpy as np
|
| 43 |
+
|
| 44 |
+
from cutlass_cppgen.backend.epilogue import EpilogueFunctorBase
|
| 45 |
+
import cutlass_cppgen.backend.evt.backend
|
| 46 |
+
from cutlass_cppgen.backend.frontend import TensorFrontend
|
| 47 |
+
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
| 48 |
+
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class EpilogueFunctorVisitor(EpilogueFunctorBase):
|
| 52 |
+
"""
|
| 53 |
+
Apply an epilogue functor described by the epilogue EVT
|
| 54 |
+
|
| 55 |
+
:param cc: compute capability
|
| 56 |
+
:param visitor_frontend: user-provide visitor frontend
|
| 57 |
+
|
| 58 |
+
"""
|
| 59 |
+
def __init__(self, cc: int, visitor, element_compute=DataType.f32) -> None:
|
| 60 |
+
# Type of Emitter based on CC
|
| 61 |
+
self.emit_cls = getattr(cutlass_cppgen.backend.evt.backend, f"Sm{cc_map[cc]}Emitter")
|
| 62 |
+
|
| 63 |
+
# Visitor Types
|
| 64 |
+
self.visitor = visitor
|
| 65 |
+
self.graph = visitor.dag_ir
|
| 66 |
+
|
| 67 |
+
# Data types
|
| 68 |
+
self.element_epilogue = element_compute # element compute
|
| 69 |
+
self.element_output = self.graph.get_node_meta('D').underlying_impl.element
|
| 70 |
+
|
| 71 |
+
# Epilogue Thread Type
|
| 72 |
+
epilogue_thread_type = self.visitor.epilogue_thread_type
|
| 73 |
+
if cc_map[cc] in [90, 100]:
|
| 74 |
+
self.arg_c_type = self.visitor.arg_c_type
|
| 75 |
+
self.arg_d_type = self.visitor.arg_d_type
|
| 76 |
+
output_names = self.visitor.return_names
|
| 77 |
+
reduction_names = self.visitor.reduction_names
|
| 78 |
+
|
| 79 |
+
# Epilogue stages specialized for sm80 kernel
|
| 80 |
+
if cc == 80:
|
| 81 |
+
if hasattr(self.visitor, "epilogue_stages"):
|
| 82 |
+
self.epilogue_stages = self.visitor.epilogue_stages
|
| 83 |
+
assert self.epilogue_stages <= 2, "Only supports Stages <=2 in SM80 Epilogue"
|
| 84 |
+
|
| 85 |
+
# Epilogue Argument Type
|
| 86 |
+
class _Arguments(ctypes.Structure):
|
| 87 |
+
"""
|
| 88 |
+
Concepts:
|
| 89 |
+
class _EpilogueArguments(ctypes.Structure):
|
| 90 |
+
_fields_ = [
|
| 91 |
+
("epilogue", _Arguments), <- this class
|
| 92 |
+
("ptr_C", ctypes.c_void_p),
|
| 93 |
+
("stride_C", StrideBatched_),
|
| 94 |
+
("ptr_D", ctypes.c_void_p),
|
| 95 |
+
("stride_D", StrideBatched_)
|
| 96 |
+
]
|
| 97 |
+
"""
|
| 98 |
+
_fields_ = [
|
| 99 |
+
("output_op", epilogue_thread_type)
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
def __init__(self, kwargs: dict) -> None:
|
| 103 |
+
# The user-input kwargs is a dict of (name: tensors)
|
| 104 |
+
# We first convert all of them to device pointers
|
| 105 |
+
ptr_kwargs = {}
|
| 106 |
+
for key in kwargs.keys():
|
| 107 |
+
is_output = key in output_names and key not in reduction_names
|
| 108 |
+
ptr_kwargs[key] = self.get_tensor_ptr(key, kwargs, is_output)
|
| 109 |
+
# Initialize the thread arguments
|
| 110 |
+
self.output_op = epilogue_thread_type(ptr_kwargs)
|
| 111 |
+
|
| 112 |
+
def get_tensor_ptr(self, tensor_name, kwargs, is_output=False):
|
| 113 |
+
"""
|
| 114 |
+
Helper function for extracting device pointer
|
| 115 |
+
"""
|
| 116 |
+
# Skip the special tensors
|
| 117 |
+
if cc in [90, 100]:
|
| 118 |
+
if tensor_name in ["C", "D"]:
|
| 119 |
+
return 0
|
| 120 |
+
if tensor_name not in kwargs.keys():
|
| 121 |
+
raise ValueError(f"Tensor {tensor_name} is not provided.")
|
| 122 |
+
tensor = kwargs[tensor_name]
|
| 123 |
+
|
| 124 |
+
# For float scalar constant, directly return the value
|
| 125 |
+
if isinstance(tensor, float):
|
| 126 |
+
return tensor
|
| 127 |
+
|
| 128 |
+
# The tensor frontend returns a device buffer for np.ndarray
|
| 129 |
+
# and device ptr for other frontends
|
| 130 |
+
buffer_or_ptr = TensorFrontend.argument(tensor, is_output)
|
| 131 |
+
if is_numpy_tensor(tensor):
|
| 132 |
+
# Remember the host tensor for later synchronization
|
| 133 |
+
setattr(self, f"{tensor_name}_buffer", buffer_or_ptr)
|
| 134 |
+
setattr(self, f"{tensor_name}_host", tensor)
|
| 135 |
+
return int(buffer_or_ptr.ptr)
|
| 136 |
+
else:
|
| 137 |
+
return int(buffer_or_ptr)
|
| 138 |
+
|
| 139 |
+
def sync(self):
|
| 140 |
+
"""
|
| 141 |
+
Synchronize the results from device to host
|
| 142 |
+
"""
|
| 143 |
+
for name in output_names:
|
| 144 |
+
if hasattr(self, f"{name}_host"):
|
| 145 |
+
host_tensor = getattr(self, f"{name}_host")
|
| 146 |
+
tensor_ptr = getattr(self, f"{name}_buffer").ptr
|
| 147 |
+
(err,) = cuda.cuMemcpyDtoH(
|
| 148 |
+
host_tensor,
|
| 149 |
+
tensor_ptr,
|
| 150 |
+
host_tensor.size * host_tensor.itemsize,
|
| 151 |
+
)
|
| 152 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 153 |
+
raise RuntimeError("CUDA Error %s" % str(err))
|
| 154 |
+
|
| 155 |
+
self.epilogue_type = _Arguments
|
| 156 |
+
|
| 157 |
+
def emit(self, operation):
|
| 158 |
+
"""
|
| 159 |
+
Emit the C++ code
|
| 160 |
+
"""
|
| 161 |
+
emitter = self.emit_cls(operation, self.graph)
|
| 162 |
+
return emitter.emit()
|
| 163 |
+
|
| 164 |
+
def get_smem_size(self, tile_description):
|
| 165 |
+
"""
|
| 166 |
+
Get the shared memory size in bytes
|
| 167 |
+
"""
|
| 168 |
+
return self.visitor.get_smem_size(tile_description)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from cutlass_cppgen.backend.evt.frontend.python_ast import PythonASTFrontend
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/frontend_base.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Base class for Python EVT Frontend
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from typing import Union
|
| 38 |
+
|
| 39 |
+
from cutlass_library import DataType
|
| 40 |
+
from cutlass_cppgen.backend.evt.ir import (
|
| 41 |
+
ComputeNode,
|
| 42 |
+
DAGIR,
|
| 43 |
+
LayoutNode,
|
| 44 |
+
LoadNode,
|
| 45 |
+
StoreNode,
|
| 46 |
+
)
|
| 47 |
+
from cutlass_cppgen.backend.evt.passes import (
|
| 48 |
+
EVTGraphDrawer,
|
| 49 |
+
EVTPassManager,
|
| 50 |
+
GetSmemSize,
|
| 51 |
+
PassDAG2Tree,
|
| 52 |
+
PassGetArgumentType,
|
| 53 |
+
PassGetImpl,
|
| 54 |
+
PassFixElementD,
|
| 55 |
+
PassLayoutManipulateElimination,
|
| 56 |
+
PassPreprocessRed,
|
| 57 |
+
PassShapeTypePropagation,
|
| 58 |
+
)
|
| 59 |
+
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
| 60 |
+
from cutlass_cppgen.backend.utils import device_cc
|
| 61 |
+
from cutlass_cppgen.epilogue.evt_ops import permute, reshape
|
| 62 |
+
from cutlass_cppgen.utils.datatypes import library_type
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class EVTFrontendBase:
|
| 66 |
+
layout_fns = {
|
| 67 |
+
"permute": permute,
|
| 68 |
+
"reshape": reshape
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
def __init__(self, cc, element_compute=DataType.f32, additional_passes=[], **kwargs) -> None:
|
| 72 |
+
self.cc = cc
|
| 73 |
+
self.element_compute = library_type(element_compute)
|
| 74 |
+
self.dag_ir = DAGIR(self.cc, self.element_compute)
|
| 75 |
+
self.compute_cnt = 0
|
| 76 |
+
self.layout_cnt = 0
|
| 77 |
+
self.imm_cnt = 0
|
| 78 |
+
|
| 79 |
+
self.pass_manager = EVTPassManager(
|
| 80 |
+
self.dag_ir,
|
| 81 |
+
[
|
| 82 |
+
PassPreprocessRed,
|
| 83 |
+
PassGetArgumentType,
|
| 84 |
+
PassShapeTypePropagation,
|
| 85 |
+
PassLayoutManipulateElimination,
|
| 86 |
+
PassGetImpl,
|
| 87 |
+
PassDAG2Tree,
|
| 88 |
+
PassFixElementD
|
| 89 |
+
] + additional_passes)
|
| 90 |
+
|
| 91 |
+
if self.cc == 80:
|
| 92 |
+
self._epilogue_stages = 1
|
| 93 |
+
else:
|
| 94 |
+
self._epilogue_stages = None
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def epilogue_stages(self):
|
| 98 |
+
return self._epilogue_stages
|
| 99 |
+
|
| 100 |
+
@epilogue_stages.setter
|
| 101 |
+
def epilogue_stages(self, stages):
|
| 102 |
+
self._epilogue_stages = stages
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def parse(self, *args, **kwargs):
|
| 106 |
+
raise NotImplementedError(f"The 'parse' function must be overloaded in frontend class")
|
| 107 |
+
|
| 108 |
+
def trace(self, *args, **kwargs):
|
| 109 |
+
# Parse the input
|
| 110 |
+
self.parse(*args, **kwargs)
|
| 111 |
+
|
| 112 |
+
# Verify the DAG IR to ensure that "D" is the output node with out_degree = 0
|
| 113 |
+
if (self.cc >= 90):
|
| 114 |
+
if (self.dag_ir.out_degree("D") != 0):
|
| 115 |
+
raise RuntimeError(
|
| 116 |
+
f"On SM90 or higher, D is expected to be a output node with 0 users to "
|
| 117 |
+
f"enable smem reuse between C and D, but got {self.dag_ir.out_degree('D')}")
|
| 118 |
+
|
| 119 |
+
# Run the passes
|
| 120 |
+
self.pass_manager()
|
| 121 |
+
# Set the epilogue type
|
| 122 |
+
self.epilogue_thread_type = self.dag_ir.epilogue_thread_type
|
| 123 |
+
if cc_map[self.cc] in [90, 100]:
|
| 124 |
+
self.arg_c_type = self.dag_ir.arg_c_type
|
| 125 |
+
self.arg_d_type = self.dag_ir.arg_d_type
|
| 126 |
+
self.reduction_names = self.dag_ir.reduction_names
|
| 127 |
+
|
| 128 |
+
#
|
| 129 |
+
# Helper functions for DAG IR manipulation
|
| 130 |
+
#
|
| 131 |
+
|
| 132 |
+
def add_node(self, node):
|
| 133 |
+
self.dag_ir.add_node(node)
|
| 134 |
+
|
| 135 |
+
def add_edge(self, src, tgt, weight=0):
|
| 136 |
+
self.dag_ir.add_edge(src, tgt, weight=weight)
|
| 137 |
+
|
| 138 |
+
def set_tensor(self, node_name, example):
|
| 139 |
+
"""
|
| 140 |
+
Add an example tensor to node {node_name} in the DAG IR
|
| 141 |
+
"""
|
| 142 |
+
meta = self.dag_ir.get_node_meta(node_name)
|
| 143 |
+
meta.tensor = {"tensor": example}
|
| 144 |
+
|
| 145 |
+
def set_store_tensor(self, node_name, example):
|
| 146 |
+
"""
|
| 147 |
+
Add an example tensor to node {node_name} in the DAG IR
|
| 148 |
+
"""
|
| 149 |
+
meta = self.dag_ir.get_node_meta(node_name)
|
| 150 |
+
meta.store_tensor = {"tensor": example}
|
| 151 |
+
|
| 152 |
+
def mark_output(self, node_name):
|
| 153 |
+
"""
|
| 154 |
+
Mark a store node as output
|
| 155 |
+
"""
|
| 156 |
+
meta = self.dag_ir.get_node_meta(node_name)
|
| 157 |
+
if not isinstance(meta, StoreNode):
|
| 158 |
+
raise ValueError(
|
| 159 |
+
f"Only StoreNodes can be marked as output. "
|
| 160 |
+
f"Got {type(meta).__name__}: {node_name}")
|
| 161 |
+
meta.is_output = True
|
| 162 |
+
|
| 163 |
+
# Add node with specific type
|
| 164 |
+
|
| 165 |
+
def add_load_node(self, name, example):
|
| 166 |
+
"""
|
| 167 |
+
Add a Load node to DAG IR
|
| 168 |
+
:param name: name of the loaded variable
|
| 169 |
+
:type name: str
|
| 170 |
+
:param example: example input
|
| 171 |
+
:type example: np.ndarray|torch.Tensor|cupy.ndarray|float
|
| 172 |
+
"""
|
| 173 |
+
if name is None:
|
| 174 |
+
raise ValueError(f"Name is not provided.")
|
| 175 |
+
if example is None:
|
| 176 |
+
raise ValueError(f"Example input for {name} is not provided.")
|
| 177 |
+
load_node = LoadNode(name)
|
| 178 |
+
load_node.tensor = {"tensor": example}
|
| 179 |
+
# Special logics for accumulator
|
| 180 |
+
if name == "accum":
|
| 181 |
+
if load_node.tensor.rank == 2:
|
| 182 |
+
new_shape = tuple([1, ] + list(load_node.tensor.shape))
|
| 183 |
+
load_node.tensor.broadcast(new_shape)
|
| 184 |
+
elif load_node.tensor.rank < 2 or load_node.tensor.rank > 3:
|
| 185 |
+
raise ValueError(f"Expect example inputs for 'accum' be a rank-2 or rank-3 tensor. Got {load_node.tensor.shape}.")
|
| 186 |
+
self.add_node(load_node)
|
| 187 |
+
|
| 188 |
+
def add_imm(self, value: Union[float,int]):
|
| 189 |
+
"""
|
| 190 |
+
Add an immediate scalar value to DAG IR
|
| 191 |
+
:param value: the value of the immediate scalar
|
| 192 |
+
:type value: float
|
| 193 |
+
"""
|
| 194 |
+
try:
|
| 195 |
+
value = float(value)
|
| 196 |
+
except:
|
| 197 |
+
raise ValueError(f"{type(value).__name__} cannot be converted to float.")
|
| 198 |
+
|
| 199 |
+
name = f"imm_{value}_k{self.imm_cnt}".replace('.', '_')
|
| 200 |
+
self.imm_cnt += 1
|
| 201 |
+
load_node = LoadNode(name)
|
| 202 |
+
load_node.tensor = {"tensor": value, "is_constant": True}
|
| 203 |
+
self.add_node(load_node)
|
| 204 |
+
return name
|
| 205 |
+
|
| 206 |
+
def add_compute_node(self, op, name=None):
|
| 207 |
+
"""
|
| 208 |
+
Add a compute node.
|
| 209 |
+
:param op: the computation op
|
| 210 |
+
:param name: the node name (optional)
|
| 211 |
+
:type name: str
|
| 212 |
+
:return: the name of the compute node
|
| 213 |
+
"""
|
| 214 |
+
if name is None:
|
| 215 |
+
name = f"compute_{self.compute_cnt}"
|
| 216 |
+
self.compute_cnt += 1
|
| 217 |
+
compute_node = ComputeNode(
|
| 218 |
+
name=name, fn=op,
|
| 219 |
+
element_output=self.element_compute,
|
| 220 |
+
element_compute=self.element_compute)
|
| 221 |
+
self.add_node(compute_node)
|
| 222 |
+
return compute_node.name
|
| 223 |
+
|
| 224 |
+
def add_layout_node(self, op, kwargs, name=None):
|
| 225 |
+
"""
|
| 226 |
+
Add a layout node.
|
| 227 |
+
:param op: the layout op
|
| 228 |
+
:type op: evt_ops
|
| 229 |
+
:param name: the node name (optional)
|
| 230 |
+
:type name: str
|
| 231 |
+
:return: the name of the layout node
|
| 232 |
+
"""
|
| 233 |
+
if name is None:
|
| 234 |
+
name = f"layout_{self.layout_cnt}"
|
| 235 |
+
self.layout_cnt += 1
|
| 236 |
+
layout_node = LayoutNode(name=name, fn=op, kwargs=kwargs)
|
| 237 |
+
self.add_node(layout_node)
|
| 238 |
+
return layout_node.name
|
| 239 |
+
|
| 240 |
+
def add_store_node(self, name):
|
| 241 |
+
store_node = StoreNode(name)
|
| 242 |
+
self.add_node(store_node)
|
| 243 |
+
|
| 244 |
+
#
|
| 245 |
+
# Visualization The DAG IR
|
| 246 |
+
#
|
| 247 |
+
|
| 248 |
+
def visualize(self, name="dag_ir"):
|
| 249 |
+
"""
|
| 250 |
+
Visualize the dag ir with svg file
|
| 251 |
+
:param name: the name of the graph
|
| 252 |
+
"""
|
| 253 |
+
drawer = EVTGraphDrawer(self.dag_ir, name)
|
| 254 |
+
try:
|
| 255 |
+
for name, graph in drawer.get_dot_graph():
|
| 256 |
+
graph.write_svg(f"./{name}.svg")
|
| 257 |
+
except:
|
| 258 |
+
raise RuntimeError(
|
| 259 |
+
"'dot' is not found in path. GraphDrawer is disabled. "
|
| 260 |
+
"Please install it with 'sudo apt-get install graphviz'."
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
#
|
| 264 |
+
# Get shared memory size
|
| 265 |
+
#
|
| 266 |
+
|
| 267 |
+
def get_smem_size(self, tile_description):
|
| 268 |
+
"""
|
| 269 |
+
Get the shared memory size of the epilogue
|
| 270 |
+
"""
|
| 271 |
+
smem_size = GetSmemSize(self.dag_ir)(tile_description)
|
| 272 |
+
return smem_size
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/frontend/python_ast.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Python AST frontend that parses input into DAG IR
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import ast
|
| 38 |
+
import inspect
|
| 39 |
+
import textwrap
|
| 40 |
+
|
| 41 |
+
from cutlass_library import DataType
|
| 42 |
+
|
| 43 |
+
import cutlass_cppgen
|
| 44 |
+
from cutlass_cppgen.backend.evt.frontend.frontend_base import EVTFrontendBase
|
| 45 |
+
from cutlass_cppgen.backend.epilogue import identity, relu, tanh, sigmoid, silu, hardswish, gelu
|
| 46 |
+
from cutlass_cppgen.backend.library import FunctionalOp
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
| 50 |
+
def __init__(self, cc, element_compute=DataType.f32, **kwargs):
|
| 51 |
+
super().__init__(cc, element_compute, **kwargs)
|
| 52 |
+
# Flags
|
| 53 |
+
# If this state is True, visit_Constant returns values without creating imm node
|
| 54 |
+
self.no_imm = False
|
| 55 |
+
self.visiting_return = False
|
| 56 |
+
|
| 57 |
+
def parse(self, example_inputs):
|
| 58 |
+
self.example_inputs = example_inputs
|
| 59 |
+
self.source = textwrap.dedent(inspect.getsource(self.__call__))
|
| 60 |
+
self.ast = ast.parse(self.source)
|
| 61 |
+
self.visit(self.ast)
|
| 62 |
+
|
| 63 |
+
#
|
| 64 |
+
# Helper functions
|
| 65 |
+
#
|
| 66 |
+
@staticmethod
|
| 67 |
+
def ast_op_to_bindings(op):
|
| 68 |
+
mapping = {
|
| 69 |
+
ast.Add: FunctionalOp.Plus,
|
| 70 |
+
ast.Sub: FunctionalOp.Minus,
|
| 71 |
+
ast.Mult: FunctionalOp.Multiplies,
|
| 72 |
+
ast.Div: FunctionalOp.Divides,
|
| 73 |
+
"maximum": FunctionalOp.Maximum,
|
| 74 |
+
"minimum": FunctionalOp.Minimum,
|
| 75 |
+
"identity": identity.binding_type,
|
| 76 |
+
"relu": relu.binding_type,
|
| 77 |
+
"tanh": tanh.binding_type,
|
| 78 |
+
"sigmoid": sigmoid.binding_type,
|
| 79 |
+
"silu": silu.binding_type,
|
| 80 |
+
"hardswish": hardswish.binding_type,
|
| 81 |
+
"gelu": gelu.binding_type,
|
| 82 |
+
"multiply_add": FunctionalOp.MultiplyAdd,
|
| 83 |
+
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
| 84 |
+
"max": (FunctionalOp.Maximum, FunctionalOp.AtomicMaximum),
|
| 85 |
+
"exp": FunctionalOp.Exp
|
| 86 |
+
}
|
| 87 |
+
return mapping[op]
|
| 88 |
+
|
| 89 |
+
#
|
| 90 |
+
# Visiting different node types
|
| 91 |
+
#
|
| 92 |
+
|
| 93 |
+
def visit_FunctionDef(self, node: ast.FunctionDef):
|
| 94 |
+
# Visit args and register load nodes
|
| 95 |
+
for arg in node.args.args:
|
| 96 |
+
self.visit(arg)
|
| 97 |
+
for expr in node.body:
|
| 98 |
+
self.visit(expr)
|
| 99 |
+
|
| 100 |
+
def visit_arg(self, node: ast.arg):
|
| 101 |
+
# Name of the argument
|
| 102 |
+
name = node.arg
|
| 103 |
+
try:
|
| 104 |
+
example_tensor = self.example_inputs[name]
|
| 105 |
+
except:
|
| 106 |
+
raise RuntimeError(f"Example input for {name} is not provided.")
|
| 107 |
+
|
| 108 |
+
self.add_load_node(name, example_tensor)
|
| 109 |
+
|
| 110 |
+
def visit_Name(self, node: ast.Name):
|
| 111 |
+
return node.id
|
| 112 |
+
|
| 113 |
+
def visit_Constant(self, node: ast.Constant):
|
| 114 |
+
if self.no_imm:
|
| 115 |
+
return node.value
|
| 116 |
+
else:
|
| 117 |
+
name = self.add_imm(node.value)
|
| 118 |
+
return name
|
| 119 |
+
|
| 120 |
+
def visit_Tuple(self, node: ast.Tuple):
|
| 121 |
+
results = []
|
| 122 |
+
for elt in node.elts:
|
| 123 |
+
results.append(self.visit(elt))
|
| 124 |
+
return tuple(results)
|
| 125 |
+
|
| 126 |
+
def visit_keyword(self, node: ast.keyword):
|
| 127 |
+
return {node.arg: self.visit(node.value)}
|
| 128 |
+
|
| 129 |
+
def visit_BinOp(self, node: ast.BinOp):
|
| 130 |
+
if self.visiting_return:
|
| 131 |
+
raise SyntaxError("Return value cannot be an expression")
|
| 132 |
+
lhs = self.visit(node.left)
|
| 133 |
+
rhs = self.visit(node.right)
|
| 134 |
+
op = self.ast_op_to_bindings(type(node.op))
|
| 135 |
+
name = self.add_compute_node(op)
|
| 136 |
+
|
| 137 |
+
# Add edges
|
| 138 |
+
# The edge weights are used to sort the input args
|
| 139 |
+
self.add_edge(lhs, name, weight=0)
|
| 140 |
+
self.add_edge(rhs, name, weight=1)
|
| 141 |
+
return name
|
| 142 |
+
|
| 143 |
+
def visit_Assign(self, node: ast.BinOp):
|
| 144 |
+
target = self.visit(node.targets[0])
|
| 145 |
+
value = self.visit(node.value)
|
| 146 |
+
# Create the assign node
|
| 147 |
+
self.add_store_node(target)
|
| 148 |
+
|
| 149 |
+
# Add edges
|
| 150 |
+
self.add_edge(value, target)
|
| 151 |
+
return target
|
| 152 |
+
|
| 153 |
+
def visit_Call(self, node: ast.Call):
|
| 154 |
+
if self.visiting_return:
|
| 155 |
+
raise SyntaxError("Return value cannot be an expression")
|
| 156 |
+
func = self.visit(node.func)
|
| 157 |
+
args = [self.visit(arg) for arg in node.args]
|
| 158 |
+
|
| 159 |
+
if func in self.layout_fns.keys():
|
| 160 |
+
# Parse kwargs
|
| 161 |
+
# By default, visiting imm automatically creates a load node
|
| 162 |
+
# However, in function call, keyword args are used to set
|
| 163 |
+
# specific function attributes such as indices for permute
|
| 164 |
+
# So no_imm is set to True temporarily
|
| 165 |
+
self.no_imm = True
|
| 166 |
+
kwargs = {}
|
| 167 |
+
for kw in node.keywords:
|
| 168 |
+
kwargs.update(self.visit(kw))
|
| 169 |
+
self.no_imm = False
|
| 170 |
+
op = self.layout_fns[func]
|
| 171 |
+
name = self.add_layout_node(op, kwargs)
|
| 172 |
+
else:
|
| 173 |
+
op = self.ast_op_to_bindings(func)
|
| 174 |
+
name = self.add_compute_node(op)
|
| 175 |
+
|
| 176 |
+
# Add edges
|
| 177 |
+
for idx, arg in enumerate(args):
|
| 178 |
+
self.add_edge(arg, name, weight=idx)
|
| 179 |
+
return name
|
| 180 |
+
|
| 181 |
+
def visit_Return(self, node: ast.Return):
|
| 182 |
+
self.visiting_return = True
|
| 183 |
+
results = self.visit(node.value)
|
| 184 |
+
self.visiting_return = False
|
| 185 |
+
self.return_names = results
|
| 186 |
+
if not isinstance(results, tuple):
|
| 187 |
+
results = (results,)
|
| 188 |
+
for rst in results:
|
| 189 |
+
try:
|
| 190 |
+
example_tensor = self.example_inputs[rst]
|
| 191 |
+
except:
|
| 192 |
+
raise RuntimeError(f"Example input for {rst} is not provided.")
|
| 193 |
+
self.set_store_tensor(rst, example_tensor)
|
| 194 |
+
self.mark_output(rst)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode, ComputeImpl
|
| 34 |
+
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
|
| 35 |
+
from cutlass_cppgen.backend.evt.ir.layout_nodes import LayoutNode
|
| 36 |
+
from cutlass_cppgen.backend.evt.ir.load_nodes import (
|
| 37 |
+
LoadNode,
|
| 38 |
+
AccumulatorImpl,
|
| 39 |
+
LoadSrcImpl,
|
| 40 |
+
AuxLoadImpl,
|
| 41 |
+
RowBroadcastImpl,
|
| 42 |
+
ColumnBroadcastImpl,
|
| 43 |
+
ScalarBroadcastImpl
|
| 44 |
+
)
|
| 45 |
+
from cutlass_cppgen.backend.evt.ir.node import TopoVisitorNode, NoOpImpl
|
| 46 |
+
from cutlass_cppgen.backend.evt.ir.store_nodes import (
|
| 47 |
+
StoreNode,
|
| 48 |
+
StoreDImpl,
|
| 49 |
+
AuxStoreImpl,
|
| 50 |
+
ColumnReductionImpl,
|
| 51 |
+
RowReductionImpl,
|
| 52 |
+
ScalarReductionImpl
|
| 53 |
+
)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/compute_nodes.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Python registration for compute nodes in EVT
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
|
| 38 |
+
from cutlass_cppgen.backend.library import FloatRoundStyle
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class ComputeImplBase(ImplBase):
|
| 42 |
+
"""
|
| 43 |
+
Base class for compute implementation
|
| 44 |
+
"""
|
| 45 |
+
def __init__(self, node) -> None:
|
| 46 |
+
super().__init__(node)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class ComputeImpl(ComputeImplBase):
|
| 50 |
+
"""
|
| 51 |
+
Implementation for Compute Node
|
| 52 |
+
"""
|
| 53 |
+
def __init__(self, node) -> None:
|
| 54 |
+
super().__init__(node)
|
| 55 |
+
|
| 56 |
+
self.fn = node.fn
|
| 57 |
+
self.element_output = node.element_output
|
| 58 |
+
self.element_compute = node.element_compute
|
| 59 |
+
self.round_style = node.round_style
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def match(node, problem_size: tuple):
|
| 63 |
+
return True
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ComputeNode(NodeBase):
|
| 67 |
+
"""
|
| 68 |
+
Compute Node in DAG IR
|
| 69 |
+
"""
|
| 70 |
+
possible_impls = [
|
| 71 |
+
ComputeImpl
|
| 72 |
+
]
|
| 73 |
+
def __init__(
|
| 74 |
+
self, name: str, fn, element_output,
|
| 75 |
+
element_compute,
|
| 76 |
+
round_style=FloatRoundStyle.ToNearest) -> None:
|
| 77 |
+
super().__init__(name)
|
| 78 |
+
self.op = "compute"
|
| 79 |
+
self.fn = fn
|
| 80 |
+
self.element_compute = element_compute
|
| 81 |
+
self.round_style = round_style
|
| 82 |
+
|
| 83 |
+
def type_propagation(self, *args, **kwargs):
|
| 84 |
+
"""
|
| 85 |
+
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
| 86 |
+
"""
|
| 87 |
+
self.element = self.element_compute
|
| 88 |
+
# In general, the compute nodes have element_output = element_compute
|
| 89 |
+
# In certain cases like producer of D it is overwritten by other passes
|
| 90 |
+
if not hasattr(self, "element_output"):
|
| 91 |
+
self.element_output = self.element
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/dag_ir.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
DAG IR used by Python EVT
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import networkx as nx
|
| 38 |
+
|
| 39 |
+
from cutlass_library import DataType
|
| 40 |
+
|
| 41 |
+
from cutlass_cppgen.backend.evt.ir.compute_nodes import ComputeNode
|
| 42 |
+
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
| 43 |
+
from cutlass_cppgen.backend.library import ActivationOp
|
| 44 |
+
from cutlass_cppgen.backend.utils import device_cc
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DAGIR:
|
| 48 |
+
"""
|
| 49 |
+
``DAGIR`` is the main data structure used in the EVT Intermediate Representation.
|
| 50 |
+
It consists of a series of ``Node`` s, each representing epilogue visitor nodes.
|
| 51 |
+
|
| 52 |
+
In the DAGIR, ``node`` is an string of its name. ``node_meta`` is the underlying class of the node
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, cc, element_compute=DataType.f32) -> None:
|
| 55 |
+
# The EVT DAGIR is managed through the nextworkX Digraph class
|
| 56 |
+
self._graph = nx.DiGraph()
|
| 57 |
+
|
| 58 |
+
self.element_compute = element_compute
|
| 59 |
+
|
| 60 |
+
self.reduction_names = []
|
| 61 |
+
|
| 62 |
+
self.cc = cc
|
| 63 |
+
|
| 64 |
+
self.identity_counter = 0
|
| 65 |
+
|
| 66 |
+
#
|
| 67 |
+
# IR manipulator
|
| 68 |
+
#
|
| 69 |
+
|
| 70 |
+
def add_node(self, meta: NodeBase):
|
| 71 |
+
"""
|
| 72 |
+
Add a node to dag ir
|
| 73 |
+
"""
|
| 74 |
+
if self.has_node(meta.name):
|
| 75 |
+
raise SyntaxError(f"Variable '{meta.name}' cannot be defined twice.")
|
| 76 |
+
self._graph.add_node(meta.name, meta=meta)
|
| 77 |
+
|
| 78 |
+
def add_edge(self, src: str, dst: str, weight: int=0):
|
| 79 |
+
"""
|
| 80 |
+
Add an edge src -> dst to dag ir with weight
|
| 81 |
+
"""
|
| 82 |
+
if not self.has_node(src):
|
| 83 |
+
raise SyntaxError(f"Variable '{src}' is undefined.")
|
| 84 |
+
if not self.has_node(dst):
|
| 85 |
+
raise SyntaxError(f"Variable '{dst}' is undefined.")
|
| 86 |
+
|
| 87 |
+
if self._graph.has_edge(src, dst):
|
| 88 |
+
# The DiGraph doesn't support multiple edges between two nodes
|
| 89 |
+
# We insert an identity node in such case as a workaround
|
| 90 |
+
identity_name = f"autogen_identity_{self.identity_counter}"
|
| 91 |
+
self.identity_counter += 1
|
| 92 |
+
compute_node = ComputeNode(
|
| 93 |
+
name=identity_name, fn=ActivationOp.Identity,
|
| 94 |
+
element_output=self.element_compute,
|
| 95 |
+
element_compute=self.element_compute)
|
| 96 |
+
self.add_node(compute_node)
|
| 97 |
+
self.add_edge(src, identity_name, 0)
|
| 98 |
+
self.add_edge(identity_name, dst, weight)
|
| 99 |
+
else:
|
| 100 |
+
self._graph.add_edge(src, dst, weight=weight)
|
| 101 |
+
|
| 102 |
+
def remove_node(self, node: str):
|
| 103 |
+
"""
|
| 104 |
+
Remove node from dag ir
|
| 105 |
+
"""
|
| 106 |
+
self._graph.remove_node(node)
|
| 107 |
+
|
| 108 |
+
def remove_edge(self, src: str, dst: str):
|
| 109 |
+
"""
|
| 110 |
+
Remove edge src -> dst
|
| 111 |
+
"""
|
| 112 |
+
self._graph.remove_edge(src, dst)
|
| 113 |
+
|
| 114 |
+
#
|
| 115 |
+
# Helper functions for getting attrs
|
| 116 |
+
#
|
| 117 |
+
|
| 118 |
+
def has_node(self, node: str) -> bool:
|
| 119 |
+
"""
|
| 120 |
+
Check if the node is in the graph
|
| 121 |
+
"""
|
| 122 |
+
return self._graph.has_node(node)
|
| 123 |
+
|
| 124 |
+
def in_degree(self, node: str):
|
| 125 |
+
"""
|
| 126 |
+
Get the input degree of node
|
| 127 |
+
"""
|
| 128 |
+
return self._graph.in_degree(node)
|
| 129 |
+
|
| 130 |
+
def in_edges(self, node: str):
|
| 131 |
+
"""
|
| 132 |
+
Get the input edges of node
|
| 133 |
+
"""
|
| 134 |
+
return [edge for edge in self._graph.in_edges(node)]
|
| 135 |
+
|
| 136 |
+
def out_degree(self, node: str):
|
| 137 |
+
"""
|
| 138 |
+
Get the output degree of node
|
| 139 |
+
"""
|
| 140 |
+
return self._graph.out_degree(node)
|
| 141 |
+
|
| 142 |
+
def out_edges(self, node: str):
|
| 143 |
+
"""
|
| 144 |
+
Get the output edges of node
|
| 145 |
+
"""
|
| 146 |
+
return [edge for edge in self._graph.out_edges(node)]
|
| 147 |
+
|
| 148 |
+
def get_node_meta(self, node: str):
|
| 149 |
+
"""
|
| 150 |
+
Get the meta data of the node
|
| 151 |
+
"""
|
| 152 |
+
return self._graph.nodes[node]["meta"]
|
| 153 |
+
|
| 154 |
+
def get_edge_weight(self, src, dst):
|
| 155 |
+
"""
|
| 156 |
+
Get the edge weight of edge src->dst
|
| 157 |
+
"""
|
| 158 |
+
return self._graph.get_edge_data(src, dst)["weight"]
|
| 159 |
+
|
| 160 |
+
#
|
| 161 |
+
# High-level helper functions
|
| 162 |
+
#
|
| 163 |
+
|
| 164 |
+
def all_reachable_nodes(self, node: str):
|
| 165 |
+
"""
|
| 166 |
+
Get all the nodes reachable from the current node (exclude)
|
| 167 |
+
"""
|
| 168 |
+
return list(nx.dfs_preorder_nodes(self._graph, source=node))
|
| 169 |
+
|
| 170 |
+
def get_users(self, node: str):
|
| 171 |
+
"""
|
| 172 |
+
Get all users of the current node
|
| 173 |
+
"""
|
| 174 |
+
return [edge[1] for edge in self.out_edges(node)]
|
| 175 |
+
|
| 176 |
+
def get_all_inputs(self, node: str):
|
| 177 |
+
"""
|
| 178 |
+
Get all the input nodes sorted by edge weight
|
| 179 |
+
"""
|
| 180 |
+
in_edges = self.in_edges(node)
|
| 181 |
+
edge_weights = [self.get_edge_weight(*edge) for edge in in_edges]
|
| 182 |
+
return [edge[0] for _, edge in sorted(zip(edge_weights, in_edges))]
|
| 183 |
+
|
| 184 |
+
def get_all_inputs_meta(self, node: str):
|
| 185 |
+
"""
|
| 186 |
+
Get all the input node metas sorted by edge weight
|
| 187 |
+
"""
|
| 188 |
+
return [self.get_node_meta(input_node) for input_node in self.get_all_inputs(node)]
|
| 189 |
+
|
| 190 |
+
def replace_all_uses_with(self, node1, node2):
|
| 191 |
+
"""
|
| 192 |
+
Replace all uses of node1 with node2
|
| 193 |
+
"""
|
| 194 |
+
for edge in self.out_edges(node1):
|
| 195 |
+
weight = self.get_edge_weight(*edge)
|
| 196 |
+
user = edge[1]
|
| 197 |
+
self.add_edge(node2, user, weight)
|
| 198 |
+
self.remove_edge(node1, user)
|
| 199 |
+
self.remove_node(node1)
|
| 200 |
+
|
| 201 |
+
#
|
| 202 |
+
# Node accessor
|
| 203 |
+
#
|
| 204 |
+
def nodes_topological_order(self):
|
| 205 |
+
"""
|
| 206 |
+
Get the nodes in the unique lexicographical topological order
|
| 207 |
+
It generates a unique ordering of nodes by first sorting topologically
|
| 208 |
+
and then additionally by sorting lexicographically.
|
| 209 |
+
|
| 210 |
+
Although topological_sort alone also works, this generates a unique key
|
| 211 |
+
for each epilogue visitor pattern and ensures the compilation cache can be reused.
|
| 212 |
+
:return: list[str]
|
| 213 |
+
"""
|
| 214 |
+
return list(nx.lexicographical_topological_sort(self._graph))
|
| 215 |
+
|
| 216 |
+
def node_metas_topological_order(self):
|
| 217 |
+
"""
|
| 218 |
+
Get the node metas in topological order
|
| 219 |
+
:return: list[NodeBase]
|
| 220 |
+
"""
|
| 221 |
+
return [self.get_node_meta(node) for node in self.nodes_topological_order()]
|
| 222 |
+
|
| 223 |
+
@property
|
| 224 |
+
def nodes(self):
|
| 225 |
+
"""
|
| 226 |
+
Get all nodes
|
| 227 |
+
:return: list[str]
|
| 228 |
+
"""
|
| 229 |
+
return list(self._graph.nodes)
|
| 230 |
+
|
| 231 |
+
@property
|
| 232 |
+
def nodes_meta(self):
|
| 233 |
+
"""
|
| 234 |
+
Get all node metas
|
| 235 |
+
:return: list[NodeBase]
|
| 236 |
+
"""
|
| 237 |
+
return [data[1]['meta'] for data in self._graph.nodes.data()]
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def edges(self):
|
| 241 |
+
"""
|
| 242 |
+
Get all edges
|
| 243 |
+
:return: list[(str, str)]
|
| 244 |
+
"""
|
| 245 |
+
return list(self._graph.edges)
|
| 246 |
+
|
| 247 |
+
#
|
| 248 |
+
# Path
|
| 249 |
+
#
|
| 250 |
+
def has_path(self, src: str, target: str) -> bool:
|
| 251 |
+
"""
|
| 252 |
+
Return True is a path exists from src to target
|
| 253 |
+
"""
|
| 254 |
+
return nx.has_path(self._graph, src, target)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_algorithm.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Layout algebras
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from pycute import Layout, composition, make_layout, flatten, product
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _infer_split(old_shape, new_shape):
|
| 41 |
+
old_shape = _tuple_to_list(old_shape)
|
| 42 |
+
new_shape = _tuple_to_list(new_shape)
|
| 43 |
+
if len(old_shape) == 0 and len(new_shape) == 0:
|
| 44 |
+
return []
|
| 45 |
+
if len(old_shape) == 0:
|
| 46 |
+
if product(tuple(new_shape)) != 1:
|
| 47 |
+
raise ValueError("Invalid reshape size")
|
| 48 |
+
else:
|
| 49 |
+
return new_shape
|
| 50 |
+
if len(new_shape) == 0:
|
| 51 |
+
if product(tuple(old_shape)) != 1:
|
| 52 |
+
raise ValueError("Invalid reshape size")
|
| 53 |
+
else:
|
| 54 |
+
return old_shape
|
| 55 |
+
# This is done recursively by only process the last dimension at each time
|
| 56 |
+
old_dim = old_shape[-1]
|
| 57 |
+
new_dim = new_shape[-1]
|
| 58 |
+
# Exact match
|
| 59 |
+
if old_dim == new_dim:
|
| 60 |
+
return _infer_split(old_shape[:-1], new_shape[:-1]) + [new_dim,]
|
| 61 |
+
# Needs split
|
| 62 |
+
if old_dim > new_dim and old_dim % new_dim == 0:
|
| 63 |
+
residual = old_dim // new_dim
|
| 64 |
+
return _infer_split(old_shape[:-1] + [residual,], new_shape[:-1]) + [new_dim,]
|
| 65 |
+
# Needs merge
|
| 66 |
+
if old_dim < new_dim and new_dim % old_dim == 0:
|
| 67 |
+
residual = new_dim // old_dim
|
| 68 |
+
return _infer_split(old_shape[:-1], new_shape[:-1] + [residual,]) + [old_dim,]
|
| 69 |
+
|
| 70 |
+
raise NotImplementedError(f"Unsupported split: {old_shape} -> {new_shape}")
|
| 71 |
+
|
| 72 |
+
def _infer_merge(flatten_shape, shape):
|
| 73 |
+
flatten_shape = _tuple_to_list(flatten_shape)
|
| 74 |
+
shape = _tuple_to_list(shape)
|
| 75 |
+
idx_flat = 0
|
| 76 |
+
merged_shape = []
|
| 77 |
+
for dim in shape:
|
| 78 |
+
# Exact match
|
| 79 |
+
if dim == flatten_shape[idx_flat]:
|
| 80 |
+
merged_shape.append(dim)
|
| 81 |
+
idx_flat += 1
|
| 82 |
+
# Need group
|
| 83 |
+
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
| 84 |
+
residual = dim
|
| 85 |
+
group = []
|
| 86 |
+
while(residual > 1):
|
| 87 |
+
group.append(flatten_shape[idx_flat])
|
| 88 |
+
residual = residual // flatten_shape[idx_flat]
|
| 89 |
+
idx_flat += 1
|
| 90 |
+
merged_shape.append(group)
|
| 91 |
+
else:
|
| 92 |
+
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
| 93 |
+
|
| 94 |
+
return merged_shape
|
| 95 |
+
|
| 96 |
+
def _list_to_tuple(nested_list):
|
| 97 |
+
if isinstance(nested_list, list) or isinstance(nested_list, tuple):
|
| 98 |
+
return tuple(_list_to_tuple(item) for item in nested_list)
|
| 99 |
+
return nested_list
|
| 100 |
+
|
| 101 |
+
def _tuple_to_list(nested_tuple):
|
| 102 |
+
if isinstance(nested_tuple, list) or isinstance(nested_tuple, tuple):
|
| 103 |
+
return list(_tuple_to_list(item) for item in nested_tuple)
|
| 104 |
+
return nested_tuple
|
| 105 |
+
|
| 106 |
+
def _reverse_tuple(nested_tuple: tuple):
|
| 107 |
+
if isinstance(nested_tuple, tuple):
|
| 108 |
+
return tuple([_reverse_tuple(item) for item in nested_tuple][::-1])
|
| 109 |
+
return nested_tuple
|
| 110 |
+
|
| 111 |
+
def _get_first_lhs_nonzero_stride(stride_list, idx):
|
| 112 |
+
for i in reversed(range(idx)):
|
| 113 |
+
if stride_list[i] != 0:
|
| 114 |
+
return i
|
| 115 |
+
else:
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
def _get_first_rhs_nonzero_stride(stride_list, idx):
|
| 119 |
+
for i in range(idx+1, len(stride_list)):
|
| 120 |
+
if stride_list[i] != 0:
|
| 121 |
+
return i
|
| 122 |
+
else:
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
def reshape(layout, new_shape):
|
| 126 |
+
"""
|
| 127 |
+
General reshape of input layout.
|
| 128 |
+
It takes two steps:
|
| 129 |
+
1. split the dimensions of the old layout
|
| 130 |
+
2. merge the splitted dimensions according to the new shape
|
| 131 |
+
"""
|
| 132 |
+
#
|
| 133 |
+
# Step 1: Split the dimensions of the old layout
|
| 134 |
+
#
|
| 135 |
+
# 1.1 Flat old and new shape
|
| 136 |
+
old_flatten_shape = list(flatten(layout.shape))
|
| 137 |
+
new_flatten_shape = list(flatten(new_shape))
|
| 138 |
+
|
| 139 |
+
# 1.2 Infer the flatten splitted shape
|
| 140 |
+
splitted_flatten_shape = _infer_split(old_flatten_shape, new_flatten_shape)
|
| 141 |
+
|
| 142 |
+
# 1.3 Unflat the splitted shape based on the old shape
|
| 143 |
+
splited_shape = _infer_merge(splitted_flatten_shape, old_flatten_shape)
|
| 144 |
+
|
| 145 |
+
# 1.4 Infer the type of each split
|
| 146 |
+
# If the split type is in row-major (R), the dimension list is reversed because
|
| 147 |
+
# the cute::composition only support column-major split
|
| 148 |
+
split_type = [] # the type of each split (ColumnMajor or RowMajor)
|
| 149 |
+
permuted_splitted_shape = []
|
| 150 |
+
old_flatten_stride = list(flatten(layout.stride))
|
| 151 |
+
for idx, dim in enumerate(splited_shape):
|
| 152 |
+
if not isinstance(dim, list):
|
| 153 |
+
permuted_splitted_shape.append(dim)
|
| 154 |
+
split_type.append("C")
|
| 155 |
+
else:
|
| 156 |
+
lhs_stride = _get_first_lhs_nonzero_stride(old_flatten_stride, idx)
|
| 157 |
+
rhs_stride = _get_first_rhs_nonzero_stride(old_flatten_stride, idx)
|
| 158 |
+
# Special case for single tuple
|
| 159 |
+
# Use column-major by default
|
| 160 |
+
if lhs_stride is None and rhs_stride is None:
|
| 161 |
+
permuted_splitted_shape.append(dim)
|
| 162 |
+
split_type.append("C")
|
| 163 |
+
else:
|
| 164 |
+
if lhs_stride is not None and rhs_stride is not None:
|
| 165 |
+
# We consider shape[idx]:stride[idx]
|
| 166 |
+
# Case 1: stride[idx - 1] <= stride[idx] <= stride[idx + 1]: column major
|
| 167 |
+
if lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
|
| 168 |
+
permuted_splitted_shape.append(dim)
|
| 169 |
+
split_type.append("C")
|
| 170 |
+
# Case 2: stride[idx - 1] > stride[idx] > stride[idx + 1]: row major
|
| 171 |
+
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
|
| 172 |
+
permuted_splitted_shape.append([d for d in reversed(dim)])
|
| 173 |
+
split_type.append("R")
|
| 174 |
+
# Case 3: stride[idx - 1] <= stride[idx] > stride[idx + 1]: concave
|
| 175 |
+
elif lhs_stride <= old_flatten_stride[idx] and old_flatten_stride[idx] > rhs_stride:
|
| 176 |
+
if lhs_stride >= rhs_stride:
|
| 177 |
+
permuted_splitted_shape.append(dim)
|
| 178 |
+
split_type.append("C")
|
| 179 |
+
else:
|
| 180 |
+
permuted_splitted_shape.append([d for d in reversed(dim)])
|
| 181 |
+
split_type.append("R")
|
| 182 |
+
# Case 4: stride[idx - 1] > stride[idx] <= stride[idx + 1]: concave
|
| 183 |
+
elif lhs_stride > old_flatten_stride[idx] and old_flatten_stride[idx] <= rhs_stride:
|
| 184 |
+
if lhs_stride >= rhs_stride:
|
| 185 |
+
permuted_splitted_shape.append(dim)
|
| 186 |
+
split_type.append("C")
|
| 187 |
+
else:
|
| 188 |
+
permuted_splitted_shape.append([d for d in reversed(dim)])
|
| 189 |
+
split_type.append("R")
|
| 190 |
+
else:
|
| 191 |
+
raise NotImplementedError()
|
| 192 |
+
elif lhs_stride is None:
|
| 193 |
+
# Case 1: dim's stride < dim+1's stride, expand in column major
|
| 194 |
+
if old_flatten_stride[idx] > rhs_stride:
|
| 195 |
+
permuted_splitted_shape.append([d for d in reversed(dim)])
|
| 196 |
+
split_type.append("R")
|
| 197 |
+
else:
|
| 198 |
+
permuted_splitted_shape.append(dim)
|
| 199 |
+
split_type.append("C")
|
| 200 |
+
else:
|
| 201 |
+
# Case 1: dim's stride > dim-1's stride
|
| 202 |
+
if old_flatten_stride[idx] < lhs_stride:
|
| 203 |
+
permuted_splitted_shape.append([d for d in reversed(dim)])
|
| 204 |
+
split_type.append("R")
|
| 205 |
+
else:
|
| 206 |
+
permuted_splitted_shape.append(dim)
|
| 207 |
+
split_type.append("C")
|
| 208 |
+
|
| 209 |
+
# 1.4 Generate the splitted layout
|
| 210 |
+
permuted_splitted_layout = composition(layout, Layout(_list_to_tuple(permuted_splitted_shape)))
|
| 211 |
+
|
| 212 |
+
# 1.5 Reverse the permutation in 1.4 before merge
|
| 213 |
+
splitted_shape = []
|
| 214 |
+
splitted_stride = []
|
| 215 |
+
for shape_dim, stride_dim, type in zip(
|
| 216 |
+
permuted_splitted_layout.shape,
|
| 217 |
+
permuted_splitted_layout.stride,
|
| 218 |
+
split_type):
|
| 219 |
+
if type == "C":
|
| 220 |
+
splitted_shape.append(shape_dim)
|
| 221 |
+
splitted_stride.append(stride_dim)
|
| 222 |
+
else:
|
| 223 |
+
splitted_shape.append(tuple([d for d in reversed(shape_dim)]))
|
| 224 |
+
splitted_stride.append(tuple([d for d in reversed(stride_dim)]))
|
| 225 |
+
splitted_layout = Layout(tuple(splitted_shape), tuple(splitted_stride))
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
#
|
| 229 |
+
# Step 2: Merge the splitted dimensions according to the new shape
|
| 230 |
+
#
|
| 231 |
+
# 2.1 Merge layout
|
| 232 |
+
merged_layout = composition(splitted_layout, Layout(new_shape))
|
| 233 |
+
|
| 234 |
+
# 2.2 Cleaning up
|
| 235 |
+
output_layout = composition(merged_layout, Layout(new_shape))
|
| 236 |
+
return output_layout
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def permutation(layout, permutation):
|
| 240 |
+
"""
|
| 241 |
+
Permute the layout
|
| 242 |
+
"""
|
| 243 |
+
new_shape = tuple([layout.shape[idx] for idx in permutation])
|
| 244 |
+
new_stride = tuple([layout.stride[idx] for idx in permutation])
|
| 245 |
+
return Layout(new_shape, new_stride)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def _broadcast(layout, new_shape):
|
| 249 |
+
if len(layout) == 1 and isinstance(new_shape, int):
|
| 250 |
+
old_dim = layout.shape
|
| 251 |
+
old_stride = layout.stride
|
| 252 |
+
new_dim = new_shape
|
| 253 |
+
if old_dim == new_dim:
|
| 254 |
+
return Layout(old_dim, old_stride)
|
| 255 |
+
elif old_dim == 1:
|
| 256 |
+
return Layout(new_dim, 0)
|
| 257 |
+
else:
|
| 258 |
+
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {new_dim}")
|
| 259 |
+
|
| 260 |
+
# Align the dimensions
|
| 261 |
+
old_shape = layout.shape
|
| 262 |
+
if isinstance(old_shape, int):
|
| 263 |
+
old_shape = (old_shape,)
|
| 264 |
+
sub_layouts = [layout,]
|
| 265 |
+
else:
|
| 266 |
+
sub_layouts = [sub_layout for sub_layout in layout]
|
| 267 |
+
rhs_broadcast_layouts = [Layout(1, 0)] * (len(new_shape) - len(old_shape))
|
| 268 |
+
# Get the broadcasted layout
|
| 269 |
+
broadcast_layouts = []
|
| 270 |
+
try:
|
| 271 |
+
layout = make_layout(*sub_layouts, *rhs_broadcast_layouts)
|
| 272 |
+
broadcast_layouts = []
|
| 273 |
+
for idx, sub_layout in enumerate(layout):
|
| 274 |
+
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
| 275 |
+
except NotImplementedError:
|
| 276 |
+
layout = make_layout(*rhs_broadcast_layouts, *sub_layouts)
|
| 277 |
+
for idx, sub_layout in enumerate(layout):
|
| 278 |
+
broadcast_layouts.append(_broadcast(sub_layout, new_shape[idx]))
|
| 279 |
+
return make_layout(*broadcast_layouts)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def broadcast(layout, new_shape):
|
| 283 |
+
"""
|
| 284 |
+
Broadcast the new layout based on the input shape
|
| 285 |
+
The broadcasted shape equals to the new shape
|
| 286 |
+
The stride of broadcasted dimensions are 0
|
| 287 |
+
"""
|
| 288 |
+
return _broadcast(layout, new_shape)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def debroadcast(layout, dims):
|
| 292 |
+
"""
|
| 293 |
+
Squeeze the 0-stride
|
| 294 |
+
"""
|
| 295 |
+
for dim in dims:
|
| 296 |
+
if layout.stride[dim] != 0:
|
| 297 |
+
raise ValueError(f"Dim{dim} cannot be debroadcasted as it has stride {layout.stride[dim]}")
|
| 298 |
+
new_shape = tuple([s for idx, s in enumerate(layout.shape) if idx not in dims])
|
| 299 |
+
new_stride = tuple([s for idx, s in enumerate(layout.stride) if idx not in dims])
|
| 300 |
+
return Layout(new_shape, new_stride)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def canonicalization_(shapes, strides):
|
| 304 |
+
if isinstance(shapes, tuple):
|
| 305 |
+
c_shapes = []
|
| 306 |
+
c_strides = []
|
| 307 |
+
for shape, stride in zip(shapes, strides):
|
| 308 |
+
c_shape, c_stride = canonicalization_(shape, stride)
|
| 309 |
+
c_shapes.append(c_shape)
|
| 310 |
+
c_strides.append(c_stride)
|
| 311 |
+
return tuple(c_shapes), tuple(c_strides)
|
| 312 |
+
else:
|
| 313 |
+
if shapes == 1:
|
| 314 |
+
return 1, 0
|
| 315 |
+
else:
|
| 316 |
+
return shapes, strides
|
| 317 |
+
|
| 318 |
+
def canonicalization(layout):
|
| 319 |
+
"""
|
| 320 |
+
Canonicalize the input layout
|
| 321 |
+
1. set the stride of shape "1" to 0
|
| 322 |
+
"""
|
| 323 |
+
new_shape, new_stride = canonicalization_(layout.shape, layout.stride)
|
| 324 |
+
return Layout(new_shape, new_stride)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Layout manipulation nodes and implementations
|
| 35 |
+
|
| 36 |
+
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
from copy import deepcopy
|
| 40 |
+
|
| 41 |
+
from cutlass_library import LayoutType
|
| 42 |
+
from pycute import product, flatten
|
| 43 |
+
|
| 44 |
+
import cutlass_cppgen
|
| 45 |
+
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
|
| 46 |
+
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
| 47 |
+
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class PermutationImpl:
|
| 51 |
+
"""
|
| 52 |
+
Detailed implementation and helper functions for permutation
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, node) -> None:
|
| 55 |
+
assert "indices" in node.kwargs.keys()
|
| 56 |
+
self.indices = list(node.kwargs["indices"])
|
| 57 |
+
self.inverse_indices = self.get_inverse_indices(self.indices)
|
| 58 |
+
|
| 59 |
+
def get_inverse_impl(self):
|
| 60 |
+
inverse_impl = deepcopy(self)
|
| 61 |
+
inverse_impl.indices = self.inverse_indices
|
| 62 |
+
inverse_impl.inverse_indices = self.indices
|
| 63 |
+
return inverse_impl
|
| 64 |
+
|
| 65 |
+
def update(self, shape):
|
| 66 |
+
num_dim = len(shape)
|
| 67 |
+
indices = self.indices
|
| 68 |
+
num_old_dim = len(indices)
|
| 69 |
+
# Add offset
|
| 70 |
+
for i, idx in enumerate(indices):
|
| 71 |
+
indices[i] = idx + num_dim - num_old_dim
|
| 72 |
+
# Add broadcast dims
|
| 73 |
+
for i in range(num_dim - num_old_dim):
|
| 74 |
+
indices = [i,] + indices
|
| 75 |
+
|
| 76 |
+
self.indices = indices
|
| 77 |
+
self.inverse_indices = self.get_inverse_indices(self.indices)
|
| 78 |
+
|
| 79 |
+
def get_inverse_indices(self, indices):
|
| 80 |
+
"""
|
| 81 |
+
Get the indices for inverse permutation
|
| 82 |
+
"""
|
| 83 |
+
num_dim = len(indices)
|
| 84 |
+
inverse_indices = [0] * num_dim
|
| 85 |
+
for i in range(num_dim):
|
| 86 |
+
inverse_indices[indices[i]] = i
|
| 87 |
+
return inverse_indices
|
| 88 |
+
|
| 89 |
+
def shape_propagation(self, input_node_meta):
|
| 90 |
+
input_shape = input_node_meta.tensor.shape
|
| 91 |
+
output_shape = tuple([input_shape[idx] for idx in self.indices])
|
| 92 |
+
return output_shape
|
| 93 |
+
|
| 94 |
+
def broadcast(self, shape, node_meta: NodeBase):
|
| 95 |
+
"""
|
| 96 |
+
Broadcast the inputs based on current shape
|
| 97 |
+
"""
|
| 98 |
+
self.update(shape)
|
| 99 |
+
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
|
| 100 |
+
node_meta.tensor.broadcast(inverse_shape)
|
| 101 |
+
|
| 102 |
+
def apply_to_user(self, usr_meta: NodeBase):
|
| 103 |
+
"""
|
| 104 |
+
Propagate the permutation to the users of the current nodes
|
| 105 |
+
"""
|
| 106 |
+
usr_meta.tensor.permute(self.inverse_indices)
|
| 107 |
+
if hasattr(usr_meta, "store_tensor"):
|
| 108 |
+
if usr_meta.store_tensor is not None:
|
| 109 |
+
usr_meta.store_tensor.permute(self.inverse_indices)
|
| 110 |
+
|
| 111 |
+
def apply_to_input(self, input_meta: NodeBase):
|
| 112 |
+
"""
|
| 113 |
+
Propagate the permutation to inputs of the current nodes
|
| 114 |
+
"""
|
| 115 |
+
input_meta.tensor.permute(self.indices)
|
| 116 |
+
if hasattr(input_meta, "store_tensor"):
|
| 117 |
+
if input_meta.store_tensor is not None:
|
| 118 |
+
input_meta.store_tensor.permute(self.indices)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class ReshapeImpl:
|
| 122 |
+
"""
|
| 123 |
+
Detailed implementation and helper functions for reshape
|
| 124 |
+
"""
|
| 125 |
+
def __init__(self, node) -> None:
|
| 126 |
+
self.node = node
|
| 127 |
+
assert "new_shape" in node.kwargs.keys()
|
| 128 |
+
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
|
| 129 |
+
|
| 130 |
+
def get_inverse_impl(self):
|
| 131 |
+
inverse_impl = deepcopy(self)
|
| 132 |
+
inverse_impl.output_shape = self.input_shape
|
| 133 |
+
inverse_impl.input_shape = self.output_shape
|
| 134 |
+
return inverse_impl
|
| 135 |
+
|
| 136 |
+
def shape_propagation(self, input_node_meta):
|
| 137 |
+
self.input_shape = input_node_meta.tensor.shape
|
| 138 |
+
return _list_to_tuple(self.output_shape)
|
| 139 |
+
|
| 140 |
+
def broadcast(self, shape, node_meta: NodeBase):
|
| 141 |
+
"""
|
| 142 |
+
Broadcast the inputs based on current shape.
|
| 143 |
+
"""
|
| 144 |
+
# Step 1: infer split
|
| 145 |
+
flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
|
| 146 |
+
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
|
| 147 |
+
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
|
| 148 |
+
|
| 149 |
+
# broadcast shape -> split_output_shape -> flatten_split_shape
|
| 150 |
+
if len(shape) - len(split_output_shape) > 0:
|
| 151 |
+
for _ in range(len(shape) - len(split_output_shape)):
|
| 152 |
+
split_output_shape = [1,] + split_output_shape
|
| 153 |
+
flatten_split_shape = [1,] + flatten_split_shape
|
| 154 |
+
split_input_shape = [1,] + split_input_shape
|
| 155 |
+
broadcast_factor = []
|
| 156 |
+
for dim, old_dim in zip(shape, split_output_shape):
|
| 157 |
+
if not isinstance(dim, list):
|
| 158 |
+
dim = [dim,]
|
| 159 |
+
if not isinstance(old_dim, list):
|
| 160 |
+
old_dim = [old_dim,]
|
| 161 |
+
if product(tuple(dim)) == product(tuple(old_dim)):
|
| 162 |
+
broadcast_factor += [1] * len(old_dim)
|
| 163 |
+
elif product(tuple(old_dim)) == 1:
|
| 164 |
+
assert len(dim) == 1
|
| 165 |
+
broadcast_factor.append(dim[0])
|
| 166 |
+
else:
|
| 167 |
+
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
|
| 168 |
+
|
| 169 |
+
# flatten_split_shape -> split_input_shape
|
| 170 |
+
factor_idx = 0
|
| 171 |
+
broadcast_split_input_shape = []
|
| 172 |
+
for dim in split_input_shape:
|
| 173 |
+
if isinstance(dim, list):
|
| 174 |
+
new_dim = []
|
| 175 |
+
for d in dim:
|
| 176 |
+
new_dim.append(d * broadcast_factor[factor_idx])
|
| 177 |
+
factor_idx += 1
|
| 178 |
+
broadcast_split_input_shape.append(new_dim)
|
| 179 |
+
else:
|
| 180 |
+
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
|
| 181 |
+
factor_idx += 1
|
| 182 |
+
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
|
| 183 |
+
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
|
| 184 |
+
node_meta.tensor.broadcast(broadcast_split_input_shape)
|
| 185 |
+
# Last reshape op to clean up
|
| 186 |
+
broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
|
| 187 |
+
node_meta.tensor.reshape(broadcast_input_shape)
|
| 188 |
+
# Update the input shape and output shape
|
| 189 |
+
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
|
| 190 |
+
self.output_shape = _list_to_tuple(shape)
|
| 191 |
+
|
| 192 |
+
def apply_to_user(self, user_meta: NodeBase):
|
| 193 |
+
"""
|
| 194 |
+
Propagate the reshape to user nodes
|
| 195 |
+
"""
|
| 196 |
+
user_meta.tensor.reshape(tuple(self.input_shape))
|
| 197 |
+
if hasattr(user_meta, "store_tensor"):
|
| 198 |
+
if user_meta.store_tensor is not None:
|
| 199 |
+
user_meta.store_tensor.reshape(tuple(self.input_shape))
|
| 200 |
+
|
| 201 |
+
def apply_to_input(self, input_meta: NodeBase):
|
| 202 |
+
"""
|
| 203 |
+
Propagate the reshape to input nodes
|
| 204 |
+
"""
|
| 205 |
+
input_meta.tensor.reshape(tuple(self.output_shape))
|
| 206 |
+
if hasattr(input_meta, "store_tensor"):
|
| 207 |
+
if input_meta.store_tensor is not None:
|
| 208 |
+
input_meta.store_tensor.reshape(tuple(self.output_shape))
|
| 209 |
+
|
| 210 |
+
#
|
| 211 |
+
# Helper functions
|
| 212 |
+
#
|
| 213 |
+
|
| 214 |
+
def infer_split(self, input_shape, output_shape):
|
| 215 |
+
"""
|
| 216 |
+
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
|
| 217 |
+
"""
|
| 218 |
+
input_shape = _tuple_to_list(input_shape)
|
| 219 |
+
output_shape = _tuple_to_list(output_shape)
|
| 220 |
+
if len(input_shape) == 0 and len(output_shape) == 0:
|
| 221 |
+
return []
|
| 222 |
+
if len(input_shape) == 0:
|
| 223 |
+
if product(tuple(output_shape)) != 1:
|
| 224 |
+
raise ValueError("Invalid reshape size")
|
| 225 |
+
else:
|
| 226 |
+
return output_shape
|
| 227 |
+
if len(output_shape) == 0:
|
| 228 |
+
if product(tuple(input_shape)) != 1:
|
| 229 |
+
raise ValueError("Invalid reshape size")
|
| 230 |
+
else:
|
| 231 |
+
return input_shape
|
| 232 |
+
# This is done recursively by only process the last dimension at each time
|
| 233 |
+
old_dim = input_shape[-1]
|
| 234 |
+
new_dim = output_shape[-1]
|
| 235 |
+
# Exact match
|
| 236 |
+
if old_dim == new_dim:
|
| 237 |
+
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
|
| 238 |
+
# Needs split
|
| 239 |
+
if old_dim > new_dim and old_dim % new_dim == 0:
|
| 240 |
+
residual = old_dim // new_dim
|
| 241 |
+
return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
|
| 242 |
+
# Needs merge
|
| 243 |
+
if old_dim < new_dim and new_dim % old_dim == 0:
|
| 244 |
+
residual = new_dim // old_dim
|
| 245 |
+
return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
|
| 246 |
+
|
| 247 |
+
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
|
| 248 |
+
|
| 249 |
+
def infer_merge(self, flatten_shape, shape):
|
| 250 |
+
flatten_shape = _tuple_to_list(flatten_shape)
|
| 251 |
+
shape = _tuple_to_list(shape)
|
| 252 |
+
idx_flat = len(flatten_shape) - 1
|
| 253 |
+
merged_shape = []
|
| 254 |
+
for dim in reversed(shape):
|
| 255 |
+
# Exact match
|
| 256 |
+
if dim == flatten_shape[idx_flat]:
|
| 257 |
+
merged_shape.append(dim)
|
| 258 |
+
idx_flat -= 1
|
| 259 |
+
# need group
|
| 260 |
+
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
| 261 |
+
residual = dim
|
| 262 |
+
group = []
|
| 263 |
+
while(residual > 1):
|
| 264 |
+
group.append(flatten_shape[idx_flat])
|
| 265 |
+
residual = residual // flatten_shape[idx_flat]
|
| 266 |
+
idx_flat -= 1
|
| 267 |
+
merged_shape.append(group[::-1])
|
| 268 |
+
else:
|
| 269 |
+
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
| 270 |
+
|
| 271 |
+
return merged_shape[::-1]
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class LayoutNode(NodeBase):
|
| 275 |
+
"""
|
| 276 |
+
Layout manipulation nodes
|
| 277 |
+
"""
|
| 278 |
+
fn_to_impl = {
|
| 279 |
+
"permute": PermutationImpl,
|
| 280 |
+
"reshape": ReshapeImpl
|
| 281 |
+
}
|
| 282 |
+
def __init__(self, name: str, fn, kwargs: dict) -> None:
|
| 283 |
+
super().__init__(name)
|
| 284 |
+
self.op = "layout"
|
| 285 |
+
self.fn = fn
|
| 286 |
+
self.kwargs = kwargs
|
| 287 |
+
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
|
| 288 |
+
|
| 289 |
+
def get_inverse_node(self):
|
| 290 |
+
inverse_node = deepcopy(self)
|
| 291 |
+
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
|
| 292 |
+
return inverse_node
|
| 293 |
+
|
| 294 |
+
def shape_propagation(self, input_node_metas):
|
| 295 |
+
if self._tensor is not None:
|
| 296 |
+
return
|
| 297 |
+
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
| 298 |
+
|
| 299 |
+
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
|
| 300 |
+
|
| 301 |
+
self._tensor = Tensor(
|
| 302 |
+
element=self.element_output,
|
| 303 |
+
shape=output_shape, layout_tag=LayoutType.RowMajor
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
return super().shape_propagation(input_node_metas)
|
| 307 |
+
|
| 308 |
+
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
|
| 309 |
+
"""
|
| 310 |
+
The store nodes has element_output = element_input
|
| 311 |
+
"""
|
| 312 |
+
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
| 313 |
+
self.element_output = input_node_metas[0].element_output
|
| 314 |
+
|
| 315 |
+
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
| 316 |
+
"""
|
| 317 |
+
Propagate the broadcast in the reversed topological order
|
| 318 |
+
"""
|
| 319 |
+
if self.tensor is None:
|
| 320 |
+
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
| 321 |
+
shape = self.tensor.shape
|
| 322 |
+
|
| 323 |
+
for child in input_node_metas:
|
| 324 |
+
self.underlying_impl.broadcast(shape, child)
|
| 325 |
+
|
| 326 |
+
def apply_to_user(self, usr_meta: NodeBase):
|
| 327 |
+
"""
|
| 328 |
+
Propagate the permutation to user nodes
|
| 329 |
+
"""
|
| 330 |
+
self.underlying_impl.apply_to_user(usr_meta)
|
| 331 |
+
|
| 332 |
+
def apply_to_input(self, input_meta: NodeBase):
|
| 333 |
+
"""
|
| 334 |
+
Propagate the permutation to input nodes
|
| 335 |
+
"""
|
| 336 |
+
self.underlying_impl.apply_to_input(input_meta)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Load nodes and implementations
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import ctypes
|
| 38 |
+
|
| 39 |
+
from cutlass_cppgen.backend.c_types import tuple_factory
|
| 40 |
+
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
|
| 41 |
+
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class LoadImplBase(ImplBase):
|
| 45 |
+
"""
|
| 46 |
+
Base class for load node implementations
|
| 47 |
+
"""
|
| 48 |
+
reserved_names = ["accum", "C"]
|
| 49 |
+
def __init__(self, node) -> None:
|
| 50 |
+
super().__init__(node)
|
| 51 |
+
self.element = node.element
|
| 52 |
+
self.element_output = node.element_output
|
| 53 |
+
self.stride = node.tensor.stride
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class AccumulatorImpl(LoadImplBase):
|
| 57 |
+
"""
|
| 58 |
+
Accumulator node implementation
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def match(node, problem_size: tuple):
|
| 63 |
+
return node.name == "accum" and node.tensor.shape == problem_size
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class LoadSrcImpl(LoadImplBase):
|
| 67 |
+
"""
|
| 68 |
+
Load C implementation
|
| 69 |
+
"""
|
| 70 |
+
@property
|
| 71 |
+
def name_camel(self) -> str:
|
| 72 |
+
return "TensorC"
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def argument_type_c(self):
|
| 76 |
+
stride_mnl = self.get_stride_mnl()
|
| 77 |
+
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 78 |
+
class _Argument(ctypes.Structure):
|
| 79 |
+
_fields_ = [
|
| 80 |
+
("ptr_C", ctypes.c_void_p),
|
| 81 |
+
("stride_C", tuple_type)
|
| 82 |
+
]
|
| 83 |
+
def __init__(self, ptr) -> None:
|
| 84 |
+
self.ptr_C = ptr
|
| 85 |
+
self.stride_C = tuple_type(stride_mnl)
|
| 86 |
+
|
| 87 |
+
return _Argument
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def match(node, problem_size: tuple):
|
| 91 |
+
return node.name == "C" and node.tensor.shape == problem_size
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class AuxLoadImpl(LoadImplBase):
|
| 95 |
+
"""
|
| 96 |
+
Load arbitrary tensor
|
| 97 |
+
"""
|
| 98 |
+
@property
|
| 99 |
+
def argument_type(self):
|
| 100 |
+
stride_mnl = self.get_stride_mnl()
|
| 101 |
+
name = self.name
|
| 102 |
+
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 103 |
+
element_type = self.element
|
| 104 |
+
class _Argument(ctypes.Structure):
|
| 105 |
+
_fields_ = [
|
| 106 |
+
("ptr_aux", ctypes.c_void_p),
|
| 107 |
+
("null_default", dtype2ctype[element_type]),
|
| 108 |
+
("dAux", tuple_type)
|
| 109 |
+
]
|
| 110 |
+
def __init__(self, kwargs) -> None:
|
| 111 |
+
ptr = kwargs[name]
|
| 112 |
+
self.ptr_aux = ptr
|
| 113 |
+
self.null_default = to_ctype_value(0, element_type)
|
| 114 |
+
self.dAux = tuple_type(stride_mnl)
|
| 115 |
+
|
| 116 |
+
return _Argument
|
| 117 |
+
|
| 118 |
+
@staticmethod
|
| 119 |
+
def match(node, problem_size: tuple):
|
| 120 |
+
if node.name in LoadImplBase.reserved_names:
|
| 121 |
+
return False
|
| 122 |
+
strideMN = node.tensor.stride[-2:]
|
| 123 |
+
if (strideMN[0] == 1 and strideMN[1] != 0 or
|
| 124 |
+
strideMN[0] != 0 and strideMN[1] == 1 ):
|
| 125 |
+
return True
|
| 126 |
+
else:
|
| 127 |
+
return False
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class RowBroadcastImpl(LoadImplBase):
|
| 131 |
+
"""
|
| 132 |
+
Broadcast a row vector
|
| 133 |
+
"""
|
| 134 |
+
def __init__(self, node) -> None:
|
| 135 |
+
super().__init__(node)
|
| 136 |
+
self.stride_dtype = "int"
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def argument_type(self):
|
| 140 |
+
stride_mnl = self.get_stride_mnl()
|
| 141 |
+
name = self.name
|
| 142 |
+
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 143 |
+
element_type = self.element
|
| 144 |
+
class _Argument(ctypes.Structure):
|
| 145 |
+
_fields_ = [
|
| 146 |
+
("ptr_row", ctypes.c_void_p),
|
| 147 |
+
("null_default", dtype2ctype[element_type]),
|
| 148 |
+
("dRow", tuple_type)
|
| 149 |
+
]
|
| 150 |
+
def __init__(self, kwargs) -> None:
|
| 151 |
+
ptr = kwargs[name]
|
| 152 |
+
self.ptr_row = ptr
|
| 153 |
+
self.null_default = to_ctype_value(0, element_type)
|
| 154 |
+
self.dRow = tuple_type(stride_mnl)
|
| 155 |
+
|
| 156 |
+
return _Argument
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def match(node, problem_size: tuple):
|
| 160 |
+
if node.name in LoadImplBase.reserved_names:
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
strideMN = node.tensor.stride[-2:]
|
| 164 |
+
if strideMN == (0, 1):
|
| 165 |
+
return True
|
| 166 |
+
else:
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class ColumnBroadcastImpl(LoadImplBase):
|
| 171 |
+
"""
|
| 172 |
+
Broadcast a column vector
|
| 173 |
+
"""
|
| 174 |
+
def __init__(self, node) -> None:
|
| 175 |
+
super().__init__(node)
|
| 176 |
+
self.stride_dtype = "int"
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def argument_type(self):
|
| 180 |
+
stride_mnl = self.get_stride_mnl()
|
| 181 |
+
name = self.name
|
| 182 |
+
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 183 |
+
element_type = self.element
|
| 184 |
+
class _Argument(ctypes.Structure):
|
| 185 |
+
_fields_ = [
|
| 186 |
+
("ptr_col", ctypes.c_void_p),
|
| 187 |
+
("null_default", dtype2ctype[element_type]),
|
| 188 |
+
("dCol", tuple_type)
|
| 189 |
+
]
|
| 190 |
+
def __init__(self, kwargs) -> None:
|
| 191 |
+
ptr = kwargs[name]
|
| 192 |
+
self.ptr_col = int(ptr)
|
| 193 |
+
self.null_default = to_ctype_value(0, element_type)
|
| 194 |
+
self.dCol = tuple_type(stride_mnl)
|
| 195 |
+
|
| 196 |
+
return _Argument
|
| 197 |
+
|
| 198 |
+
@staticmethod
|
| 199 |
+
def match(node, problem_size: tuple):
|
| 200 |
+
if node.name in LoadImplBase.reserved_names:
|
| 201 |
+
return False
|
| 202 |
+
|
| 203 |
+
strideMN = node.tensor.stride[-2:]
|
| 204 |
+
if strideMN == (1, 0):
|
| 205 |
+
return True
|
| 206 |
+
else:
|
| 207 |
+
return False
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class ScalarBroadcastImpl(LoadImplBase):
|
| 211 |
+
"""
|
| 212 |
+
Broadcast a scalar
|
| 213 |
+
"""
|
| 214 |
+
def __init__(self, node) -> None:
|
| 215 |
+
super().__init__(node)
|
| 216 |
+
self.stride_dtype = "int"
|
| 217 |
+
|
| 218 |
+
@property
|
| 219 |
+
def argument_type(self):
|
| 220 |
+
stride_mnl = self.get_stride_mnl()
|
| 221 |
+
name = self.name
|
| 222 |
+
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 223 |
+
element_type = self.element
|
| 224 |
+
|
| 225 |
+
if self.tensor.is_constant:
|
| 226 |
+
value = self.tensor.value
|
| 227 |
+
class _Argument(ctypes.Structure):
|
| 228 |
+
_fields_ = [
|
| 229 |
+
("scalars", dtype2ctype[element_type]),
|
| 230 |
+
("scalar_ptrs", ctypes.c_void_p),
|
| 231 |
+
("dScalar", tuple_type)
|
| 232 |
+
]
|
| 233 |
+
def __init__(self, kwargs) -> None:
|
| 234 |
+
self.scalars = to_ctype_value(value, element_type)
|
| 235 |
+
self.scalar_ptrs = 0
|
| 236 |
+
self.dScalar = tuple_type(stride_mnl)
|
| 237 |
+
|
| 238 |
+
else:
|
| 239 |
+
class _Argument(ctypes.Structure):
|
| 240 |
+
_fields_ = [
|
| 241 |
+
("scalars", dtype2ctype[element_type]),
|
| 242 |
+
("scalar_ptrs", ctypes.c_void_p),
|
| 243 |
+
("dScalar", tuple_type)
|
| 244 |
+
]
|
| 245 |
+
def __init__(self, kwargs) -> None:
|
| 246 |
+
scalar_or_ptr = kwargs[name]
|
| 247 |
+
if isinstance(scalar_or_ptr, float):
|
| 248 |
+
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
|
| 249 |
+
self.scalar_ptrs = 0
|
| 250 |
+
else:
|
| 251 |
+
self.scalar_ptrs = int(scalar_or_ptr)
|
| 252 |
+
|
| 253 |
+
self.dScalar = tuple_type(stride_mnl)
|
| 254 |
+
|
| 255 |
+
return _Argument
|
| 256 |
+
|
| 257 |
+
@staticmethod
|
| 258 |
+
def match(node, problem_size: tuple):
|
| 259 |
+
if node.name in LoadImplBase.reserved_names:
|
| 260 |
+
return False
|
| 261 |
+
|
| 262 |
+
strideMN = node.tensor.stride[-2:]
|
| 263 |
+
if strideMN == (0, 0):
|
| 264 |
+
return True
|
| 265 |
+
else:
|
| 266 |
+
return False
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class LoadNode(NodeBase):
|
| 270 |
+
"""
|
| 271 |
+
Load Node
|
| 272 |
+
"""
|
| 273 |
+
cnt = 0
|
| 274 |
+
possible_impls = [
|
| 275 |
+
AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
|
| 276 |
+
RowBroadcastImpl, ColumnBroadcastImpl,
|
| 277 |
+
ScalarBroadcastImpl
|
| 278 |
+
]
|
| 279 |
+
def __init__(self, name: str) -> None:
|
| 280 |
+
if name is None:
|
| 281 |
+
name = f"load{LoadNode.cnt}"
|
| 282 |
+
LoadNode.cnt += 1
|
| 283 |
+
super().__init__(name)
|
| 284 |
+
self.op = "load"
|
| 285 |
+
|
| 286 |
+
def type_propagation(self, *args, **kwargs):
|
| 287 |
+
"""
|
| 288 |
+
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
| 289 |
+
"""
|
| 290 |
+
if self.tensor is None:
|
| 291 |
+
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
| 292 |
+
|
| 293 |
+
self.element = self.tensor.element
|
| 294 |
+
self.element_output = self.tensor.element
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Base & visitor classes of DAGIR Nodes
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import ctypes
|
| 38 |
+
from re import sub
|
| 39 |
+
|
| 40 |
+
from cutlass_library import LayoutType
|
| 41 |
+
|
| 42 |
+
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
|
| 43 |
+
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TupleEmitter:
|
| 47 |
+
"""
|
| 48 |
+
Emit the cute tuple to C++ code
|
| 49 |
+
"""
|
| 50 |
+
def __init__(self, stride_dtype):
|
| 51 |
+
self.stride_dtype = stride_dtype
|
| 52 |
+
|
| 53 |
+
def emit(self, py_tuple):
|
| 54 |
+
if isinstance(py_tuple, int):
|
| 55 |
+
if py_tuple in [0, 1]:
|
| 56 |
+
return f"cute::Int<{py_tuple}>"
|
| 57 |
+
else:
|
| 58 |
+
return f"{self.stride_dtype}"
|
| 59 |
+
elif isinstance(py_tuple, tuple):
|
| 60 |
+
decl = "cute::Stride<"
|
| 61 |
+
for item in py_tuple:
|
| 62 |
+
decl += self.emit(item) + ", "
|
| 63 |
+
return decl[:-2] + ">"
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class ImplBase:
|
| 69 |
+
"""
|
| 70 |
+
Base class for Node Implementation
|
| 71 |
+
"""
|
| 72 |
+
def __init__(self, node) -> None:
|
| 73 |
+
self.node = node
|
| 74 |
+
self.name = node.name
|
| 75 |
+
self.tensor = node.tensor
|
| 76 |
+
self._type_decl = None
|
| 77 |
+
self.tuple_emitter = TupleEmitter("int64_t")
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def stride_dtype(self):
|
| 81 |
+
return self.tuple_emitter.stride_dtype
|
| 82 |
+
|
| 83 |
+
@stride_dtype.setter
|
| 84 |
+
def stride_dtype(self, stride_dtype):
|
| 85 |
+
self.tuple_emitter.stride_dtype = stride_dtype
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def match(node, problem_size: tuple):
|
| 89 |
+
"""
|
| 90 |
+
Match function used in get_underlying_impl
|
| 91 |
+
"""
|
| 92 |
+
raise NotImplementedError(f"The `match` function is not defined.")
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def argument_type(self):
|
| 96 |
+
"""
|
| 97 |
+
Default class for Argument Type
|
| 98 |
+
"""
|
| 99 |
+
class _Argument(ctypes.Structure):
|
| 100 |
+
_fields_ = []
|
| 101 |
+
|
| 102 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 103 |
+
pass
|
| 104 |
+
|
| 105 |
+
return _Argument
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def name_camel(self) -> str:
|
| 109 |
+
"""
|
| 110 |
+
Return the CamelCase name.
|
| 111 |
+
"""
|
| 112 |
+
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def stride_mnl(self):
|
| 116 |
+
"""
|
| 117 |
+
Typename StrideMNL
|
| 118 |
+
"""
|
| 119 |
+
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
| 120 |
+
return self.tuple_emitter.emit(stride)
|
| 121 |
+
|
| 122 |
+
def get_non_constant_stride(self, py_tuple):
|
| 123 |
+
if isinstance(py_tuple, int):
|
| 124 |
+
if py_tuple not in [0, 1]:
|
| 125 |
+
return py_tuple
|
| 126 |
+
else:
|
| 127 |
+
return None
|
| 128 |
+
non_constant_stride = []
|
| 129 |
+
for item in py_tuple:
|
| 130 |
+
item_out = self.get_non_constant_stride(item)
|
| 131 |
+
if item_out:
|
| 132 |
+
non_constant_stride.append(item_out)
|
| 133 |
+
return tuple(non_constant_stride)
|
| 134 |
+
|
| 135 |
+
def get_stride_mnl(self):
|
| 136 |
+
"""
|
| 137 |
+
Get the non-zero stride mnl. This is used in argument construction
|
| 138 |
+
"""
|
| 139 |
+
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
| 140 |
+
return stride
|
| 141 |
+
|
| 142 |
+
def get_smem_size(self, *args, **kwargs):
|
| 143 |
+
"""
|
| 144 |
+
Get the shared memory size and alignment of current node
|
| 145 |
+
"""
|
| 146 |
+
return (0, 1)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class NoOpImpl(ImplBase):
|
| 150 |
+
"""
|
| 151 |
+
The NoOpImpl does nothing but forward its input to users
|
| 152 |
+
"""
|
| 153 |
+
def __init__(self, node) -> None:
|
| 154 |
+
super().__init__(node)
|
| 155 |
+
|
| 156 |
+
@staticmethod
|
| 157 |
+
def match(node, problem_size: tuple):
|
| 158 |
+
if node.op == "store":
|
| 159 |
+
# Store that is not output is a No OP
|
| 160 |
+
return not node.is_output
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class NodeBase:
|
| 164 |
+
"""
|
| 165 |
+
Base class of DAG Node
|
| 166 |
+
"""
|
| 167 |
+
def __init__(self, name: str) -> None:
|
| 168 |
+
self.name = name
|
| 169 |
+
self.underlying_impl = None
|
| 170 |
+
|
| 171 |
+
self._tensor = None
|
| 172 |
+
|
| 173 |
+
# Whether the node is disabled for emit
|
| 174 |
+
self.disabled = False
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def name_camel(self) -> str:
|
| 178 |
+
"""
|
| 179 |
+
Return the CamelCase name.
|
| 180 |
+
"""
|
| 181 |
+
return self.underlying_impl.name_camel
|
| 182 |
+
|
| 183 |
+
@property
|
| 184 |
+
def tensor(self) -> Tensor:
|
| 185 |
+
"""
|
| 186 |
+
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
| 187 |
+
"""
|
| 188 |
+
return self._tensor
|
| 189 |
+
|
| 190 |
+
@tensor.setter
|
| 191 |
+
def tensor(self, kwargs):
|
| 192 |
+
"""
|
| 193 |
+
Setting the tensor
|
| 194 |
+
"""
|
| 195 |
+
self._tensor = Tensor(**kwargs)
|
| 196 |
+
|
| 197 |
+
#
|
| 198 |
+
# Helper functions for type/shape propagation
|
| 199 |
+
#
|
| 200 |
+
|
| 201 |
+
def shape_propagation(self, input_node_metas):
|
| 202 |
+
"""
|
| 203 |
+
Infer shape from input nodes
|
| 204 |
+
General Broadcasting Rules from NumPy
|
| 205 |
+
When operating on two arrays, we compare their shapes element-wise.
|
| 206 |
+
It starts with the trailing (i.e. rightmost) dimension and works its
|
| 207 |
+
way left. Two dimensions are compatible when
|
| 208 |
+
1. they are equal
|
| 209 |
+
2. one of them is 1
|
| 210 |
+
"""
|
| 211 |
+
if self._tensor is not None:
|
| 212 |
+
return
|
| 213 |
+
|
| 214 |
+
shape = None
|
| 215 |
+
for src in input_node_metas:
|
| 216 |
+
src_shape = src.tensor.shape
|
| 217 |
+
if shape is None:
|
| 218 |
+
shape = src_shape
|
| 219 |
+
else:
|
| 220 |
+
len_difference = len(shape) - len(src_shape)
|
| 221 |
+
if len_difference > 0:
|
| 222 |
+
for _ in range(len_difference):
|
| 223 |
+
src_shape = [1, ] + list(src_shape)
|
| 224 |
+
elif len_difference < 0:
|
| 225 |
+
for _ in range(-len_difference):
|
| 226 |
+
shape = [1, ] + list(shape)
|
| 227 |
+
broadcasted_shape = []
|
| 228 |
+
# Infer broadcast shape
|
| 229 |
+
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
|
| 230 |
+
if shape_dim == 1:
|
| 231 |
+
broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
|
| 232 |
+
elif src_dim == 1:
|
| 233 |
+
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
| 234 |
+
elif shape_dim == src_dim:
|
| 235 |
+
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
| 236 |
+
else:
|
| 237 |
+
error_msg = "Dimension mismatch between "
|
| 238 |
+
for src_ in input_node_metas:
|
| 239 |
+
error_msg += f"{src_.name}{src_.tensor.shape}, "
|
| 240 |
+
error_msg = error_msg[:-2] + "."
|
| 241 |
+
raise RuntimeError(error_msg)
|
| 242 |
+
shape = tuple(broadcasted_shape)
|
| 243 |
+
|
| 244 |
+
self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
|
| 245 |
+
|
| 246 |
+
def type_propagation(self, *args, **kwargs):
|
| 247 |
+
"""
|
| 248 |
+
Each node is associated with two data types: `element` and `element_output`.
|
| 249 |
+
The `element_output` is the type of return array of the node. The `element`
|
| 250 |
+
has specific meaning for different node types.
|
| 251 |
+
* Load Node: data type of tensor in gmem
|
| 252 |
+
* Compute Node: element compute
|
| 253 |
+
* Store Node: data type of tensor in gmem
|
| 254 |
+
This function must be overloaded in the derived classes
|
| 255 |
+
"""
|
| 256 |
+
raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
|
| 257 |
+
|
| 258 |
+
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
| 259 |
+
"""
|
| 260 |
+
Propagate the broadcast in the reversed topological order.
|
| 261 |
+
For example:
|
| 262 |
+
C[l, m, n] = A[m, 1] + B[l, m, n]
|
| 263 |
+
After the broadcast propagation, it will be come
|
| 264 |
+
C[l, m, n] = A[l, m, n] + B[l, m, n]
|
| 265 |
+
and each tensor will have a proper stride accessing the underlying tensor
|
| 266 |
+
"""
|
| 267 |
+
if self.tensor is None:
|
| 268 |
+
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
| 269 |
+
for child in input_node_metas:
|
| 270 |
+
child.tensor.broadcast(self.tensor.shape)
|
| 271 |
+
|
| 272 |
+
def get_underlying_impl(self, problem_size: tuple):
|
| 273 |
+
"""
|
| 274 |
+
Get the underlying implementation of the current node.
|
| 275 |
+
"""
|
| 276 |
+
if self.tensor is None:
|
| 277 |
+
raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
|
| 278 |
+
|
| 279 |
+
for impl in self.possible_impls:
|
| 280 |
+
if impl.match(self, problem_size):
|
| 281 |
+
self.underlying_impl = impl(self)
|
| 282 |
+
break
|
| 283 |
+
|
| 284 |
+
if self.underlying_impl is None:
|
| 285 |
+
raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
|
| 286 |
+
|
| 287 |
+
#
|
| 288 |
+
# Visitor Nodes & Impls
|
| 289 |
+
#
|
| 290 |
+
|
| 291 |
+
class TopoVisitorImpl(ImplBase):
|
| 292 |
+
"""
|
| 293 |
+
Impl for topological visitor
|
| 294 |
+
"""
|
| 295 |
+
def __init__(self, node) -> None:
|
| 296 |
+
super().__init__(node.output_node)
|
| 297 |
+
self.name = node.name
|
| 298 |
+
self.element_output = node.output_node.element_output
|
| 299 |
+
|
| 300 |
+
class TopoVisitorNode(NodeBase):
|
| 301 |
+
def __init__(self, name: str, subgraph, output_node) -> None:
|
| 302 |
+
super().__init__(name)
|
| 303 |
+
self.subgraph = subgraph
|
| 304 |
+
self.output_node = output_node
|
| 305 |
+
self.op = "dag"
|
| 306 |
+
self.underlying_impl = TopoVisitorImpl(self)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Store node and implementations
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import ctypes
|
| 38 |
+
|
| 39 |
+
from cutlass_library import DataType
|
| 40 |
+
|
| 41 |
+
from cutlass_cppgen.backend.c_types import tuple_factory
|
| 42 |
+
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
|
| 43 |
+
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
|
| 44 |
+
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
| 45 |
+
from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class StoreImplBase(ImplBase):
|
| 49 |
+
"""
|
| 50 |
+
Base class for store node implementation
|
| 51 |
+
"""
|
| 52 |
+
reserved_names = ["D"]
|
| 53 |
+
def __init__(self, node) -> None:
|
| 54 |
+
super().__init__(node)
|
| 55 |
+
self.element = node.element
|
| 56 |
+
self.element_output = node.element_output
|
| 57 |
+
self.stride = node.store_tensor.stride
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class StoreDImpl(StoreImplBase):
|
| 61 |
+
"""
|
| 62 |
+
Store D implementation
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def argument_type_d(self):
|
| 67 |
+
stride_mnl = self.get_stride_mnl()
|
| 68 |
+
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 69 |
+
class _Argument(ctypes.Structure):
|
| 70 |
+
_fields_ = [
|
| 71 |
+
("ptr_D", ctypes.c_void_p),
|
| 72 |
+
("stride_D", tuple_type)
|
| 73 |
+
]
|
| 74 |
+
def __init__(self, ptr: int) -> None:
|
| 75 |
+
self.ptr_D = ptr
|
| 76 |
+
self.stride_D = tuple_type(stride_mnl)
|
| 77 |
+
|
| 78 |
+
return _Argument
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def match(node, problem_size: tuple):
|
| 82 |
+
if node.name == "D" and node.store_tensor.shape == problem_size:
|
| 83 |
+
return True
|
| 84 |
+
return False
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class AuxStoreImpl(StoreImplBase):
|
| 88 |
+
def __init__(self, node) -> None:
|
| 89 |
+
super().__init__(node)
|
| 90 |
+
self.round_style = FloatRoundStyle.ToNearest
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def argument_type(self):
|
| 94 |
+
stride_mnl = self.get_stride_mnl()
|
| 95 |
+
name = self.name
|
| 96 |
+
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 97 |
+
class _Argument(ctypes.Structure):
|
| 98 |
+
_fields_ = [
|
| 99 |
+
("ptr_aux", ctypes.c_void_p),
|
| 100 |
+
("dAux", tuple_type)
|
| 101 |
+
]
|
| 102 |
+
def __init__(self, kwargs) -> None:
|
| 103 |
+
ptr = kwargs[name]
|
| 104 |
+
self.ptr_aux = ptr
|
| 105 |
+
self.dAux = tuple_type(stride_mnl)
|
| 106 |
+
|
| 107 |
+
return _Argument
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def match(node, problem_size: tuple):
|
| 111 |
+
if not node.is_output:
|
| 112 |
+
return False
|
| 113 |
+
if node.name in StoreImplBase.reserved_names:
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
strideMN = node.store_tensor.stride[-2:]
|
| 117 |
+
if (strideMN[0] == 1 and strideMN[1] != 0 or
|
| 118 |
+
strideMN[0] != 0 and strideMN[1] == 1 ):
|
| 119 |
+
return True
|
| 120 |
+
else:
|
| 121 |
+
return False
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class ReductionImplBase(StoreImplBase):
|
| 125 |
+
def __init__(self, node) -> None:
|
| 126 |
+
super().__init__(node)
|
| 127 |
+
self.element = node.store_tensor.element
|
| 128 |
+
self.element_compute = node.element_compute
|
| 129 |
+
self.reg_reduce_fn = self.node.reg_reduce_fn
|
| 130 |
+
self.gmem_reduce_fn = self.node.gmem_reduce_fn
|
| 131 |
+
self.round_style = node.round_style
|
| 132 |
+
self.stride_dtype = "int"
|
| 133 |
+
|
| 134 |
+
def get_reduce_identity(self):
|
| 135 |
+
"""
|
| 136 |
+
Return the reduction identity of the current reduce_fn
|
| 137 |
+
"""
|
| 138 |
+
maxes = {
|
| 139 |
+
DataType.f32: (2 ** 31) - 1,
|
| 140 |
+
DataType.f16: (2 ** 15),
|
| 141 |
+
DataType.s32: (2 ** 31) - 1,
|
| 142 |
+
DataType.s8: (2 ** 7) - 1
|
| 143 |
+
}
|
| 144 |
+
mins = {
|
| 145 |
+
DataType.f32: -maxes[DataType.f32],
|
| 146 |
+
DataType.f16: -maxes[DataType.f16],
|
| 147 |
+
DataType.s32: -maxes[DataType.s32],
|
| 148 |
+
DataType.s8: -maxes[DataType.s8]
|
| 149 |
+
}
|
| 150 |
+
if self.reg_reduce_fn == FunctionalOp.Maximum:
|
| 151 |
+
if self.element_compute not in mins:
|
| 152 |
+
raise Exception(f"No min entry for data type {self.element_compute}")
|
| 153 |
+
return to_ctype_value(mins[self.element_compute], self.element_compute)
|
| 154 |
+
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
|
| 155 |
+
return to_ctype_value(1., self.element_compute)
|
| 156 |
+
elif self.reg_reduce_fn == FunctionalOp.Minimum:
|
| 157 |
+
if self.element_compute not in maxes:
|
| 158 |
+
raise Exception(f"No max entry for data type {self.element_compute}")
|
| 159 |
+
return to_ctype_value(maxes[self.element_compute], self.element_compute)
|
| 160 |
+
else:
|
| 161 |
+
return to_ctype_value(0., self.element_compute)
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def argument_type(self):
|
| 165 |
+
self.get_reduce_identity()
|
| 166 |
+
stride_mnl = self.get_stride_mnl()
|
| 167 |
+
name = self.name
|
| 168 |
+
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 169 |
+
element_compute = self.element_compute
|
| 170 |
+
reduce_identity = self.get_reduce_identity()
|
| 171 |
+
class _Argument(ctypes.Structure):
|
| 172 |
+
_fields_ = [
|
| 173 |
+
("ptr", ctypes.c_void_p),
|
| 174 |
+
("reduce_identity", dtype2ctype[element_compute]),
|
| 175 |
+
("dMNL", tuple_type)
|
| 176 |
+
]
|
| 177 |
+
def __init__(self, kwargs) -> None:
|
| 178 |
+
ptr = kwargs[name]
|
| 179 |
+
self.ptr = ptr
|
| 180 |
+
self.reduce_identity = reduce_identity
|
| 181 |
+
self.dMNL = tuple_type(stride_mnl)
|
| 182 |
+
|
| 183 |
+
return _Argument
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class ColumnReductionImpl(ReductionImplBase):
|
| 187 |
+
|
| 188 |
+
@staticmethod
|
| 189 |
+
def match(node, problem_size: tuple):
|
| 190 |
+
if not node.is_output:
|
| 191 |
+
return False
|
| 192 |
+
if node.name in StoreImplBase.reserved_names:
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
+
strideMN = node.store_tensor.stride[-2:]
|
| 196 |
+
if strideMN == (1, 0):
|
| 197 |
+
return True
|
| 198 |
+
else:
|
| 199 |
+
return False
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class RowReductionImpl(ReductionImplBase):
|
| 203 |
+
|
| 204 |
+
@staticmethod
|
| 205 |
+
def match(node, problem_size: tuple):
|
| 206 |
+
if not node.is_output:
|
| 207 |
+
return False
|
| 208 |
+
if node.name in StoreImplBase.reserved_names:
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
strideMN = node.store_tensor.stride[-2:]
|
| 212 |
+
if strideMN == (0, 1):
|
| 213 |
+
return True
|
| 214 |
+
else:
|
| 215 |
+
return False
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class ScalarReductionImpl(ReductionImplBase):
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def match(node, problem_size: tuple):
|
| 222 |
+
if not node.is_output:
|
| 223 |
+
return False
|
| 224 |
+
if node.name in StoreImplBase.reserved_names:
|
| 225 |
+
return False
|
| 226 |
+
|
| 227 |
+
strideMN = node.store_tensor.stride[-2:]
|
| 228 |
+
if strideMN == (0, 0):
|
| 229 |
+
return True
|
| 230 |
+
else:
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class StoreNode(NodeBase):
|
| 235 |
+
"""
|
| 236 |
+
Store node
|
| 237 |
+
"""
|
| 238 |
+
possible_impls = [
|
| 239 |
+
AuxStoreImpl, RowReductionImpl,
|
| 240 |
+
ColumnReductionImpl, ScalarReductionImpl,
|
| 241 |
+
NoOpImpl, StoreDImpl
|
| 242 |
+
]
|
| 243 |
+
def __init__(self, name: str) -> None:
|
| 244 |
+
super().__init__(name)
|
| 245 |
+
self.op = "store"
|
| 246 |
+
self.is_output = False
|
| 247 |
+
self._store_tensor = None
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def store_tensor(self) -> Tensor:
|
| 251 |
+
"""
|
| 252 |
+
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
| 253 |
+
"""
|
| 254 |
+
return self._store_tensor
|
| 255 |
+
|
| 256 |
+
@store_tensor.setter
|
| 257 |
+
def store_tensor(self, kwargs):
|
| 258 |
+
"""
|
| 259 |
+
Setting the tensor
|
| 260 |
+
"""
|
| 261 |
+
self._store_tensor = Tensor(**kwargs)
|
| 262 |
+
|
| 263 |
+
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
|
| 264 |
+
"""
|
| 265 |
+
The store nodes has element_output = element_input
|
| 266 |
+
"""
|
| 267 |
+
if self.is_output:
|
| 268 |
+
if self.store_tensor is None:
|
| 269 |
+
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
|
| 270 |
+
self.element = self.store_tensor.element
|
| 271 |
+
assert len(input_node_metas) == 1, "Store node can only have one input node"
|
| 272 |
+
self.element_output = input_node_metas[0].element_output
|
| 273 |
+
|
| 274 |
+
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
| 275 |
+
super().broadcast_propagation(input_node_metas)
|
| 276 |
+
if self.is_output:
|
| 277 |
+
self._store_tensor.broadcast(self.tensor.shape)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
High-level class for tensor
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from cutlass_library import LayoutType
|
| 38 |
+
|
| 39 |
+
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
|
| 40 |
+
Layout,
|
| 41 |
+
broadcast,
|
| 42 |
+
canonicalization,
|
| 43 |
+
permutation,
|
| 44 |
+
reshape,
|
| 45 |
+
_reverse_tuple
|
| 46 |
+
)
|
| 47 |
+
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Tensor:
|
| 51 |
+
"""
|
| 52 |
+
The tensor abstracts the data type
|
| 53 |
+
"""
|
| 54 |
+
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
|
| 55 |
+
if element is not None and tensor is not None:
|
| 56 |
+
raise Exception(f"Must not specify both element and tensor")
|
| 57 |
+
elif shape is not None and tensor is not None:
|
| 58 |
+
raise Exception(f"Must not specify both shape and tensor")
|
| 59 |
+
elif layout_tag is not None and tensor is not None:
|
| 60 |
+
raise Exception(f"Must not specify both layout_tag and tensor")
|
| 61 |
+
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
|
| 62 |
+
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
|
| 63 |
+
elif stride is not None and tensor is not None:
|
| 64 |
+
raise Exception(f"Must not specify both stride and tensor")
|
| 65 |
+
elif stride is not None and layout_tag is not None:
|
| 66 |
+
raise Exception(f"Must not specify layout_tag when stride is provided")
|
| 67 |
+
|
| 68 |
+
if isinstance(tensor, Tensor):
|
| 69 |
+
# Directly copy all the attributes
|
| 70 |
+
self.__dict__.update(vars(tensor))
|
| 71 |
+
else:
|
| 72 |
+
if tensor is None:
|
| 73 |
+
self.element = library_type(element)
|
| 74 |
+
else:
|
| 75 |
+
self.element, layout_tag = get_datatype_and_layout(tensor)
|
| 76 |
+
shape = get_tensor_shape(tensor)
|
| 77 |
+
if stride is not None:
|
| 78 |
+
self.layout = Layout(shape[::-1], stride[::-1])
|
| 79 |
+
else:
|
| 80 |
+
if layout_tag == LayoutType.RowMajor:
|
| 81 |
+
self.layout = Layout(shape[::-1])
|
| 82 |
+
elif layout_tag == LayoutType.ColumnMajor:
|
| 83 |
+
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
| 84 |
+
self.layout = canonicalization(self.layout)
|
| 85 |
+
|
| 86 |
+
self.is_constant = is_constant
|
| 87 |
+
# Save the tensor value if it is constant
|
| 88 |
+
if is_constant and tensor is not None:
|
| 89 |
+
self.value = tensor
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def shape(self):
|
| 93 |
+
"""
|
| 94 |
+
Returns the RowMajor layout shape
|
| 95 |
+
"""
|
| 96 |
+
return _reverse_tuple(self.layout.shape)
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def stride(self):
|
| 100 |
+
"""
|
| 101 |
+
Returns the RowMajor layout stride
|
| 102 |
+
"""
|
| 103 |
+
return _reverse_tuple(self.layout.stride)
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def rank(self):
|
| 107 |
+
"""
|
| 108 |
+
Returns the rank of the tensor
|
| 109 |
+
"""
|
| 110 |
+
return len(self.shape)
|
| 111 |
+
|
| 112 |
+
#
|
| 113 |
+
# Layout Algorithms
|
| 114 |
+
#
|
| 115 |
+
|
| 116 |
+
def broadcast(self, shape):
|
| 117 |
+
"""
|
| 118 |
+
Broadcast self.layout to shape
|
| 119 |
+
"""
|
| 120 |
+
assert isinstance(shape, tuple)
|
| 121 |
+
self.layout = broadcast(self.layout, _reverse_tuple(shape))
|
| 122 |
+
|
| 123 |
+
def reshape(self, shape):
|
| 124 |
+
"""
|
| 125 |
+
Reshape self.layout to shape
|
| 126 |
+
"""
|
| 127 |
+
assert isinstance(shape, tuple)
|
| 128 |
+
reverse_shape = _reverse_tuple(shape)
|
| 129 |
+
self.layout = reshape(self.layout, reverse_shape)
|
| 130 |
+
|
| 131 |
+
def permute(self, indices):
|
| 132 |
+
"""
|
| 133 |
+
Permute self.layout according to indices
|
| 134 |
+
"""
|
| 135 |
+
length = len(indices)
|
| 136 |
+
indices = [length - idx - 1 for idx in indices]
|
| 137 |
+
self.layout = permutation(self.layout, indices[::-1])
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer
|
| 34 |
+
from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType
|
| 35 |
+
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
| 36 |
+
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
| 37 |
+
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
| 38 |
+
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
| 39 |
+
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager
|
| 40 |
+
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
| 41 |
+
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
| 42 |
+
from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import subprocess
|
| 35 |
+
|
| 36 |
+
from cutlass_library import DataTypeTag
|
| 37 |
+
|
| 38 |
+
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
_COLOR_MAP = {
|
| 42 |
+
"load": '"AliceBlue"',
|
| 43 |
+
"compute": "LemonChiffon1",
|
| 44 |
+
"accumulator": "LightGrey",
|
| 45 |
+
"store": "PowderBlue",
|
| 46 |
+
"layout": "lightseagreen",
|
| 47 |
+
"dag": "darkorange"
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class EVTGraphDrawer:
|
| 52 |
+
"""
|
| 53 |
+
Visualize a EVT DAGIR with graphviz
|
| 54 |
+
"""
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
graph: DAGIR,
|
| 58 |
+
name: str
|
| 59 |
+
):
|
| 60 |
+
self._name = name
|
| 61 |
+
self._dot_graphs = {}
|
| 62 |
+
|
| 63 |
+
self._dot_graphs[name] = self._to_dot(graph, name)
|
| 64 |
+
|
| 65 |
+
def _get_node_style(self, node):
|
| 66 |
+
template = {
|
| 67 |
+
"shape": "record",
|
| 68 |
+
"fillcolor": "#CAFFE3",
|
| 69 |
+
"style": '"filled,rounded"',
|
| 70 |
+
"fontcolor": "#000000",
|
| 71 |
+
}
|
| 72 |
+
if node.op in _COLOR_MAP:
|
| 73 |
+
template["fillcolor"] = _COLOR_MAP[node.op]
|
| 74 |
+
else:
|
| 75 |
+
raise NotImplementedError("unknown node op")
|
| 76 |
+
if node.disabled:
|
| 77 |
+
template["fontcolor"] = "grey"
|
| 78 |
+
template["fillcolor"] = "white"
|
| 79 |
+
return template
|
| 80 |
+
|
| 81 |
+
def _get_node_label(self, node):
|
| 82 |
+
label = "{" + f"name={node.name}|op={node.op}"
|
| 83 |
+
if node.op == "layout":
|
| 84 |
+
label += f"|fn={node.fn.__name__}"
|
| 85 |
+
for key in node.kwargs:
|
| 86 |
+
label += f"|{key}={node.kwargs[key]}"
|
| 87 |
+
if node.underlying_impl is not None:
|
| 88 |
+
label += f"|impl={type(node.underlying_impl).__name__}"
|
| 89 |
+
if node.op == "load":
|
| 90 |
+
label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
|
| 91 |
+
elif node.op == "compute":
|
| 92 |
+
label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
| 93 |
+
elif node.op == "store":
|
| 94 |
+
label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
| 95 |
+
elif node.op == "dag":
|
| 96 |
+
label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
| 97 |
+
if node.tensor is not None:
|
| 98 |
+
shape = node.tensor.shape
|
| 99 |
+
stride = node.tensor.stride
|
| 100 |
+
label += f"|shape={shape}|stride={stride}"
|
| 101 |
+
|
| 102 |
+
if hasattr(node, "store_tensor"):
|
| 103 |
+
if node.store_tensor is not None:
|
| 104 |
+
store_shape = node.store_tensor.shape
|
| 105 |
+
store_stride = node.store_tensor.stride
|
| 106 |
+
label += f"|store_shape={store_shape}|stride_stride={store_stride}"
|
| 107 |
+
|
| 108 |
+
label += "}"
|
| 109 |
+
return label
|
| 110 |
+
|
| 111 |
+
def _to_dot(
|
| 112 |
+
self,
|
| 113 |
+
graph: DAGIR,
|
| 114 |
+
name: str
|
| 115 |
+
):
|
| 116 |
+
import pydot
|
| 117 |
+
dot_graph = pydot.Dot(name, randir="TB")
|
| 118 |
+
for node in graph.nodes_meta:
|
| 119 |
+
style = self._get_node_style(node)
|
| 120 |
+
label = self._get_node_label(node)
|
| 121 |
+
dot_node = pydot.Node(
|
| 122 |
+
node.name, label=label, **style
|
| 123 |
+
)
|
| 124 |
+
dot_graph.add_node(dot_node)
|
| 125 |
+
if node.op == "dag":
|
| 126 |
+
dot_subgraph = self._to_dot(node.subgraph, name=node.name)
|
| 127 |
+
self._dot_graphs[node.name] = dot_subgraph
|
| 128 |
+
|
| 129 |
+
# Add edges
|
| 130 |
+
for src, dst in graph.edges:
|
| 131 |
+
weight = graph.get_edge_weight(src, dst)
|
| 132 |
+
dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
|
| 133 |
+
|
| 134 |
+
return dot_graph
|
| 135 |
+
|
| 136 |
+
def get_dot_graph(self) -> pydot.Dot:
|
| 137 |
+
return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]
|
| 138 |
+
|
| 139 |
+
def get_dot_graph_by_name(self, name) -> pydot.Dot:
|
| 140 |
+
return self._dot_graphs[name]
|
| 141 |
+
|
| 142 |
+
def get_main_dot_graph(self) -> pydot.Dot:
|
| 143 |
+
return self._dot_graphs[self._name]
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Construct the epilogue visitor argument type
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from cutlass_cppgen.backend.c_types import visitor_factory
|
| 38 |
+
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode
|
| 39 |
+
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
| 40 |
+
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
| 41 |
+
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 42 |
+
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
| 43 |
+
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class PassGetArgumentType(EVTPassBase):
|
| 47 |
+
"""
|
| 48 |
+
Construct the epilogue visitor argument type
|
| 49 |
+
"""
|
| 50 |
+
dependencies = [
|
| 51 |
+
PassShapeTypePropagation, # The Layout of all nodes must be set
|
| 52 |
+
PassDAG2Tree, # The type of each node must be set
|
| 53 |
+
PassGetImpl # The DAG subgraphs must be set
|
| 54 |
+
]
|
| 55 |
+
|
| 56 |
+
def requires(self) -> None:
|
| 57 |
+
# Check "D" is in the node list
|
| 58 |
+
if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")):
|
| 59 |
+
raise SyntaxError(
|
| 60 |
+
"Sm90+ EVT requires the epilogue to have a returned tensor D, "
|
| 61 |
+
"but the variable 'D' is not found in the return values.")
|
| 62 |
+
|
| 63 |
+
def call(self):
|
| 64 |
+
nodes = self.dag_ir.nodes_topological_order()
|
| 65 |
+
self.argument_types = {}
|
| 66 |
+
for node in nodes:
|
| 67 |
+
meta = self.dag_ir.get_node_meta(node)
|
| 68 |
+
if not meta.disabled:
|
| 69 |
+
self.argument_types[node] = meta.underlying_impl.argument_type
|
| 70 |
+
if node == "D" and cc_map[self.cc] in [90, 100]:
|
| 71 |
+
continue
|
| 72 |
+
if isinstance(meta, TopoVisitorNode):
|
| 73 |
+
self.get_dag_argument_type(node)
|
| 74 |
+
else:
|
| 75 |
+
self.get_evt_argument_type(node)
|
| 76 |
+
|
| 77 |
+
self.cc_specific_method(self.set_argument_type)()
|
| 78 |
+
|
| 79 |
+
def get_evt_argument_type(self, node):
|
| 80 |
+
# Sort the input nodes by edge weight
|
| 81 |
+
input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)]
|
| 82 |
+
if len(input_types) > 0:
|
| 83 |
+
self.argument_types[node] = visitor_factory(
|
| 84 |
+
input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,])
|
| 85 |
+
|
| 86 |
+
def get_dag_argument_type(self, node):
|
| 87 |
+
meta = self.dag_ir.get_node_meta(node)
|
| 88 |
+
subgraph = meta.subgraph
|
| 89 |
+
subgraph_nodes = subgraph.nodes_topological_order()
|
| 90 |
+
# Visit the unvisited nodes in subgraph
|
| 91 |
+
for n in subgraph_nodes:
|
| 92 |
+
m = subgraph.get_node_meta(n)
|
| 93 |
+
if m.disabled:
|
| 94 |
+
continue
|
| 95 |
+
else:
|
| 96 |
+
self.argument_types[n] = m.underlying_impl.argument_type
|
| 97 |
+
input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
|
| 98 |
+
if len(input_types) > 0:
|
| 99 |
+
self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1])
|
| 100 |
+
|
| 101 |
+
def set_argument_type(self):
|
| 102 |
+
pass
|
| 103 |
+
|
| 104 |
+
def sm90_set_argument_type(self):
|
| 105 |
+
self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]]
|
| 106 |
+
# Get the tensorD argument type
|
| 107 |
+
self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d
|
| 108 |
+
|
| 109 |
+
# Get the tensorC argument type
|
| 110 |
+
if self.dag_ir.has_node("C"):
|
| 111 |
+
self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c
|
| 112 |
+
else:
|
| 113 |
+
self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
|
| 114 |
+
|
| 115 |
+
def sm100_set_argument_type(self):
|
| 116 |
+
self.sm90_set_argument_type()
|
| 117 |
+
|
| 118 |
+
def sm80_set_argument_type(self):
|
| 119 |
+
nodes = self.dag_ir.nodes_topological_order()
|
| 120 |
+
self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
|
| 35 |
+
by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
from copy import deepcopy
|
| 39 |
+
|
| 40 |
+
from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode
|
| 41 |
+
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
| 42 |
+
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 43 |
+
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class PassDAG2Tree(EVTPassBase):
|
| 47 |
+
"""
|
| 48 |
+
Convert the DAG IR to Tree by fusing subgraphs
|
| 49 |
+
"""
|
| 50 |
+
dependencies = [
|
| 51 |
+
PassShapeTypePropagation,
|
| 52 |
+
PassGetImpl
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
def call(self):
|
| 56 |
+
# Step 1: find the nodes that have multiple parents
|
| 57 |
+
multi_parent_nodes = []
|
| 58 |
+
|
| 59 |
+
for node in self.dag_ir.nodes_topological_order():
|
| 60 |
+
if self.dag_ir.out_degree(node) > 1:
|
| 61 |
+
multi_parent_nodes.append(node)
|
| 62 |
+
# Step 2: find the lowest common ancestor (LCA) of all its parents
|
| 63 |
+
for node in multi_parent_nodes:
|
| 64 |
+
# A multi-parent node could be already fused by the previous node
|
| 65 |
+
if not self.dag_ir.has_node(node):
|
| 66 |
+
continue
|
| 67 |
+
# A node uncovered by the previous fusions can have out degree change
|
| 68 |
+
# Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
|
| 69 |
+
# Case 2: it has more than one edges to the previously fused subgraph, degree drops
|
| 70 |
+
if self.dag_ir.out_degree(node) <= 1:
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
# Otherwise, the node still
|
| 74 |
+
reachable_nodes = []
|
| 75 |
+
# Complexity: O(Dout*N)
|
| 76 |
+
for parent in self.dag_ir.get_users(node):
|
| 77 |
+
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
|
| 78 |
+
# get the common reachable objects
|
| 79 |
+
common_items = set.intersection(*reachable_nodes)
|
| 80 |
+
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
|
| 81 |
+
|
| 82 |
+
lca = None
|
| 83 |
+
# If common ancestor exists, find the lowest one
|
| 84 |
+
if len(common_items) > 0:
|
| 85 |
+
topo_order = self.dag_ir.nodes_topological_order()
|
| 86 |
+
topo_idx = -1
|
| 87 |
+
for item in common_items:
|
| 88 |
+
if lca is None:
|
| 89 |
+
lca = item
|
| 90 |
+
topo_idx = topo_order.index(item)
|
| 91 |
+
else:
|
| 92 |
+
if topo_idx > topo_order.index(item):
|
| 93 |
+
lca = item
|
| 94 |
+
topo_idx = topo_order.index(item)
|
| 95 |
+
else:
|
| 96 |
+
# there is no common ancestor for all the parents, we pack all the reachable
|
| 97 |
+
# nodes into a single DAG node as a fallback. The lca should be the input node of
|
| 98 |
+
# one of the output nodes with out_degree = 0
|
| 99 |
+
potential_output_nodes = []
|
| 100 |
+
for node in node_to_fuse:
|
| 101 |
+
if self.dag_ir.out_degree(node) == 0:
|
| 102 |
+
potential_output_nodes.append(node)
|
| 103 |
+
if len(potential_output_nodes) == 0:
|
| 104 |
+
raise RuntimeError(f"No output node with out degree = 0 found.")
|
| 105 |
+
|
| 106 |
+
output_node = None
|
| 107 |
+
if (self.dag_ir.cc >= 90):
|
| 108 |
+
# For SM90+, the lca should be the input node of D
|
| 109 |
+
if (not self.dag_ir.has_node("D")):
|
| 110 |
+
raise RuntimeError(f"D is not a node in the DAG IR.")
|
| 111 |
+
output_node = "D"
|
| 112 |
+
else:
|
| 113 |
+
output_node = potential_output_nodes[0]
|
| 114 |
+
|
| 115 |
+
if (output_node is None):
|
| 116 |
+
raise RuntimeError(f"No output node found.")
|
| 117 |
+
lca = self.dag_ir.get_all_inputs(output_node)[0]
|
| 118 |
+
node_to_fuse.remove(output_node)
|
| 119 |
+
|
| 120 |
+
# The lca is the output node of the DAG node
|
| 121 |
+
# Get the nodes to be fused
|
| 122 |
+
node_to_fuse.add(lca)
|
| 123 |
+
# Get all the input nodes
|
| 124 |
+
all_input_nodes = []
|
| 125 |
+
all_output_nodes = []
|
| 126 |
+
for node in node_to_fuse:
|
| 127 |
+
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
|
| 128 |
+
all_output_nodes.append(set(self.dag_ir.get_users(node)))
|
| 129 |
+
all_input_nodes = set.union(*all_input_nodes)
|
| 130 |
+
all_output_nodes = set.union(*all_output_nodes)
|
| 131 |
+
|
| 132 |
+
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
|
| 133 |
+
|
| 134 |
+
# Create the subgraph
|
| 135 |
+
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
| 136 |
+
subgraph = DAGIR(self.dag_ir.cc)
|
| 137 |
+
for node in subgraph_.nodes:
|
| 138 |
+
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
| 139 |
+
if node not in node_to_fuse:
|
| 140 |
+
meta.disabled = True
|
| 141 |
+
subgraph.add_node(meta)
|
| 142 |
+
for edge in subgraph_.edges:
|
| 143 |
+
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# Create the fused node
|
| 147 |
+
dag_node = TopoVisitorNode(
|
| 148 |
+
name=f"dag_{lca}", subgraph=subgraph,
|
| 149 |
+
output_node=self.dag_ir.get_node_meta(lca))
|
| 150 |
+
self.dag_ir.add_node(dag_node)
|
| 151 |
+
|
| 152 |
+
# Add input edges
|
| 153 |
+
for idx, node in enumerate(all_input_nodes):
|
| 154 |
+
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
|
| 155 |
+
|
| 156 |
+
# Replace all uses with DAG node (only 1 output node)
|
| 157 |
+
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
|
| 158 |
+
|
| 159 |
+
# Remove all fused nodes
|
| 160 |
+
node_to_fuse.remove(lca)
|
| 161 |
+
for node in node_to_fuse:
|
| 162 |
+
self.dag_ir.remove_node(node)
|
| 163 |
+
|
| 164 |
+
def ensures(self) -> None:
|
| 165 |
+
# Ensure that after the pass, the resulting DAG becomes a tree
|
| 166 |
+
for node in self.dag_ir.nodes:
|
| 167 |
+
out_degree = self.dag_ir.out_degree(node)
|
| 168 |
+
if out_degree > 1:
|
| 169 |
+
raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}")
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Fix the element_output of producer of D.
|
| 35 |
+
|
| 36 |
+
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
| 37 |
+
element converter, so the compute node producing D must have element_output = type(D).
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
| 41 |
+
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class PassFixElementD(EVTPassBase):
|
| 45 |
+
"""
|
| 46 |
+
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
| 47 |
+
element converter, so the compute node producing D must have
|
| 48 |
+
element_output = type(D)
|
| 49 |
+
"""
|
| 50 |
+
dependencies = [
|
| 51 |
+
PassLayoutManipulateElimination
|
| 52 |
+
]
|
| 53 |
+
def get_producer(self, node, element_D):
|
| 54 |
+
node_meta = self.dag_ir.get_node_meta(node)
|
| 55 |
+
if node_meta.op == "compute":
|
| 56 |
+
node_meta.element_output = element_D
|
| 57 |
+
elif node_meta.op == "store":
|
| 58 |
+
self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D)
|
| 59 |
+
|
| 60 |
+
def call(self):
|
| 61 |
+
if self.dag_ir.has_node("D"):
|
| 62 |
+
node_d_meta = self.dag_ir.get_node_meta("D")
|
| 63 |
+
element_D = node_d_meta.store_tensor.element
|
| 64 |
+
self.get_producer("D", element_D)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Infer the underlying implement of each node.
|
| 35 |
+
|
| 36 |
+
While the frontend only distinguish between Load/Store/Compute Node,
|
| 37 |
+
each of these nodes can have different underlying implementation based
|
| 38 |
+
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
| 39 |
+
This pass infers the underlying impl of each node
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
import cutlass_cppgen.backend.evt.backend as evt_backend
|
| 43 |
+
from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode
|
| 44 |
+
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
| 45 |
+
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 46 |
+
from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
|
| 47 |
+
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
| 48 |
+
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class PassGetImpl(EVTPassBase):
|
| 52 |
+
"""
|
| 53 |
+
While the frontend only distinguish between Load/Store/Compute Node,
|
| 54 |
+
each of these nodes can have different underlying implementation based
|
| 55 |
+
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
| 56 |
+
This pass infers the underlying impl of each node
|
| 57 |
+
"""
|
| 58 |
+
dependencies = [
|
| 59 |
+
PassShapeTypePropagation, # The shape and type info are required for inference
|
| 60 |
+
PassFixElementD
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
def __init__(self, dag_ir: DAGIR) -> None:
|
| 64 |
+
super().__init__(dag_ir)
|
| 65 |
+
self.no_op_elimination = PassNoOpElimination(dag_ir)
|
| 66 |
+
|
| 67 |
+
def requires(self) -> None:
|
| 68 |
+
# Verify "accum" is in the arg list
|
| 69 |
+
if not self.dag_ir.has_node("accum"):
|
| 70 |
+
raise SyntaxError("Cannot find 'accum' in the argument list.")
|
| 71 |
+
|
| 72 |
+
def call(self):
|
| 73 |
+
# The loop structure of the epilogue is determined by the
|
| 74 |
+
# accumulator shape
|
| 75 |
+
accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
|
| 76 |
+
problem_size = accumulator.tensor.shape
|
| 77 |
+
|
| 78 |
+
for node_meta in self.dag_ir.node_metas_topological_order():
|
| 79 |
+
node_meta.get_underlying_impl(problem_size)
|
| 80 |
+
|
| 81 |
+
def ensures(self) -> None:
|
| 82 |
+
# Some nodes will be lowered to NoOp, eliminate them
|
| 83 |
+
self.no_op_elimination()
|
| 84 |
+
# Lower to cc-specific impl
|
| 85 |
+
for node_meta in self.dag_ir.nodes_meta:
|
| 86 |
+
node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
|
| 87 |
+
node_meta.underlying_impl = getattr(
|
| 88 |
+
node_impl_ccs,
|
| 89 |
+
f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__
|
| 90 |
+
)(node_meta)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Eliminate layout manipulation nodes
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from copy import deepcopy
|
| 38 |
+
|
| 39 |
+
from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode
|
| 40 |
+
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 41 |
+
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class PassLayoutManipulateElimination(EVTPassBase):
|
| 45 |
+
"""
|
| 46 |
+
Eliminate layout manipulation nodes
|
| 47 |
+
"""
|
| 48 |
+
dependencies = [PassShapeTypePropagation]
|
| 49 |
+
|
| 50 |
+
def __init__(self, dag_ir: DAGIR) -> None:
|
| 51 |
+
super().__init__(dag_ir)
|
| 52 |
+
self.copy_cnt = 0
|
| 53 |
+
|
| 54 |
+
def call(self):
|
| 55 |
+
self.layout_nodes_worklist = self.get_all_layout_nodes()
|
| 56 |
+
# Run while loop utill all layout nodes are eliminated
|
| 57 |
+
while(len(self.layout_nodes_worklist) > 0):
|
| 58 |
+
node = self.layout_nodes_worklist.pop(0)
|
| 59 |
+
# for node in layout_nodes:
|
| 60 |
+
# Step 1: get the propagation direction
|
| 61 |
+
direction = self.get_propagation_direction(node)
|
| 62 |
+
self.visited = []
|
| 63 |
+
getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node)
|
| 64 |
+
# Eliminate the current node
|
| 65 |
+
input_node = self.dag_ir.get_all_inputs(node)[0]
|
| 66 |
+
self.dag_ir.replace_all_uses_with(node, input_node)
|
| 67 |
+
# layout_nodes = self.get_all_layout_nodes()
|
| 68 |
+
|
| 69 |
+
def get_all_layout_nodes(self):
|
| 70 |
+
layout_nodes = []
|
| 71 |
+
for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
|
| 72 |
+
if isinstance(node_meta, LayoutNode):
|
| 73 |
+
layout_nodes.append(node_meta.name)
|
| 74 |
+
return layout_nodes
|
| 75 |
+
|
| 76 |
+
def get_propagation_direction(self, node: str):
|
| 77 |
+
"""
|
| 78 |
+
The logic is propagating all layout nodes away from the accumulator node.
|
| 79 |
+
"""
|
| 80 |
+
self.visited = []
|
| 81 |
+
self.get_influenced_users(node)
|
| 82 |
+
nodes_influenced_dir_users = self.visited
|
| 83 |
+
self.visited = []
|
| 84 |
+
self.get_influenced_inputs(node)
|
| 85 |
+
nodes_influenced_dir_inputs = self.visited
|
| 86 |
+
|
| 87 |
+
if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs:
|
| 88 |
+
return "inputs"
|
| 89 |
+
elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs:
|
| 90 |
+
return "users"
|
| 91 |
+
else:
|
| 92 |
+
raise RuntimeError("Unsolved propagation direction")
|
| 93 |
+
|
| 94 |
+
# Get all influenced nodes if we propagate along the user direction
|
| 95 |
+
def get_influenced_users(self, node: str):
|
| 96 |
+
if node in self.visited:
|
| 97 |
+
return
|
| 98 |
+
self.visited.append(node)
|
| 99 |
+
|
| 100 |
+
users = self.dag_ir.get_users(node)
|
| 101 |
+
for user in users:
|
| 102 |
+
self.get_influenced_users(user)
|
| 103 |
+
user_inputs = []
|
| 104 |
+
for user in users:
|
| 105 |
+
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
| 106 |
+
if len(user_inputs) > 0:
|
| 107 |
+
user_inputs = set.union(*user_inputs)
|
| 108 |
+
user_inputs.remove(node)
|
| 109 |
+
for input in user_inputs:
|
| 110 |
+
self.get_influenced_inputs(input)
|
| 111 |
+
|
| 112 |
+
# Get all influenced nodes if we propagate along the input direction
|
| 113 |
+
def get_influenced_inputs(self, node: str):
|
| 114 |
+
if node in self.visited:
|
| 115 |
+
return
|
| 116 |
+
self.visited.append(node)
|
| 117 |
+
|
| 118 |
+
inputs = self.dag_ir.get_all_inputs(node)
|
| 119 |
+
for input in inputs:
|
| 120 |
+
self.get_influenced_inputs(input)
|
| 121 |
+
input_users = []
|
| 122 |
+
for input in inputs:
|
| 123 |
+
input_users.append(set(self.dag_ir.get_users(input)))
|
| 124 |
+
if len(input_users) > 0:
|
| 125 |
+
input_users = set.union(*input_users)
|
| 126 |
+
input_users.remove(node)
|
| 127 |
+
for user in input_users:
|
| 128 |
+
self.get_influenced_users(user)
|
| 129 |
+
|
| 130 |
+
def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
|
| 131 |
+
copied_node_meta = deepcopy(layout_node_meta)
|
| 132 |
+
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
| 133 |
+
self.copy_cnt += 1
|
| 134 |
+
copied_node_meta.name = copied_node
|
| 135 |
+
self.dag_ir.add_node(copied_node_meta)
|
| 136 |
+
# Add edges
|
| 137 |
+
target_inputs = self.dag_ir.get_all_inputs(target)
|
| 138 |
+
for src in target_inputs:
|
| 139 |
+
self.dag_ir.remove_edge(src, target)
|
| 140 |
+
self.dag_ir.add_edge(src, copied_node)
|
| 141 |
+
self.dag_ir.add_edge(copied_node, target)
|
| 142 |
+
self.layout_nodes_worklist.append(copied_node)
|
| 143 |
+
|
| 144 |
+
def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
|
| 145 |
+
copied_node_meta = deepcopy(layout_node_meta)
|
| 146 |
+
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
| 147 |
+
self.copy_cnt += 1
|
| 148 |
+
copied_node_meta.name = copied_node
|
| 149 |
+
self.dag_ir.add_node(copied_node_meta)
|
| 150 |
+
# Add edges
|
| 151 |
+
users = self.dag_ir.get_users(target)
|
| 152 |
+
for user in users:
|
| 153 |
+
self.dag_ir.remove_edge(target, user)
|
| 154 |
+
self.dag_ir.add_edge(copied_node, user)
|
| 155 |
+
self.dag_ir.add_edge(target, copied_node)
|
| 156 |
+
self.layout_nodes_worklist.append(copied_node)
|
| 157 |
+
|
| 158 |
+
# Propagate the layout `node` along the user direction
|
| 159 |
+
def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
|
| 160 |
+
"""
|
| 161 |
+
Propagate layout node to users
|
| 162 |
+
"""
|
| 163 |
+
if node in self.visited:
|
| 164 |
+
# Avoid applying twice
|
| 165 |
+
return
|
| 166 |
+
self.visited.append(node)
|
| 167 |
+
|
| 168 |
+
node_meta = self.dag_ir.get_node_meta(node)
|
| 169 |
+
if layout_node_meta.name != node:
|
| 170 |
+
if isinstance(node_meta, LayoutNode):
|
| 171 |
+
# Layout node is not transparent with layout node
|
| 172 |
+
self.add_copy_before(layout_node_meta, node)
|
| 173 |
+
return
|
| 174 |
+
else:
|
| 175 |
+
layout_node_meta.apply_to_user(node_meta)
|
| 176 |
+
|
| 177 |
+
users = self.dag_ir.get_users(node)
|
| 178 |
+
user_inputs = []
|
| 179 |
+
for user in users:
|
| 180 |
+
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
| 181 |
+
for user in users:
|
| 182 |
+
self.propagate_to_users(layout_node_meta, user)
|
| 183 |
+
if len(user_inputs) > 0:
|
| 184 |
+
user_inputs = set.union(*user_inputs)
|
| 185 |
+
user_inputs.remove(node)
|
| 186 |
+
for input in user_inputs:
|
| 187 |
+
self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input)
|
| 188 |
+
|
| 189 |
+
# Propagate the layout `node` along the input direction
|
| 190 |
+
def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
|
| 191 |
+
"""
|
| 192 |
+
Propagate layout node to inputs
|
| 193 |
+
"""
|
| 194 |
+
if node in self.visited:
|
| 195 |
+
# Avoid applying twice
|
| 196 |
+
return
|
| 197 |
+
self.visited.append(node)
|
| 198 |
+
|
| 199 |
+
node_meta = self.dag_ir.get_node_meta(node)
|
| 200 |
+
if layout_node_meta.name != node:
|
| 201 |
+
if isinstance(node_meta, LayoutNode):
|
| 202 |
+
# Layout node is not transparent with layout node
|
| 203 |
+
self.add_copy_after(layout_node_meta, node)
|
| 204 |
+
return
|
| 205 |
+
else:
|
| 206 |
+
layout_node_meta.apply_to_input(node_meta)
|
| 207 |
+
inputs = self.dag_ir.get_all_inputs(node)
|
| 208 |
+
input_users = []
|
| 209 |
+
for input in inputs:
|
| 210 |
+
input_users.append(set(self.dag_ir.get_users(input)))
|
| 211 |
+
for input in inputs:
|
| 212 |
+
self.propagate_to_inputs(layout_node_meta, input)
|
| 213 |
+
if len(input_users) > 0:
|
| 214 |
+
input_users = set.union(*input_users)
|
| 215 |
+
input_users.remove(node)
|
| 216 |
+
for user in input_users:
|
| 217 |
+
self.propagate_to_users(layout_node_meta.get_inverse_node(), user)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Pass manager for DAG IR.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from typing import Any
|
| 38 |
+
|
| 39 |
+
import networkx as nx
|
| 40 |
+
|
| 41 |
+
from cutlass_cppgen.backend.evt.ir import DAGIR
|
| 42 |
+
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class EVTPassBase:
|
| 46 |
+
"""
|
| 47 |
+
Base class for EVT Passes
|
| 48 |
+
"""
|
| 49 |
+
dependencies = []
|
| 50 |
+
def __init__(self, dag_ir: DAGIR) -> None:
|
| 51 |
+
self.dag_ir = dag_ir
|
| 52 |
+
self.cc = self.dag_ir.cc
|
| 53 |
+
|
| 54 |
+
def requires(self) -> None:
|
| 55 |
+
"""
|
| 56 |
+
This function will be called before the pass is run.
|
| 57 |
+
"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
def call(self) -> None:
|
| 61 |
+
"""
|
| 62 |
+
The pass that is run through the self.dag_ir
|
| 63 |
+
"""
|
| 64 |
+
raise NotImplementedError(
|
| 65 |
+
f"__call__ is not overwritten in Pass {self.__class__.__name__}")
|
| 66 |
+
|
| 67 |
+
def ensures(self) -> None:
|
| 68 |
+
"""
|
| 69 |
+
This function will be called after the pass is run.
|
| 70 |
+
"""
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
def __call__(self) -> Any:
|
| 74 |
+
self.requires()
|
| 75 |
+
self.call()
|
| 76 |
+
self.ensures()
|
| 77 |
+
|
| 78 |
+
def cc_specific_method(self, func):
|
| 79 |
+
"""
|
| 80 |
+
This enables defining function that behaves differently under different cc
|
| 81 |
+
The simplest example of using this function is the following
|
| 82 |
+
|
| 83 |
+
.. highlight:: python
|
| 84 |
+
.. code-block:: python
|
| 85 |
+
|
| 86 |
+
class ExamplePass(EVTPassBase):
|
| 87 |
+
|
| 88 |
+
def call(sekf):
|
| 89 |
+
# This automatically select the smXX_func based on current cc
|
| 90 |
+
self.cc_specific_method(self.func)()
|
| 91 |
+
|
| 92 |
+
# Interface func, can be empty
|
| 93 |
+
def func(self):
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
# Sm90 specific func
|
| 97 |
+
def sm90_func(self):
|
| 98 |
+
// sm90 specific method
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
# Sm80 specific func
|
| 102 |
+
def sm80_func(self):
|
| 103 |
+
// sm80 specific method
|
| 104 |
+
return
|
| 105 |
+
"""
|
| 106 |
+
func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
|
| 107 |
+
if hasattr(self, func_name):
|
| 108 |
+
return getattr(self, func_name)
|
| 109 |
+
else:
|
| 110 |
+
raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class EVTPassManager(nx.DiGraph):
|
| 114 |
+
"""
|
| 115 |
+
Topological-based Pass Manager.
|
| 116 |
+
Each registered pass has a list of dependencies. The pass manager organizes
|
| 117 |
+
the passes as a DAG and launch the compiler passes under topological order.
|
| 118 |
+
"""
|
| 119 |
+
def __init__(self, dag_ir: DAGIR, pass_list):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.dag_ir = dag_ir
|
| 122 |
+
for pass_cls in pass_list:
|
| 123 |
+
self.add_pass(pass_cls)
|
| 124 |
+
|
| 125 |
+
self.sorted_passes = self.schedule()
|
| 126 |
+
|
| 127 |
+
def get_callable(self, pass_name):
|
| 128 |
+
"""
|
| 129 |
+
Return the callable of the pass
|
| 130 |
+
"""
|
| 131 |
+
return self.nodes[pass_name]["callable"]
|
| 132 |
+
|
| 133 |
+
def add_pass(self, pass_cls):
|
| 134 |
+
"""
|
| 135 |
+
Add a pass to the pass manager
|
| 136 |
+
:param pass_cls: the class of pass
|
| 137 |
+
:type pass_cls: derived class of EVTPassBase
|
| 138 |
+
"""
|
| 139 |
+
name = pass_cls.__name__
|
| 140 |
+
pass_callable = pass_cls(self.dag_ir)
|
| 141 |
+
self.add_node(name, callable=pass_callable)
|
| 142 |
+
|
| 143 |
+
def schedule(self):
|
| 144 |
+
"""
|
| 145 |
+
Schedule the added passes under topological order
|
| 146 |
+
"""
|
| 147 |
+
# Add edges
|
| 148 |
+
for pass_name in self.nodes:
|
| 149 |
+
callable = self.get_callable(pass_name)
|
| 150 |
+
for dependency_cls in callable.dependencies:
|
| 151 |
+
self.add_edge(
|
| 152 |
+
dependency_cls.__name__,
|
| 153 |
+
type(callable).__name__)
|
| 154 |
+
|
| 155 |
+
# Topological sort
|
| 156 |
+
return list(nx.topological_sort(self))
|
| 157 |
+
|
| 158 |
+
def __call__(self) -> Any:
|
| 159 |
+
"""
|
| 160 |
+
Launch the registered passes
|
| 161 |
+
"""
|
| 162 |
+
for pass_name in self.sorted_passes:
|
| 163 |
+
callable = self.get_callable(pass_name)
|
| 164 |
+
callable()
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
No op elimination node
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from typing import Any
|
| 38 |
+
|
| 39 |
+
from cutlass_cppgen.backend.evt.ir import NoOpImpl
|
| 40 |
+
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class PassNoOpElimination(EVTPassBase):
|
| 44 |
+
"""
|
| 45 |
+
The dead node elimination pass removes nodes with NoOpImpl in DAG IR
|
| 46 |
+
"""
|
| 47 |
+
dependencies = []
|
| 48 |
+
|
| 49 |
+
def call(self) -> Any:
|
| 50 |
+
for node in self.dag_ir.nodes_topological_order():
|
| 51 |
+
node_meta = self.dag_ir.get_node_meta(node)
|
| 52 |
+
if isinstance(node_meta.underlying_impl, NoOpImpl):
|
| 53 |
+
self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0])
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Preprocess the reduction nodes.
|
| 35 |
+
|
| 36 |
+
The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
|
| 37 |
+
This pass fuses these into a single store node, and then replaces all uses of the
|
| 38 |
+
current node with the new store node.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode
|
| 42 |
+
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PassPreprocessRed(EVTPassBase):
|
| 46 |
+
"""
|
| 47 |
+
Preprocess red nodes
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def call(self):
|
| 51 |
+
# Step 1: find the compute nodes with op=red
|
| 52 |
+
red_compute_nodes = []
|
| 53 |
+
for node_meta in self.dag_ir.nodes_meta:
|
| 54 |
+
if isinstance(node_meta, ComputeNode):
|
| 55 |
+
if type(node_meta.fn) == tuple:
|
| 56 |
+
# To keep the frontend simple, the reduction nodes
|
| 57 |
+
# are parsed into compute nodes by default
|
| 58 |
+
# The simple heuristic to distinguish between compute
|
| 59 |
+
# and reduction node is that compute node is a single function,
|
| 60 |
+
# while the reduction node is a tuple of functions for
|
| 61 |
+
# in-register reduction and atomic global memory reduction
|
| 62 |
+
red_compute_nodes.append(node_meta.name)
|
| 63 |
+
|
| 64 |
+
# Step 2: for each compute, merge it with the succeeding store
|
| 65 |
+
for node in red_compute_nodes:
|
| 66 |
+
# Verify
|
| 67 |
+
users = self.dag_ir.get_users(node)
|
| 68 |
+
inputs = self.dag_ir.get_all_inputs(node)
|
| 69 |
+
# Has a single user
|
| 70 |
+
assert len(users) == 1
|
| 71 |
+
assert len(inputs) == 1
|
| 72 |
+
user = users[0]
|
| 73 |
+
input = inputs[0]
|
| 74 |
+
|
| 75 |
+
user_meta = self.dag_ir.get_node_meta(user)
|
| 76 |
+
# Must be a store node
|
| 77 |
+
assert isinstance(user_meta, StoreNode)
|
| 78 |
+
# With output degree == 0
|
| 79 |
+
assert self.dag_ir.out_degree(user) == 0
|
| 80 |
+
# Register the reduce op
|
| 81 |
+
node_meta = self.dag_ir.get_node_meta(node)
|
| 82 |
+
user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
|
| 83 |
+
user_meta.element_compute = node_meta.element_compute
|
| 84 |
+
user_meta.round_style = node_meta.round_style
|
| 85 |
+
|
| 86 |
+
# Replace all uses
|
| 87 |
+
self.dag_ir.remove_edge(input, node)
|
| 88 |
+
input_users = self.dag_ir.get_users(input)
|
| 89 |
+
for iu in input_users:
|
| 90 |
+
weight = self.dag_ir.get_edge_weight(input, iu)
|
| 91 |
+
self.dag_ir.add_edge(user, iu, weight)
|
| 92 |
+
self.dag_ir.remove_edge(input, iu)
|
| 93 |
+
self.dag_ir.add_edge(input, user)
|
| 94 |
+
self.dag_ir.remove_node(node)
|
| 95 |
+
|
| 96 |
+
# Register the reduction name
|
| 97 |
+
self.dag_ir.reduction_names.append(user)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Shape and type propagation pass
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
| 38 |
+
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 39 |
+
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class PassShapeTypePropagation(EVTPassBase):
|
| 43 |
+
"""
|
| 44 |
+
Propagate the shape and type of all nodes
|
| 45 |
+
"""
|
| 46 |
+
dependencies = [PassPreprocessRed]
|
| 47 |
+
|
| 48 |
+
def call(self):
|
| 49 |
+
# Propagate the node shape and type
|
| 50 |
+
for node in self.dag_ir.nodes_topological_order():
|
| 51 |
+
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
| 52 |
+
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
| 53 |
+
node_meta.type_propagation(input_node_metas)
|
| 54 |
+
node_meta.shape_propagation(input_node_metas)
|
| 55 |
+
|
| 56 |
+
for node in reversed(self.dag_ir.nodes_topological_order()):
|
| 57 |
+
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
| 58 |
+
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
| 59 |
+
node_meta.broadcast_propagation(input_node_metas)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Compute the shared memory size in bytes
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from math import gcd
|
| 38 |
+
|
| 39 |
+
import cutlass_library
|
| 40 |
+
from pycute import flatten, shape_div, product
|
| 41 |
+
|
| 42 |
+
import cutlass_cppgen
|
| 43 |
+
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
|
| 44 |
+
from cutlass_cppgen.backend.library import DataType, DataTypeSize
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class GetSmemSize:
|
| 48 |
+
"""
|
| 49 |
+
Get the size in byte of shared memory used by the kernel
|
| 50 |
+
"""
|
| 51 |
+
def __init__(self, dag_ir: DAGIR) -> None:
|
| 52 |
+
self.dag_ir = dag_ir
|
| 53 |
+
self.cc = self.dag_ir.cc
|
| 54 |
+
|
| 55 |
+
#
|
| 56 |
+
# Sm90 epilogue specific
|
| 57 |
+
#
|
| 58 |
+
|
| 59 |
+
def sm90_epilogue_tile(self, tile_description):
|
| 60 |
+
# Get the epilogue tile size
|
| 61 |
+
schedule = tile_description.epilogue_schedule
|
| 62 |
+
if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized:
|
| 63 |
+
element_d = self.dag_ir.get_node_meta("D").element
|
| 64 |
+
nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32
|
| 65 |
+
epi_tile_m = min(64, tile_description.threadblock_shape[0])
|
| 66 |
+
epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
|
| 67 |
+
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
|
| 68 |
+
elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative:
|
| 69 |
+
epi_tile_m = min(128, tile_description.threadblock_shape[0])
|
| 70 |
+
epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
|
| 71 |
+
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
|
| 72 |
+
else:
|
| 73 |
+
raise NotImplementedError(f"Unsupported schedule: {schedule}")
|
| 74 |
+
|
| 75 |
+
# Get the pipeline stages
|
| 76 |
+
stages_d = 2
|
| 77 |
+
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
|
| 78 |
+
if self.dag_ir.has_node("C"):
|
| 79 |
+
element_c = self.dag_ir.get_node_meta("C").element
|
| 80 |
+
else:
|
| 81 |
+
element_c = None
|
| 82 |
+
|
| 83 |
+
element_d = self.dag_ir.get_node_meta("D").element
|
| 84 |
+
if element_c == element_d:
|
| 85 |
+
reuse_smem_c = True
|
| 86 |
+
else:
|
| 87 |
+
reuse_smem_c = False
|
| 88 |
+
stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
|
| 89 |
+
|
| 90 |
+
# Record the epilogue tile
|
| 91 |
+
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
| 92 |
+
self.epilogue_tile_mn = epilogue_tile_mn
|
| 93 |
+
self.epi_tiles = epi_tiles
|
| 94 |
+
self.stages_c = stages_c
|
| 95 |
+
self.stages_d = stages_d
|
| 96 |
+
self.reuse_smem_c = reuse_smem_c
|
| 97 |
+
self.element_c = element_c
|
| 98 |
+
self.element_d = element_d
|
| 99 |
+
self.is_source_supported = element_c is not None
|
| 100 |
+
|
| 101 |
+
def sm90_or_sm100_epilogue_smem_size(self, tile_description):
|
| 102 |
+
# Get the Fusion Storage
|
| 103 |
+
nodes = self.dag_ir.nodes_topological_order()
|
| 104 |
+
self.smem_types = {}
|
| 105 |
+
for node in nodes:
|
| 106 |
+
meta = self.dag_ir.get_node_meta(node)
|
| 107 |
+
if not meta.disabled:
|
| 108 |
+
self.smem_types[node] = meta.underlying_impl.get_smem_size(
|
| 109 |
+
self.cta_tile_mnk, self.epilogue_tile_mn,
|
| 110 |
+
self.stages_c, self.stages_d, self.epi_tiles)
|
| 111 |
+
if node == "D":
|
| 112 |
+
continue
|
| 113 |
+
if isinstance(meta, TopoVisitorNode):
|
| 114 |
+
self.get_dag_smem_type(node)
|
| 115 |
+
else:
|
| 116 |
+
self.get_evt_smem_type(node)
|
| 117 |
+
|
| 118 |
+
thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
|
| 119 |
+
# Get the Tensor Storage
|
| 120 |
+
tensors = []
|
| 121 |
+
if self.is_source_supported:
|
| 122 |
+
smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8
|
| 123 |
+
tensors.append((smem_C, 128))
|
| 124 |
+
else:
|
| 125 |
+
tensors.append((0, 1))
|
| 126 |
+
if self.reuse_smem_c:
|
| 127 |
+
tensors.append((0, 128))
|
| 128 |
+
else:
|
| 129 |
+
smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8
|
| 130 |
+
tensors.append((smem_D, 128))
|
| 131 |
+
tensors.append((thread_smem_size, 128))
|
| 132 |
+
|
| 133 |
+
tensor_smem_size = self.get_struct_size(tensors)
|
| 134 |
+
# Get pipeline storage size
|
| 135 |
+
# sizeof(uint64_t * stages_c * 2), alignment of uint64_t
|
| 136 |
+
# 2 is for FullBarrier and EmptyBarrier
|
| 137 |
+
pipeline_smem_size = (8 * self.stages_c * 2, 8)
|
| 138 |
+
|
| 139 |
+
# get SharedStorage size
|
| 140 |
+
smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
|
| 141 |
+
return smem_size[0]
|
| 142 |
+
|
| 143 |
+
def sm90_epilogue_smem_size(self, tile_description):
|
| 144 |
+
"""
|
| 145 |
+
Compute the shared memory size of sm90 collective epilogue
|
| 146 |
+
"""
|
| 147 |
+
self.sm90_epilogue_tile(tile_description)
|
| 148 |
+
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
|
| 149 |
+
|
| 150 |
+
#
|
| 151 |
+
# Sm100 epilogue specific
|
| 152 |
+
#
|
| 153 |
+
|
| 154 |
+
def sm100_epilogue_tile(self, tile_description):
|
| 155 |
+
cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1])
|
| 156 |
+
mma_tile = cta_tile
|
| 157 |
+
|
| 158 |
+
if tile_description.is_2sm:
|
| 159 |
+
cta_tile = (cta_tile[0] // 2, cta_tile[1])
|
| 160 |
+
|
| 161 |
+
if tile_description.is_2sm and mma_tile[0] == 128:
|
| 162 |
+
tmem_warps = (2, 2)
|
| 163 |
+
else:
|
| 164 |
+
tmem_warps = (4, 1)
|
| 165 |
+
|
| 166 |
+
if self.dag_ir.has_node("C"):
|
| 167 |
+
element_c = self.dag_ir.get_node_meta("C").element
|
| 168 |
+
element_c_size = DataTypeSize[element_c]
|
| 169 |
+
else:
|
| 170 |
+
element_c = None
|
| 171 |
+
element_c_size = 0
|
| 172 |
+
|
| 173 |
+
element_d = self.dag_ir.get_node_meta("D").element
|
| 174 |
+
|
| 175 |
+
DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void
|
| 176 |
+
|
| 177 |
+
CtaM = cta_tile[0]
|
| 178 |
+
CtaN = cta_tile[1]
|
| 179 |
+
WarpM = tmem_warps[0]
|
| 180 |
+
WarpN = tmem_warps[1]
|
| 181 |
+
MaxBits = max(element_c_size, DataTypeSize[element_d])
|
| 182 |
+
DpFull = 32
|
| 183 |
+
M = min(CtaM, DpFull * WarpM)
|
| 184 |
+
|
| 185 |
+
if DisableSource:
|
| 186 |
+
# Epilogues w/o residual load are less sensitive to smem allocation
|
| 187 |
+
# Target a fixed amount of compute per epilogue iteration
|
| 188 |
+
if MaxBits == 4:
|
| 189 |
+
# Make epilogue tile larger to reduce the epilogue iterations.
|
| 190 |
+
# 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
|
| 191 |
+
ComputeElts = 8192
|
| 192 |
+
Nperf = ComputeElts // M
|
| 193 |
+
else:
|
| 194 |
+
ComputeElts = 4096
|
| 195 |
+
Nperf = ComputeElts // M
|
| 196 |
+
else:
|
| 197 |
+
# Epilogues w/ residual load are more sensitive to smem allocation
|
| 198 |
+
# Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
|
| 199 |
+
if MaxBits == 32:
|
| 200 |
+
Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32
|
| 201 |
+
elif MaxBits == 16:
|
| 202 |
+
Nperf = 32 if CtaN <= 128 else 64
|
| 203 |
+
else:
|
| 204 |
+
Nperf = 64
|
| 205 |
+
|
| 206 |
+
def is_m_major(layout):
|
| 207 |
+
return flatten(layout.stride[0]) == 1
|
| 208 |
+
|
| 209 |
+
if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout):
|
| 210 |
+
N_min_C = 8 * WarpN
|
| 211 |
+
elif element_c_size == 6:
|
| 212 |
+
N_min_C = 128 * WarpN
|
| 213 |
+
else:
|
| 214 |
+
N_min_C = (128 // element_c_size) * WarpN
|
| 215 |
+
|
| 216 |
+
if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout):
|
| 217 |
+
N_min_D = 8 * WarpN
|
| 218 |
+
elif DataTypeSize[element_d] == 6:
|
| 219 |
+
N_min_D = 128 * WarpN
|
| 220 |
+
else:
|
| 221 |
+
N_min_D = (128 // DataTypeSize[element_d]) * WarpN
|
| 222 |
+
|
| 223 |
+
N = min(CtaN, max(Nperf, N_min_C, N_min_D))
|
| 224 |
+
|
| 225 |
+
tile_m = M
|
| 226 |
+
tile_n_size = N // WarpN * WarpN
|
| 227 |
+
|
| 228 |
+
epilogue_tile_mn = (tile_m, tile_n_size)
|
| 229 |
+
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
|
| 230 |
+
|
| 231 |
+
stages_d = min(epi_tiles, 2)
|
| 232 |
+
reuse_smem_c = (element_c_size > 8)
|
| 233 |
+
|
| 234 |
+
if reuse_smem_c:
|
| 235 |
+
stages_c = max(min(epi_tiles, 4), stages_d + 1)
|
| 236 |
+
else:
|
| 237 |
+
stages_c = min(epi_tiles, 4)
|
| 238 |
+
|
| 239 |
+
# Record the epilogue tile
|
| 240 |
+
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
| 241 |
+
self.epilogue_tile_mn = epilogue_tile_mn
|
| 242 |
+
self.epi_tiles = epi_tiles
|
| 243 |
+
self.stages_c = stages_c
|
| 244 |
+
self.stages_d = stages_d
|
| 245 |
+
self.reuse_smem_c = reuse_smem_c
|
| 246 |
+
self.element_c = element_c
|
| 247 |
+
self.element_d = element_d
|
| 248 |
+
self.is_source_supported = not DisableSource
|
| 249 |
+
|
| 250 |
+
def sm100_epilogue_smem_size(self, tile_description):
|
| 251 |
+
"""
|
| 252 |
+
Compute the shared memory size of sm100 collective epilogue
|
| 253 |
+
"""
|
| 254 |
+
self.sm100_epilogue_tile(tile_description)
|
| 255 |
+
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
|
| 256 |
+
|
| 257 |
+
def __call__(self, tile_description):
|
| 258 |
+
return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
|
| 259 |
+
|
| 260 |
+
#
|
| 261 |
+
# Helper functions
|
| 262 |
+
#
|
| 263 |
+
|
| 264 |
+
@staticmethod
|
| 265 |
+
def get_visitor_size(members: list, ebo: bool):
|
| 266 |
+
"""
|
| 267 |
+
Get the size of struct in bytes
|
| 268 |
+
"""
|
| 269 |
+
offset = 0
|
| 270 |
+
max_alignment = 1
|
| 271 |
+
if len(members) > 0:
|
| 272 |
+
# Get alignment
|
| 273 |
+
for _, alignment in members:
|
| 274 |
+
max_alignment = max(max_alignment, alignment)
|
| 275 |
+
|
| 276 |
+
for type_size, _ in members:
|
| 277 |
+
if type_size != 0:
|
| 278 |
+
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
| 279 |
+
if type_size == 0 and not ebo:
|
| 280 |
+
offset += 1
|
| 281 |
+
else:
|
| 282 |
+
offset += type_size
|
| 283 |
+
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
| 284 |
+
return (offset, max_alignment)
|
| 285 |
+
else:
|
| 286 |
+
# Struct size is at least 1
|
| 287 |
+
return (1, 1)
|
| 288 |
+
|
| 289 |
+
def get_struct_size(self, members: list):
|
| 290 |
+
"""
|
| 291 |
+
Get the size of struct in bytes
|
| 292 |
+
"""
|
| 293 |
+
return self.get_visitor_size(members, False)
|
| 294 |
+
|
| 295 |
+
def get_evt_smem_type(self, node):
|
| 296 |
+
# Sort the input nodes by edge weight
|
| 297 |
+
input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)]
|
| 298 |
+
input_types.append(self.smem_types[node])
|
| 299 |
+
if len(input_types) > 1:
|
| 300 |
+
ebo = len(input_types) > 4
|
| 301 |
+
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
| 302 |
+
|
| 303 |
+
def get_dag_smem_type(self, node):
|
| 304 |
+
meta = self.dag_ir.get_node_meta(node)
|
| 305 |
+
subgraph = meta.subgraph
|
| 306 |
+
subgraph_nodes = subgraph.nodes_topological_order()
|
| 307 |
+
# Visit the unvisited nodes in subgraph
|
| 308 |
+
for n in subgraph_nodes:
|
| 309 |
+
m = subgraph.get_node_meta(n)
|
| 310 |
+
if m.disabled:
|
| 311 |
+
continue
|
| 312 |
+
else:
|
| 313 |
+
self.smem_types[n] = m.underlying_impl.get_smem_size(
|
| 314 |
+
self.cta_tile_mnk, self.epilogue_tile_mn,
|
| 315 |
+
self.stages_c, self.stages_d, self.epi_tiles)
|
| 316 |
+
input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
|
| 317 |
+
if len(input_types) > 0:
|
| 318 |
+
ebo = len(input_types) > 4
|
| 319 |
+
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for passes
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
# Map from the CC of the kernel to the EVT implementation that the CC targets
|
| 38 |
+
cc_map = {
|
| 39 |
+
80: 80,
|
| 40 |
+
86: 80,
|
| 41 |
+
89: 80,
|
| 42 |
+
90: 90,
|
| 43 |
+
100: 100,
|
| 44 |
+
101: 100,
|
| 45 |
+
103: 100,
|
| 46 |
+
}
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 35 |
+
cuda = lazy_import("cuda.cuda")
|
| 36 |
+
import numpy as np
|
| 37 |
+
|
| 38 |
+
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
|
| 39 |
+
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class NumpyFrontend:
|
| 43 |
+
"""
|
| 44 |
+
Frontend node for numpy
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr:
|
| 49 |
+
"""Convert the input numpy tensor to CUDA device pointer
|
| 50 |
+
|
| 51 |
+
:param np_tensor: input numpy nd array
|
| 52 |
+
:param is_output: whether the tensor is output
|
| 53 |
+
|
| 54 |
+
:return: CUDA device pointer
|
| 55 |
+
"""
|
| 56 |
+
# copy the data to device
|
| 57 |
+
if is_output:
|
| 58 |
+
return device_mem_alloc(np_tensor.size * np_tensor.itemsize)
|
| 59 |
+
else:
|
| 60 |
+
return todevice(np_tensor)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class TorchFrontend:
|
| 64 |
+
"""
|
| 65 |
+
Frontend node for torch
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr:
|
| 70 |
+
"""Convert the input torch tensor to CUDA device pointer
|
| 71 |
+
|
| 72 |
+
:param torch_tensor: input torch tensor
|
| 73 |
+
:param is_output: whether the tensor is output
|
| 74 |
+
|
| 75 |
+
:return: CUDA device pointer
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
# check the device of torch_tensor
|
| 79 |
+
if not torch_tensor.is_cuda:
|
| 80 |
+
torch_tensor = torch_tensor.to("cuda")
|
| 81 |
+
|
| 82 |
+
return cuda.CUdeviceptr(torch_tensor.data_ptr())
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class CupyFrontend:
|
| 86 |
+
"""
|
| 87 |
+
Frontend node for cupy
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def argument(cupy_ndarray: "cp.ndarray"):
|
| 92 |
+
return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr))
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class TensorFrontend:
|
| 96 |
+
"""
|
| 97 |
+
Universal Frontend for client-provide tensors
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def argument(tensor, is_output=False):
|
| 102 |
+
if is_numpy_tensor(tensor):
|
| 103 |
+
return NumpyFrontend.argument(tensor, is_output)
|
| 104 |
+
elif is_torch_tensor(tensor):
|
| 105 |
+
return TorchFrontend.argument(tensor)
|
| 106 |
+
elif is_cupy_tensor(tensor):
|
| 107 |
+
return CupyFrontend.argument(tensor)
|
| 108 |
+
else:
|
| 109 |
+
raise NotImplementedError("Unknown Tensor Type")
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py
ADDED
|
@@ -0,0 +1,2145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import copy
|
| 35 |
+
import ctypes
|
| 36 |
+
import enum
|
| 37 |
+
|
| 38 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 39 |
+
cuda = lazy_import("cuda.cuda")
|
| 40 |
+
cudart = lazy_import("cuda.cudart")
|
| 41 |
+
from cutlass_library import SubstituteTemplate
|
| 42 |
+
import numpy as np
|
| 43 |
+
|
| 44 |
+
from cutlass_library import (
|
| 45 |
+
ComplexTransformTag,
|
| 46 |
+
DataType,
|
| 47 |
+
DataTypeNames,
|
| 48 |
+
DataTypeSize,
|
| 49 |
+
DataTypeTag,
|
| 50 |
+
EpilogueScheduleSuffixes,
|
| 51 |
+
EpilogueScheduleTag,
|
| 52 |
+
EpilogueScheduleType,
|
| 53 |
+
GemmKind,
|
| 54 |
+
GemmKindNames,
|
| 55 |
+
GemmUniversalMode,
|
| 56 |
+
KernelScheduleSuffixes,
|
| 57 |
+
KernelScheduleTag,
|
| 58 |
+
KernelScheduleType,
|
| 59 |
+
LayoutTag,
|
| 60 |
+
LayoutType,
|
| 61 |
+
MathOperation,
|
| 62 |
+
MathOperationTag,
|
| 63 |
+
OpcodeClass,
|
| 64 |
+
OpcodeClassNames,
|
| 65 |
+
OpcodeClassTag,
|
| 66 |
+
OperationKind,
|
| 67 |
+
ShortComplexLayoutNames,
|
| 68 |
+
ShortDataTypeNames,
|
| 69 |
+
ShortLayoutTypeNames,
|
| 70 |
+
SwizzlingFunctor,
|
| 71 |
+
SwizzlingFunctorTag,
|
| 72 |
+
TileSchedulerSuffixes,
|
| 73 |
+
TileSchedulerTag,
|
| 74 |
+
TileSchedulerType,
|
| 75 |
+
get_complex_from_real
|
| 76 |
+
)
|
| 77 |
+
from cutlass_cppgen.backend.arguments import ArgumentBase
|
| 78 |
+
from cutlass_cppgen.backend.c_types import (
|
| 79 |
+
GemmCoord_,
|
| 80 |
+
GemmCoordBatched_,
|
| 81 |
+
GenericMainloopArguments3x_,
|
| 82 |
+
StrideBatched_,
|
| 83 |
+
dim3_,
|
| 84 |
+
get_gemm_arguments,
|
| 85 |
+
get_gemm_arguments_3x,
|
| 86 |
+
get_gemm_arguments_streamk,
|
| 87 |
+
get_gemm_grouped_arguments,
|
| 88 |
+
get_mainloop_arguments_3x,
|
| 89 |
+
get_tile_scheduler_arguments_3x,
|
| 90 |
+
)
|
| 91 |
+
from cutlass_cppgen.backend.library import (
|
| 92 |
+
ApiVersion,
|
| 93 |
+
EmissionType,
|
| 94 |
+
SchedulerMode,
|
| 95 |
+
SchedulerModeTag,
|
| 96 |
+
TensorDescription,
|
| 97 |
+
TileDescription,
|
| 98 |
+
api_version,
|
| 99 |
+
)
|
| 100 |
+
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
|
| 101 |
+
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
|
| 102 |
+
from cutlass_cppgen.backend.type_hint import GemmOperation, Tensor
|
| 103 |
+
from cutlass_cppgen.backend.utils.device import device_sm_count
|
| 104 |
+
from cutlass_cppgen.shape import GemmCoord, MatrixCoord
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
################################################################################
|
| 108 |
+
#
|
| 109 |
+
# Data structure modeling a GEMM operation
|
| 110 |
+
#
|
| 111 |
+
################################################################################
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def leading_dimension(layout: LayoutType, shape: MatrixCoord) -> int:
|
| 115 |
+
"""
|
| 116 |
+
Returns the leading dimenson of a tensor with layout ``layout`` and shape ``shape``.
|
| 117 |
+
|
| 118 |
+
:param layout: layout of the tensor
|
| 119 |
+
:type layout: cutlass_cppgen.shape.LayoutType
|
| 120 |
+
:param shape: shape of the tensor
|
| 121 |
+
:type shape: cutlass_cppgen.shape.MatrixCoord
|
| 122 |
+
|
| 123 |
+
:return: leading dimension of the tensor
|
| 124 |
+
:rtype: int
|
| 125 |
+
"""
|
| 126 |
+
if layout == LayoutType.RowMajor:
|
| 127 |
+
return shape.column
|
| 128 |
+
elif layout == LayoutType.ColumnMajor:
|
| 129 |
+
return shape.row
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def transpose_layout(layout: LayoutType) -> LayoutType:
|
| 133 |
+
if layout == LayoutType.ColumnMajor:
|
| 134 |
+
return LayoutType.RowMajor
|
| 135 |
+
elif layout == LayoutType.RowMajor:
|
| 136 |
+
return LayoutType.ColumnMajor
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"Unsupported Layout {layout}")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class GemmArguments2x(ArgumentBase):
|
| 142 |
+
"""
|
| 143 |
+
Argument wrapper for GEMM in CUTLASS 2. It encodes problem information and
|
| 144 |
+
user-provide tensors into the kernel's argument
|
| 145 |
+
|
| 146 |
+
:param operation: the GEMM operation to take the argument
|
| 147 |
+
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
| 148 |
+
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
| 149 |
+
|
| 150 |
+
:param problem_size: GEMM problem size gemm(M, N, K)
|
| 151 |
+
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
| 152 |
+
|
| 153 |
+
:param A: tensor A
|
| 154 |
+
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 155 |
+
|
| 156 |
+
:param B: tensor B
|
| 157 |
+
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 158 |
+
|
| 159 |
+
:param C: tensor C
|
| 160 |
+
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 161 |
+
|
| 162 |
+
:param D: tensor D
|
| 163 |
+
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 164 |
+
|
| 165 |
+
:param gemm_mode: GEMM mode
|
| 166 |
+
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
|
| 167 |
+
|
| 168 |
+
:param output_op: output operator, optional
|
| 169 |
+
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
| 170 |
+
|
| 171 |
+
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
| 172 |
+
:type stream: :class:`cuda.cuda.CUstream`
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
| 176 |
+
self.operation = operation
|
| 177 |
+
|
| 178 |
+
self.layout_A = operation.A.layout
|
| 179 |
+
self.layout_B = operation.B.layout
|
| 180 |
+
self.layout_C = operation.C.layout
|
| 181 |
+
|
| 182 |
+
self.element_A = operation.A.element
|
| 183 |
+
self.element_B = operation.B.element
|
| 184 |
+
self.element_C = operation.C.element
|
| 185 |
+
|
| 186 |
+
if operation.C.layout in [LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32]:
|
| 187 |
+
raise Exception("Interleaved layout not currently supported")
|
| 188 |
+
|
| 189 |
+
if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch not in [90, 100, 101, 103]:
|
| 190 |
+
super().__init__(A, B, None, None, **kwargs)
|
| 191 |
+
else:
|
| 192 |
+
super().__init__(A, B, C, D, **kwargs)
|
| 193 |
+
|
| 194 |
+
if operation.switched:
|
| 195 |
+
self.problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
|
| 196 |
+
self.ptr_A, self.ptr_B = self.ptr_B, self.ptr_A
|
| 197 |
+
else:
|
| 198 |
+
self.problem_size = problem_size
|
| 199 |
+
# If the number of elements in C = problem_size.n, C is treated as the bias
|
| 200 |
+
if hasattr(self, "tensor_c_numel"):
|
| 201 |
+
if self.tensor_c_numel == self.problem_size.n and self.problem_size.m != 1:
|
| 202 |
+
self.bias = True
|
| 203 |
+
|
| 204 |
+
self.lda = leading_dimension(self.layout_A, self.problem_size.mk)
|
| 205 |
+
self.ldb = leading_dimension(self.layout_B, self.problem_size.kn)
|
| 206 |
+
self.ldc = leading_dimension(self.layout_C, self.problem_size.mn)
|
| 207 |
+
self.ldd = self.ldc
|
| 208 |
+
|
| 209 |
+
if self.bias:
|
| 210 |
+
self.ldc = 0
|
| 211 |
+
|
| 212 |
+
if "output_op" in kwargs.keys() and gemm_mode != GemmUniversalMode.GemmSplitKParallel:
|
| 213 |
+
self.output_op = kwargs["output_op"]
|
| 214 |
+
else:
|
| 215 |
+
if self.operation.epilogue_functor.element_epilogue in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]:
|
| 216 |
+
dtype = int
|
| 217 |
+
else:
|
| 218 |
+
dtype = float
|
| 219 |
+
self.output_op = self.operation.epilogue_type(dtype(1.0), dtype(0.0))
|
| 220 |
+
|
| 221 |
+
self.gemm_mode = gemm_mode
|
| 222 |
+
if gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]:
|
| 223 |
+
if "split_k_slices" in kwargs.keys():
|
| 224 |
+
self.batch_count = kwargs["split_k_slices"]
|
| 225 |
+
else:
|
| 226 |
+
self.batch_count = 1
|
| 227 |
+
self.split_k_slices = self.batch_count
|
| 228 |
+
|
| 229 |
+
if gemm_mode in [GemmUniversalMode.Batched, GemmUniversalMode.Array]:
|
| 230 |
+
if "batch" in kwargs.keys():
|
| 231 |
+
self.batch_count = kwargs["batch"]
|
| 232 |
+
else:
|
| 233 |
+
self.batch_count = 1
|
| 234 |
+
|
| 235 |
+
if "batch_strides" in kwargs:
|
| 236 |
+
self.batched_stride_A = kwargs["batch_strides"]["A"]
|
| 237 |
+
self.batched_stride_B = kwargs["batch_strides"]["B"]
|
| 238 |
+
self.batched_stride_C = kwargs["batch_strides"]["C"]
|
| 239 |
+
self.batched_stride_D = kwargs["batch_strides"]["D"]
|
| 240 |
+
else:
|
| 241 |
+
self.batched_stride_A = self.problem_size.m * self.problem_size.k
|
| 242 |
+
self.batched_stride_B = self.problem_size.n * self.problem_size.k
|
| 243 |
+
self.batched_stride_C = self.problem_size.m * self.problem_size.n
|
| 244 |
+
self.batched_stride_D = self.problem_size.m * self.problem_size.n
|
| 245 |
+
|
| 246 |
+
if self.bias:
|
| 247 |
+
self.batched_stride_C = self.problem_size.n
|
| 248 |
+
|
| 249 |
+
if gemm_mode == GemmUniversalMode.Array:
|
| 250 |
+
self.ptr_A_array = []
|
| 251 |
+
self.ptr_B_array = []
|
| 252 |
+
self.ptr_C_array = []
|
| 253 |
+
self.ptr_D_array = []
|
| 254 |
+
|
| 255 |
+
ptr_A_addr = int(self.ptr_A)
|
| 256 |
+
ptr_B_addr = int(self.ptr_B)
|
| 257 |
+
ptr_C_addr = int(self.ptr_C)
|
| 258 |
+
ptr_D_addr = int(self.ptr_D)
|
| 259 |
+
|
| 260 |
+
stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8
|
| 261 |
+
stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8
|
| 262 |
+
stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8
|
| 263 |
+
stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8
|
| 264 |
+
for _ in range(self.batch_count):
|
| 265 |
+
self.ptr_A_array.append(ptr_A_addr)
|
| 266 |
+
self.ptr_B_array.append(ptr_B_addr)
|
| 267 |
+
self.ptr_C_array.append(ptr_C_addr)
|
| 268 |
+
self.ptr_D_array.append(ptr_D_addr)
|
| 269 |
+
|
| 270 |
+
ptr_A_addr += stride_A
|
| 271 |
+
ptr_B_addr += stride_B
|
| 272 |
+
ptr_C_addr += stride_C
|
| 273 |
+
ptr_D_addr += stride_D
|
| 274 |
+
|
| 275 |
+
self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64)
|
| 276 |
+
self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64)
|
| 277 |
+
self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64)
|
| 278 |
+
self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64)
|
| 279 |
+
|
| 280 |
+
if isinstance(self.operation, GemmOperationUniversal):
|
| 281 |
+
self.initialize()
|
| 282 |
+
|
| 283 |
+
def get_arguments(self):
|
| 284 |
+
problem_size_ = self.problem_size.ctype
|
| 285 |
+
grid_tiled_shape_ = GemmCoord(
|
| 286 |
+
self.grid_tiled_shape.x,
|
| 287 |
+
self.grid_tiled_shape.y,
|
| 288 |
+
self.grid_tiled_shape.z ).ctype
|
| 289 |
+
|
| 290 |
+
if self.gemm_mode == GemmUniversalMode.Array:
|
| 291 |
+
arguments = self.operation.argument_type(
|
| 292 |
+
# Arguments from UniversalArgumentsBase
|
| 293 |
+
self.gemm_mode,
|
| 294 |
+
problem_size_,
|
| 295 |
+
self.batch_count,
|
| 296 |
+
0,
|
| 297 |
+
# Remaining arguments
|
| 298 |
+
self.output_op,
|
| 299 |
+
int(self.ptr_A_array_buffer.ptr),
|
| 300 |
+
int(self.ptr_B_array_buffer.ptr),
|
| 301 |
+
int(self.ptr_C_array_buffer.ptr),
|
| 302 |
+
int(self.ptr_D_array_buffer.ptr),
|
| 303 |
+
0, 0, 0,
|
| 304 |
+
self.lda, self.ldb, self.ldc, self.ldd,
|
| 305 |
+
self.lda, self.ldb, self.ldc, self.ldd,
|
| 306 |
+
0, 0, 0
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
arguments = self.operation.argument_type(
|
| 310 |
+
# Arguments from UniversalArgumentsBase
|
| 311 |
+
self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D,
|
| 312 |
+
# Remaining arguments
|
| 313 |
+
self.output_op,
|
| 314 |
+
int(self.ptr_A),
|
| 315 |
+
int(self.ptr_B),
|
| 316 |
+
int(self.ptr_C),
|
| 317 |
+
int(self.ptr_D),
|
| 318 |
+
self.batched_stride_A,
|
| 319 |
+
self.batched_stride_B,
|
| 320 |
+
self.batched_stride_C,
|
| 321 |
+
self.lda, self.ldb, self.ldc, self.ldd,
|
| 322 |
+
self.lda, self.ldb, self.ldc, self.ldd,
|
| 323 |
+
0, 0, 0
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size
|
| 327 |
+
|
| 328 |
+
def initialize(self):
|
| 329 |
+
launch_config = self.operation.rt_module.plan(self)
|
| 330 |
+
|
| 331 |
+
# Get the host and device workspace
|
| 332 |
+
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
|
| 333 |
+
|
| 334 |
+
if device_workspace_size > 0:
|
| 335 |
+
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
| 336 |
+
workspace_ptr = self.workspace_buffer.ptr
|
| 337 |
+
err, = cuda.cuMemsetD32(
|
| 338 |
+
workspace_ptr, 0, device_workspace_size // 4)
|
| 339 |
+
else:
|
| 340 |
+
workspace_ptr = None
|
| 341 |
+
|
| 342 |
+
device_workspace = 0
|
| 343 |
+
if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
|
| 344 |
+
# In GEMM splik-K parallel, the D pointer is redirected to the workspace
|
| 345 |
+
self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
|
| 346 |
+
elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
|
| 347 |
+
device_workspace = workspace_ptr
|
| 348 |
+
|
| 349 |
+
self.get_arguments()
|
| 350 |
+
|
| 351 |
+
arguments, grid_tiled_shape, gemm_k_size = self.arguments
|
| 352 |
+
res_arg = self.operation.rt_module.get_args(
|
| 353 |
+
ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace)))
|
| 354 |
+
host_workspace = bytearray(res_arg.contents)
|
| 355 |
+
|
| 356 |
+
device_workspace = None
|
| 357 |
+
|
| 358 |
+
self.host_workspace = host_workspace
|
| 359 |
+
self.device_workspace = device_workspace
|
| 360 |
+
self.launch_config = launch_config
|
| 361 |
+
|
| 362 |
+
def sync(self, stream_sync=True):
|
| 363 |
+
super().sync(stream_sync)
|
| 364 |
+
if hasattr(self.output_op, "sync"):
|
| 365 |
+
self.output_op.sync()
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class GemmArguments2xStreamK(GemmArguments2x):
|
| 369 |
+
"""
|
| 370 |
+
Argument wrapper for stream-K GEMMs in CUTLASS 2. It encodes problem information and
|
| 371 |
+
user-provide tensors into the kernel's argument
|
| 372 |
+
|
| 373 |
+
:param operation: the GEMM operation to take the argument
|
| 374 |
+
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
| 375 |
+
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
| 376 |
+
|
| 377 |
+
:param problem_size: GEMM problem size gemm(M, N, K)
|
| 378 |
+
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
| 379 |
+
|
| 380 |
+
:param A: tensor A
|
| 381 |
+
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 382 |
+
|
| 383 |
+
:param B: tensor B
|
| 384 |
+
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 385 |
+
|
| 386 |
+
:param C: tensor C
|
| 387 |
+
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 388 |
+
|
| 389 |
+
:param D: tensor D
|
| 390 |
+
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 391 |
+
|
| 392 |
+
:param gemm_mode: GEMM mode
|
| 393 |
+
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
|
| 394 |
+
|
| 395 |
+
:param output_op: output operator, optional
|
| 396 |
+
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
| 400 |
+
if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]:
|
| 401 |
+
raise Exception(f"Unsupported GEMM mode {gemm_mode}.")
|
| 402 |
+
|
| 403 |
+
super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
|
| 404 |
+
|
| 405 |
+
def get_arguments(self):
|
| 406 |
+
batch_stride_A = self.problem_size.m * self.problem_size.k
|
| 407 |
+
batch_stride_B = self.problem_size.k * self.problem_size.n
|
| 408 |
+
batch_stride_C = self.problem_size.m * self.problem_size.n
|
| 409 |
+
batch_stride_D = self.problem_size.m * self.problem_size.n
|
| 410 |
+
|
| 411 |
+
arguments = self.operation.argument_type(
|
| 412 |
+
self.gemm_mode,
|
| 413 |
+
GemmCoord_(self.problem_size.m, self.problem_size.n, self.problem_size.k),
|
| 414 |
+
self.batch_count,
|
| 415 |
+
self.output_op,
|
| 416 |
+
int(self.ptr_A),
|
| 417 |
+
int(self.ptr_B),
|
| 418 |
+
int(self.ptr_C),
|
| 419 |
+
int(self.ptr_D),
|
| 420 |
+
batch_stride_A,
|
| 421 |
+
batch_stride_B,
|
| 422 |
+
batch_stride_C,
|
| 423 |
+
batch_stride_D,
|
| 424 |
+
self.lda, self.ldb, self.ldc, self.ldd, # strides
|
| 425 |
+
self.lda, self.ldb, self.ldc, self.ldd,
|
| 426 |
+
-1, # avail_sms
|
| 427 |
+
)
|
| 428 |
+
return arguments
|
| 429 |
+
|
| 430 |
+
def initialize(self):
|
| 431 |
+
# Get the host and device workspace
|
| 432 |
+
device_workspace_size = self.operation.rt_module.get_device_workspace_size(
|
| 433 |
+
self,
|
| 434 |
+
device_sm_count(),
|
| 435 |
+
self.operation.rt_module.occupancy
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
if device_workspace_size > 0:
|
| 439 |
+
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
| 440 |
+
workspace_ptr = self.workspace_buffer.ptr
|
| 441 |
+
err, = cuda.cuMemsetD32(
|
| 442 |
+
workspace_ptr, 0, device_workspace_size // 4)
|
| 443 |
+
else:
|
| 444 |
+
workspace_ptr = None
|
| 445 |
+
|
| 446 |
+
device_workspace = 0
|
| 447 |
+
if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
|
| 448 |
+
# In GEMM splik-K parallel, the D pointer is redirected to the workspace
|
| 449 |
+
self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
|
| 450 |
+
elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
|
| 451 |
+
device_workspace = workspace_ptr
|
| 452 |
+
|
| 453 |
+
arguments = self.get_arguments()
|
| 454 |
+
|
| 455 |
+
res_arg = self.operation.rt_module.get_args(
|
| 456 |
+
ctypes.byref(arguments),
|
| 457 |
+
ctypes.c_void_p(int(device_workspace)),
|
| 458 |
+
device_sm_count(),
|
| 459 |
+
self.operation.rt_module.occupancy
|
| 460 |
+
)
|
| 461 |
+
host_workspace = bytearray(res_arg.contents)
|
| 462 |
+
|
| 463 |
+
grid = self.operation.rt_module.get_grid_shape(
|
| 464 |
+
ctypes.byref(arguments),
|
| 465 |
+
device_sm_count(),
|
| 466 |
+
self.operation.rt_module.occupancy
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
device_workspace = None
|
| 470 |
+
|
| 471 |
+
self.host_workspace = host_workspace
|
| 472 |
+
self.device_workspace = device_workspace
|
| 473 |
+
self.launch_config = LaunchConfiguration(
|
| 474 |
+
[grid.m, grid.n, grid.k],
|
| 475 |
+
[self.operation.rt_module.threads, 1, 1],
|
| 476 |
+
self.operation.rt_module.shared_memory_capacity
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
class GemmArguments3x(GemmArguments2x):
|
| 481 |
+
"""
|
| 482 |
+
Argument wrapper for GEMM in CUTLASS 3. It encodes problem information and
|
| 483 |
+
user-provide tensors into the kernel's argument
|
| 484 |
+
|
| 485 |
+
:param operation: the GEMM operation to take the argument
|
| 486 |
+
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
| 487 |
+
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
| 488 |
+
|
| 489 |
+
:param problem_size: GEMM problem size gemm(M, N, K)
|
| 490 |
+
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
| 491 |
+
|
| 492 |
+
:param A: tensor A
|
| 493 |
+
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 494 |
+
|
| 495 |
+
:param B: tensor B
|
| 496 |
+
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 497 |
+
|
| 498 |
+
:param C: tensor C
|
| 499 |
+
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 500 |
+
|
| 501 |
+
:param D: tensor D
|
| 502 |
+
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 503 |
+
|
| 504 |
+
:param gemm_mode: GEMM mode
|
| 505 |
+
:type gemm_mode: GemmUniversalMode
|
| 506 |
+
|
| 507 |
+
:param output_op: output operator, optional
|
| 508 |
+
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
| 512 |
+
if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]:
|
| 513 |
+
raise Exception(f"Unsupported GEMM mode {gemm_mode}.")
|
| 514 |
+
|
| 515 |
+
super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
|
| 516 |
+
|
| 517 |
+
def get_arguments(self):
|
| 518 |
+
mainloop_args = get_mainloop_arguments_3x(
|
| 519 |
+
self.operation.tile_description.kernel_schedule,
|
| 520 |
+
self.operation.A.element,
|
| 521 |
+
self.operation.B.element,
|
| 522 |
+
self.operation.A.alignment,
|
| 523 |
+
self.operation.B.alignment
|
| 524 |
+
)
|
| 525 |
+
scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler)
|
| 526 |
+
uses_default_epilogue = self.operation.rt_module.uses_default_epilogue()
|
| 527 |
+
argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x(
|
| 528 |
+
mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue)
|
| 529 |
+
|
| 530 |
+
problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count)
|
| 531 |
+
|
| 532 |
+
if self.batch_count > 1:
|
| 533 |
+
bsA = self.batched_stride_A
|
| 534 |
+
bsB = self.batched_stride_B
|
| 535 |
+
bsC = self.batched_stride_C
|
| 536 |
+
bsD = self.batched_stride_D
|
| 537 |
+
else:
|
| 538 |
+
bsA = 0
|
| 539 |
+
bsB = 0
|
| 540 |
+
bsC = 0
|
| 541 |
+
bsD = 0
|
| 542 |
+
stride_A = StrideBatched_(self.lda, bsA)
|
| 543 |
+
stride_B = StrideBatched_(self.ldb, bsB)
|
| 544 |
+
stride_C = StrideBatched_(self.ldc, bsC)
|
| 545 |
+
stride_D = StrideBatched_(self.ldd, bsD)
|
| 546 |
+
|
| 547 |
+
# Superset of potential mainloop arguments
|
| 548 |
+
generic_args = GenericMainloopArguments3x_(
|
| 549 |
+
int(self.ptr_A),
|
| 550 |
+
stride_A,
|
| 551 |
+
int(self.ptr_B),
|
| 552 |
+
stride_B,
|
| 553 |
+
4 # mma_promotion_interval
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
# Set of mainloop arguments needed for this kernel
|
| 557 |
+
mainloop = mainloop_args.from_generic_mainloop_args(generic_args)
|
| 558 |
+
|
| 559 |
+
if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"):
|
| 560 |
+
self.output_op = self.output_op.to_evt_params()
|
| 561 |
+
|
| 562 |
+
epilogue = epilogue_args(
|
| 563 |
+
self.output_op,
|
| 564 |
+
int(self.ptr_C),
|
| 565 |
+
stride_C,
|
| 566 |
+
int(self.ptr_D),
|
| 567 |
+
stride_D,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# Set hardware info
|
| 571 |
+
hw_info_ = hw_info(
|
| 572 |
+
0, device_sm_count(), 0,
|
| 573 |
+
dim3_(0,0,0),
|
| 574 |
+
dim3_(0,0,0),
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
self.arguments = argument_type(
|
| 578 |
+
int(self.gemm_mode),
|
| 579 |
+
problem_size_,
|
| 580 |
+
mainloop,
|
| 581 |
+
epilogue,
|
| 582 |
+
hw_info_,
|
| 583 |
+
scheduler_args
|
| 584 |
+
)
|
| 585 |
+
return self.arguments
|
| 586 |
+
|
| 587 |
+
def initialize(self):
|
| 588 |
+
# Get the host and evice workspace
|
| 589 |
+
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
|
| 590 |
+
|
| 591 |
+
if device_workspace_size > 0:
|
| 592 |
+
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
| 593 |
+
workspace_ptr = self.workspace_buffer.ptr
|
| 594 |
+
err, = cuda.cuMemsetD32(
|
| 595 |
+
workspace_ptr, 0, device_workspace_size // 4)
|
| 596 |
+
else:
|
| 597 |
+
workspace_ptr = None
|
| 598 |
+
|
| 599 |
+
device_workspace = 0
|
| 600 |
+
if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
|
| 601 |
+
# In GEMM splik-K parallel, the D pointer is redirected to the workspace
|
| 602 |
+
self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
|
| 603 |
+
elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
|
| 604 |
+
device_workspace = workspace_ptr
|
| 605 |
+
|
| 606 |
+
self.get_arguments()
|
| 607 |
+
res_arg = self.operation.rt_module.get_args(
|
| 608 |
+
ctypes.byref(self.arguments),
|
| 609 |
+
ctypes.c_void_p(int(device_workspace)),
|
| 610 |
+
)
|
| 611 |
+
host_workspace = bytearray(res_arg.contents)
|
| 612 |
+
|
| 613 |
+
grid = self.operation.rt_module.get_grid_shape(
|
| 614 |
+
ctypes.byref(self.arguments),
|
| 615 |
+
ctypes.c_void_p(int(device_workspace)),
|
| 616 |
+
)
|
| 617 |
+
block = self.operation.rt_module.get_block_shape()
|
| 618 |
+
|
| 619 |
+
device_workspace = None
|
| 620 |
+
|
| 621 |
+
self.host_workspace = host_workspace
|
| 622 |
+
self.device_workspace = device_workspace
|
| 623 |
+
self.launch_config = LaunchConfiguration(
|
| 624 |
+
[grid.x, grid.y, grid.z],
|
| 625 |
+
[block.x, block.y, block.z],
|
| 626 |
+
self.operation.rt_module.shared_memory_capacity,
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
| 631 |
+
"""
|
| 632 |
+
Argument wrapper for GEMM in CUTLASS 2 or 3. It returns either 2x arguments
|
| 633 |
+
or 3x arguments depending on the `arch` field specified in `operation`.
|
| 634 |
+
|
| 635 |
+
:param operation: the GEMM operation to take the argument
|
| 636 |
+
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
| 637 |
+
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
| 638 |
+
|
| 639 |
+
:param problem_size: GEMM problem size gemm(M, N, K)
|
| 640 |
+
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
| 641 |
+
|
| 642 |
+
:param A: tensor A
|
| 643 |
+
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 644 |
+
|
| 645 |
+
:param B: tensor B
|
| 646 |
+
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 647 |
+
|
| 648 |
+
:param C: tensor C
|
| 649 |
+
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 650 |
+
|
| 651 |
+
:param D: tensor D
|
| 652 |
+
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 653 |
+
|
| 654 |
+
:param gemm_mode: GEMM mode
|
| 655 |
+
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
|
| 656 |
+
|
| 657 |
+
:param output_op: output operator, optional
|
| 658 |
+
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
| 659 |
+
"""
|
| 660 |
+
if operation.swizzling_functor == SwizzlingFunctor.StreamK:
|
| 661 |
+
if operation.api == ApiVersion.v3x:
|
| 662 |
+
raise Exception("Stream K is currently only supported in CUTLASS 2.x")
|
| 663 |
+
ArgClass = GemmArguments2xStreamK
|
| 664 |
+
else:
|
| 665 |
+
ArgClass = GemmArguments3x if operation.api == ApiVersion.v3x else GemmArguments2x
|
| 666 |
+
return ArgClass(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
class GemmGroupedArguments:
|
| 670 |
+
"""
|
| 671 |
+
Argument wrapper for GEMM Grouped. It encodes problem information and
|
| 672 |
+
user-provide tensors into the kernel's argument
|
| 673 |
+
|
| 674 |
+
:param operation: the GEMM Grouped operation to take the argument
|
| 675 |
+
:type operation: :class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
| 676 |
+
|
| 677 |
+
:param problem_size: list of GEMM problem size gemm(M, N, K)
|
| 678 |
+
:type operation: list[:class:`cutlass_cppgen.shape.GemmCoord`]
|
| 679 |
+
|
| 680 |
+
:param A: list of tensor A
|
| 681 |
+
:type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
|
| 682 |
+
|
| 683 |
+
:param B: list of tensor B
|
| 684 |
+
:type B: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
|
| 685 |
+
|
| 686 |
+
:param C: list of tensor C
|
| 687 |
+
:type C: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
|
| 688 |
+
|
| 689 |
+
:param D: list of tensor D
|
| 690 |
+
:type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
|
| 691 |
+
|
| 692 |
+
:param output_op: output operator, optional
|
| 693 |
+
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
| 694 |
+
|
| 695 |
+
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
| 696 |
+
:type stream: :class:`cuda.cuda.CUstream`
|
| 697 |
+
"""
|
| 698 |
+
|
| 699 |
+
def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs):
|
| 700 |
+
# Get number of problems in the group
|
| 701 |
+
self.problem_count = len(problem_sizes)
|
| 702 |
+
|
| 703 |
+
# Check the input arguments
|
| 704 |
+
assert len(A) == self.problem_count
|
| 705 |
+
assert len(B) == self.problem_count
|
| 706 |
+
assert len(C) == self.problem_count
|
| 707 |
+
assert len(D) == self.problem_count
|
| 708 |
+
|
| 709 |
+
problem_size_host = []
|
| 710 |
+
self.ptr_A_host = []
|
| 711 |
+
self.ptr_B_host = []
|
| 712 |
+
self.ptr_C_host = []
|
| 713 |
+
self.ptr_D_host = []
|
| 714 |
+
|
| 715 |
+
lda_host = []
|
| 716 |
+
ldb_host = []
|
| 717 |
+
ldc_host = []
|
| 718 |
+
ldd_host = []
|
| 719 |
+
|
| 720 |
+
self.partitions = 1
|
| 721 |
+
|
| 722 |
+
self.operation = operation
|
| 723 |
+
|
| 724 |
+
# Get the threadblock
|
| 725 |
+
threadblock_shape = operation.tile_description.threadblock_shape
|
| 726 |
+
self.threadblock_shape = GemmCoord(
|
| 727 |
+
threadblock_shape[0],
|
| 728 |
+
threadblock_shape[1],
|
| 729 |
+
threadblock_shape[2],
|
| 730 |
+
)
|
| 731 |
+
self.threadblock_swizzle = operation.swizzling_functor
|
| 732 |
+
|
| 733 |
+
self.total_tiles = 0
|
| 734 |
+
|
| 735 |
+
self.gemm_arguments = []
|
| 736 |
+
|
| 737 |
+
self.stream = kwargs.get("stream", cuda.CUstream(0))
|
| 738 |
+
|
| 739 |
+
# Process the input arguments
|
| 740 |
+
for idx, problem_size in enumerate(problem_sizes):
|
| 741 |
+
M, N, K = problem_size.m, problem_size.n, problem_size.k
|
| 742 |
+
temp_argument = GemmArguments2x(
|
| 743 |
+
operation=operation,
|
| 744 |
+
problem_size=GemmCoord(M, N, K),
|
| 745 |
+
A=A[idx], B=B[idx], C=C[idx], D=D[idx])
|
| 746 |
+
self.gemm_arguments.append(temp_argument)
|
| 747 |
+
|
| 748 |
+
problem_size_host.append(
|
| 749 |
+
[temp_argument.problem_size.m,
|
| 750 |
+
temp_argument.problem_size.n,
|
| 751 |
+
temp_argument.problem_size.k]
|
| 752 |
+
)
|
| 753 |
+
|
| 754 |
+
self.ptr_A_host.append(int(temp_argument.ptr_A))
|
| 755 |
+
lda_host.append(temp_argument.lda)
|
| 756 |
+
|
| 757 |
+
self.ptr_B_host.append(int(temp_argument.ptr_B))
|
| 758 |
+
ldb_host.append(temp_argument.ldb)
|
| 759 |
+
|
| 760 |
+
self.ptr_C_host.append(int(temp_argument.ptr_C))
|
| 761 |
+
ldc_host.append(temp_argument.ldc)
|
| 762 |
+
|
| 763 |
+
self.ptr_D_host.append(int(temp_argument.ptr_D))
|
| 764 |
+
ldd_host.append(temp_argument.ldd)
|
| 765 |
+
|
| 766 |
+
# Get number of tiles
|
| 767 |
+
grid = self.operation.rt_module.get_grid_shape(
|
| 768 |
+
self.operation.rt_module.get_tiled_shape(
|
| 769 |
+
temp_argument.problem_size.ctype,
|
| 770 |
+
self.threadblock_shape.ctype,
|
| 771 |
+
temp_argument.batch_count
|
| 772 |
+
)
|
| 773 |
+
)
|
| 774 |
+
self.total_tiles += grid.x * grid.y * grid.z
|
| 775 |
+
|
| 776 |
+
self.problem_size_buffer = todevice(problem_size_host, np.int32)
|
| 777 |
+
self.ptr_A_buffer = todevice(self.ptr_A_host, np.int64)
|
| 778 |
+
self.ptr_B_buffer = todevice(self.ptr_B_host, np.int64)
|
| 779 |
+
self.ptr_C_buffer = todevice(self.ptr_C_host, np.int64)
|
| 780 |
+
self.ptr_D_buffer = todevice(self.ptr_D_host, np.int64)
|
| 781 |
+
|
| 782 |
+
self.lda_buffer = todevice(lda_host, np.int64)
|
| 783 |
+
self.ldb_buffer = todevice(ldb_host, np.int64)
|
| 784 |
+
self.ldc_buffer = todevice(ldc_host, np.int64)
|
| 785 |
+
self.ldd_buffer = todevice(ldd_host, np.int64)
|
| 786 |
+
|
| 787 |
+
if "output_op" in kwargs.keys():
|
| 788 |
+
self.alpha = kwargs["output_op"].alpha
|
| 789 |
+
self.beta = kwargs["output_op"].beta
|
| 790 |
+
else:
|
| 791 |
+
self.alpha = 1.0
|
| 792 |
+
self.beta = 0.0
|
| 793 |
+
|
| 794 |
+
if "output_op" in kwargs.keys():
|
| 795 |
+
self.output_op = kwargs["output_op"]
|
| 796 |
+
else:
|
| 797 |
+
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
| 798 |
+
|
| 799 |
+
# Get host problem size
|
| 800 |
+
self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0]
|
| 801 |
+
|
| 802 |
+
self.arguments = self.get_arguments()
|
| 803 |
+
|
| 804 |
+
self.initialize()
|
| 805 |
+
|
| 806 |
+
def get_arguments(self):
|
| 807 |
+
return self.operation.argument_type(
|
| 808 |
+
self.problem_size_buffer.ptr,
|
| 809 |
+
self.problem_count,
|
| 810 |
+
self.total_tiles,
|
| 811 |
+
self.output_op,
|
| 812 |
+
self.ptr_A_buffer.ptr,
|
| 813 |
+
self.ptr_B_buffer.ptr,
|
| 814 |
+
self.ptr_C_buffer.ptr,
|
| 815 |
+
self.ptr_D_buffer.ptr,
|
| 816 |
+
self.lda_buffer.ptr,
|
| 817 |
+
self.ldb_buffer.ptr,
|
| 818 |
+
self.ldc_buffer.ptr,
|
| 819 |
+
self.ldd_buffer.ptr,
|
| 820 |
+
ctypes.c_void_p(int(self.host_problem_size_ptr)),
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
def initialize(self):
|
| 824 |
+
# Get launch configuration
|
| 825 |
+
launch_config = self.operation.rt_module.plan(self)
|
| 826 |
+
|
| 827 |
+
# Get the host and evice workspace
|
| 828 |
+
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
|
| 829 |
+
|
| 830 |
+
if device_workspace_size > 0:
|
| 831 |
+
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
| 832 |
+
workspace_ptr = self.workspace_buffer.ptr
|
| 833 |
+
err, = cuda.cuMemsetD32(
|
| 834 |
+
workspace_ptr, 0, device_workspace_size // 4)
|
| 835 |
+
else:
|
| 836 |
+
workspace_ptr = None
|
| 837 |
+
|
| 838 |
+
if self.operation.precompute_mode == SchedulerMode.Host:
|
| 839 |
+
device_workspace_ptr = self.operation.rt_module.host_precompute(
|
| 840 |
+
self, self.operation.rt_module.get_workspace_size(self),)
|
| 841 |
+
else:
|
| 842 |
+
device_workspace_ptr = 0
|
| 843 |
+
|
| 844 |
+
result = self.operation.rt_module.get_args(
|
| 845 |
+
ctypes.byref(self.arguments),
|
| 846 |
+
self.total_tiles,
|
| 847 |
+
ctypes.c_void_p(int(device_workspace_ptr)),
|
| 848 |
+
)
|
| 849 |
+
host_workspace = bytearray(result.contents)
|
| 850 |
+
|
| 851 |
+
device_workspace = None
|
| 852 |
+
|
| 853 |
+
self.host_workspace = host_workspace
|
| 854 |
+
self.device_workspace = device_workspace
|
| 855 |
+
self.launch_config = launch_config
|
| 856 |
+
|
| 857 |
+
def sync(self):
|
| 858 |
+
err, = cudart.cudaDeviceSynchronize()
|
| 859 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 860 |
+
raise RuntimeError("CUDA Error %s" % str(err))
|
| 861 |
+
for arg in self.gemm_arguments:
|
| 862 |
+
arg.sync(stream_sync=False)
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
################################################################################
|
| 866 |
+
# Base class for GEMM runtime module
|
| 867 |
+
################################################################################
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
class GemmRTbase(ExecutableOperation):
|
| 871 |
+
"""
|
| 872 |
+
GemmRT manages the CUTLASS runtime components
|
| 873 |
+
"""
|
| 874 |
+
|
| 875 |
+
KernelTemplate = r"""
|
| 876 |
+
extern "C"
|
| 877 |
+
__global__ void
|
| 878 |
+
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
| 879 |
+
|
| 880 |
+
// Dynamic shared memory base pointer
|
| 881 |
+
extern __shared__ int SharedStorageBase[];
|
| 882 |
+
|
| 883 |
+
// Declare pointer to dynamic shared memory.
|
| 884 |
+
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
| 885 |
+
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
| 886 |
+
|
| 887 |
+
${operation_name}${operation_suffix}::invoke(params, *shared_storage);
|
| 888 |
+
}
|
| 889 |
+
"""
|
| 890 |
+
|
| 891 |
+
def __init__(self, operation: "GemmOperation"):
|
| 892 |
+
super().__init__(operation)
|
| 893 |
+
|
| 894 |
+
self.operation = operation
|
| 895 |
+
threadblock_shape = operation.tile_description.threadblock_shape
|
| 896 |
+
self.threadblock_shape = GemmCoord(
|
| 897 |
+
threadblock_shape[0], threadblock_shape[1], threadblock_shape[2])
|
| 898 |
+
self.threadblock_swizzle = operation.swizzling_functor
|
| 899 |
+
|
| 900 |
+
# Threads per threadblock
|
| 901 |
+
self.threads = operation.tile_description.num_threads
|
| 902 |
+
|
| 903 |
+
def emit(self):
|
| 904 |
+
return self.emitter.emit(self.operation)
|
| 905 |
+
|
| 906 |
+
def can_implement(self, configuration, arguments):
|
| 907 |
+
raise NotImplementedError()
|
| 908 |
+
|
| 909 |
+
def get_host_workspace_size(self, arguments):
|
| 910 |
+
raise NotImplementedError()
|
| 911 |
+
|
| 912 |
+
def get_device_workspace_size(self, arguments):
|
| 913 |
+
return 0
|
| 914 |
+
|
| 915 |
+
def initialize(self):
|
| 916 |
+
err, = cuda.cuFuncSetAttribute(
|
| 917 |
+
self.kernel,
|
| 918 |
+
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
| 919 |
+
value=self.shared_memory_capacity)
|
| 920 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 921 |
+
raise RuntimeError(
|
| 922 |
+
f"CUDA error on call to cuFuncSetAttribute: {cuda.cuGetErrorString(err)[1]}"
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
################################################################################
|
| 927 |
+
# Runtime module for GEMM Universal
|
| 928 |
+
################################################################################
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
class GemmRTUniversal(GemmRTbase):
|
| 932 |
+
"""
|
| 933 |
+
GemmRTUniversal manages the CUTLASS runtime components
|
| 934 |
+
"""
|
| 935 |
+
|
| 936 |
+
HostTemplate = r"""
|
| 937 |
+
extern "C" {
|
| 938 |
+
// Get the size of params in bytes
|
| 939 |
+
int ${operation_name}_get_param_size(){
|
| 940 |
+
return sizeof(${operation_name}${operation_suffix}::Params);
|
| 941 |
+
}
|
| 942 |
+
|
| 943 |
+
// Get the size of dynamic shared memory in bytes
|
| 944 |
+
int ${operation_name}_shared_memory_size() {
|
| 945 |
+
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
| 946 |
+
}
|
| 947 |
+
|
| 948 |
+
// Get the params as byte array
|
| 949 |
+
char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){
|
| 950 |
+
${operation_name}_base::Params* params;
|
| 951 |
+
params = new ${operation_name}_base::Params(*argument,
|
| 952 |
+
-1, // SM count. Only used for stream-K
|
| 953 |
+
-1 // Occupancy. Only used for stream-K
|
| 954 |
+
);
|
| 955 |
+
|
| 956 |
+
// Semaphore holds the pointer to the workspace in the Params struct
|
| 957 |
+
params->semaphore = workspace;
|
| 958 |
+
|
| 959 |
+
char *bytes = ((char*)(params));
|
| 960 |
+
char *output = new char[sizeof(${operation_name}_base::Params)];
|
| 961 |
+
for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++)
|
| 962 |
+
output[i] = bytes[i];
|
| 963 |
+
|
| 964 |
+
return output;
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape(
|
| 968 |
+
cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) {
|
| 969 |
+
return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape(
|
| 970 |
+
problem_size, tile_size, split_k_slices);
|
| 971 |
+
}
|
| 972 |
+
|
| 973 |
+
dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) {
|
| 974 |
+
return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape);
|
| 975 |
+
}
|
| 976 |
+
}
|
| 977 |
+
"""
|
| 978 |
+
|
| 979 |
+
def __init__(self, operation):
|
| 980 |
+
super(GemmRTUniversal, self).__init__(operation)
|
| 981 |
+
self.extra_funcs = {
|
| 982 |
+
"get_tiled_shape": GemmCoord_,
|
| 983 |
+
"get_grid_shape": dim3_,
|
| 984 |
+
}
|
| 985 |
+
self.emitter = EmitGemmUniversalInstance(
|
| 986 |
+
"_type", operation.direct_store)
|
| 987 |
+
|
| 988 |
+
self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor)
|
| 989 |
+
self.argtype = [
|
| 990 |
+
ctypes.POINTER(self.argument_type),
|
| 991 |
+
ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p
|
| 992 |
+
]
|
| 993 |
+
|
| 994 |
+
def plan(self, arguments):
|
| 995 |
+
grid = self.get_tiled_shape(
|
| 996 |
+
arguments.problem_size.ctype,
|
| 997 |
+
self.threadblock_shape.ctype,
|
| 998 |
+
arguments.batch_count
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
gemm_k_size = arguments.problem_size.k
|
| 1002 |
+
if arguments.gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]:
|
| 1003 |
+
alignk = max(max(128 // DataTypeSize[self.operation.A.element],
|
| 1004 |
+
128 // DataTypeSize[self.operation.B.element]), 1)
|
| 1005 |
+
|
| 1006 |
+
gemm_k_size = (((arguments.problem_size.k + arguments.batch_count - 1) //
|
| 1007 |
+
arguments.batch_count + alignk - 1) // alignk) * alignk
|
| 1008 |
+
|
| 1009 |
+
if gemm_k_size:
|
| 1010 |
+
grid_z = (arguments.problem_size.k + gemm_k_size - 1) // gemm_k_size
|
| 1011 |
+
grid = GemmCoord(grid.m, grid.n, grid_z).ctype
|
| 1012 |
+
|
| 1013 |
+
arguments.grid_tiled_shape = dim3_(grid.m, grid.n, grid.k)
|
| 1014 |
+
grid = self.get_grid_shape(grid)
|
| 1015 |
+
arguments.gemm_k_size = gemm_k_size
|
| 1016 |
+
return LaunchConfiguration(
|
| 1017 |
+
[grid.x, grid.y, grid.z],
|
| 1018 |
+
[self.threads, 1, 1],
|
| 1019 |
+
self.shared_memory_capacity)
|
| 1020 |
+
|
| 1021 |
+
def get_device_workspace_size(self, arguments: GemmArguments):
|
| 1022 |
+
workspace_bytes = 0
|
| 1023 |
+
if arguments.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
|
| 1024 |
+
workspace_bytes = (DataTypeSize[arguments.operation.C.element]
|
| 1025 |
+
* arguments.batched_stride_D * arguments.grid_tiled_shape.z // 8)
|
| 1026 |
+
elif (arguments.gemm_mode == GemmUniversalMode.Gemm and
|
| 1027 |
+
arguments.split_k_slices > 1):
|
| 1028 |
+
workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y
|
| 1029 |
+
|
| 1030 |
+
return workspace_bytes
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
class GemmRTUniversalStreamK(GemmRTUniversal):
|
| 1034 |
+
"""
|
| 1035 |
+
Manages the CUTLASS runtime components for 2.x stream K kernels
|
| 1036 |
+
"""
|
| 1037 |
+
|
| 1038 |
+
HostTemplate = r"""
|
| 1039 |
+
extern "C" {
|
| 1040 |
+
// Get the size of params in bytes
|
| 1041 |
+
int ${operation_name}_get_param_size(){
|
| 1042 |
+
return sizeof(${operation_name}${operation_suffix}::Params);
|
| 1043 |
+
}
|
| 1044 |
+
|
| 1045 |
+
// Get the size of dynamic shared memory in bytes
|
| 1046 |
+
int ${operation_name}_shared_memory_size() {
|
| 1047 |
+
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
| 1048 |
+
}
|
| 1049 |
+
|
| 1050 |
+
using GemmType = ${operation_name}_base;
|
| 1051 |
+
|
| 1052 |
+
// Get the params as byte array
|
| 1053 |
+
char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace,
|
| 1054 |
+
int sm_count, int occupancy) {
|
| 1055 |
+
GemmType::Params* params;
|
| 1056 |
+
params = new GemmType::Params(*argument, sm_count, occupancy);
|
| 1057 |
+
|
| 1058 |
+
params->init_workspace(workspace);
|
| 1059 |
+
|
| 1060 |
+
char *bytes = ((char*)(params));
|
| 1061 |
+
char *output = new char[sizeof(GemmType::Params)];
|
| 1062 |
+
for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++)
|
| 1063 |
+
output[i] = bytes[i];
|
| 1064 |
+
|
| 1065 |
+
return output;
|
| 1066 |
+
}
|
| 1067 |
+
|
| 1068 |
+
dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int device_sms, int sm_occupancy) {
|
| 1069 |
+
typename GemmType::Params params(*args, device_sms, sm_occupancy);
|
| 1070 |
+
return params.get_grid_dims();
|
| 1071 |
+
}
|
| 1072 |
+
|
| 1073 |
+
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* args, int device_sms, int sm_occupancy) {
|
| 1074 |
+
typename GemmType::Params params(*args, device_sms, sm_occupancy);
|
| 1075 |
+
return params.get_workspace_size();
|
| 1076 |
+
}
|
| 1077 |
+
}
|
| 1078 |
+
"""
|
| 1079 |
+
|
| 1080 |
+
def __init__(self, operation: "GemmOperation"):
|
| 1081 |
+
super(GemmRTUniversalStreamK, self).__init__(operation)
|
| 1082 |
+
self.extra_funcs = {
|
| 1083 |
+
"get_grid_shape": GemmCoord_,
|
| 1084 |
+
"get_kernel_workspace_size": ctypes.c_uint64,
|
| 1085 |
+
}
|
| 1086 |
+
self._occupancy = None
|
| 1087 |
+
self.argument_type, self.epilogue_type = get_gemm_arguments_streamk(operation.epilogue_functor)
|
| 1088 |
+
|
| 1089 |
+
@property
|
| 1090 |
+
def occupancy(self):
|
| 1091 |
+
if self._occupancy is None:
|
| 1092 |
+
err, self._occupancy = cuda.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
|
| 1093 |
+
self.kernel, self.threads, self.shared_memory_capacity,
|
| 1094 |
+
cuda.CUoccupancy_flags.CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE)
|
| 1095 |
+
|
| 1096 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 1097 |
+
raise RuntimeError(
|
| 1098 |
+
"CUDA error on call to cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: "
|
| 1099 |
+
f"{cuda.cuGetErrorString(err)[1]}")
|
| 1100 |
+
return self._occupancy
|
| 1101 |
+
|
| 1102 |
+
def get_device_workspace_size(self, arguments: GemmArguments2xStreamK, device_sms: int, sm_occupancy: int):
|
| 1103 |
+
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()), device_sms, sm_occupancy)
|
| 1104 |
+
|
| 1105 |
+
|
| 1106 |
+
################################################################################
|
| 1107 |
+
# Runtime module for GEMM Universal within CUTLASS 3
|
| 1108 |
+
################################################################################
|
| 1109 |
+
|
| 1110 |
+
|
| 1111 |
+
class GemmRTUniversal3x(GemmRTUniversal):
|
| 1112 |
+
"""
|
| 1113 |
+
Manages the CUTLASS runtime components for 3.x kernels
|
| 1114 |
+
"""
|
| 1115 |
+
|
| 1116 |
+
KernelTemplate = r"""
|
| 1117 |
+
|
| 1118 |
+
using Operator = ${operation_name}${operation_suffix};
|
| 1119 |
+
extern "C"
|
| 1120 |
+
__global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor)
|
| 1121 |
+
void ${operation_name}(__grid_constant__ typename Operator::Params const params) {
|
| 1122 |
+
// Dynamic shared memory base pointer
|
| 1123 |
+
extern __shared__ char smem[];
|
| 1124 |
+
|
| 1125 |
+
// Declare pointer to dynamic shared memory.
|
| 1126 |
+
Operator op;
|
| 1127 |
+
op(params, smem);
|
| 1128 |
+
}
|
| 1129 |
+
"""
|
| 1130 |
+
HostTemplate = r"""
|
| 1131 |
+
extern "C" {
|
| 1132 |
+
// Get the size of params in bytes
|
| 1133 |
+
int ${operation_name}_get_param_size(){
|
| 1134 |
+
return sizeof(${operation_name}${operation_suffix}::Params);
|
| 1135 |
+
}
|
| 1136 |
+
|
| 1137 |
+
// Get the size of dynamic shared memory in bytes
|
| 1138 |
+
int ${operation_name}_shared_memory_size() {
|
| 1139 |
+
return ${operation_name}${operation_suffix}::SharedStorageSize;
|
| 1140 |
+
}
|
| 1141 |
+
|
| 1142 |
+
using GemmType = ${operation_name}_base;
|
| 1143 |
+
|
| 1144 |
+
bool ${operation_name}_uses_default_epilogue() {
|
| 1145 |
+
return std::is_same_v<GemmType::CollectiveEpilogue::DispatchPolicy, cutlass::gemm::EpilogueDefault>;
|
| 1146 |
+
}
|
| 1147 |
+
|
| 1148 |
+
// Get the workspace size
|
| 1149 |
+
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
|
| 1150 |
+
return GemmType::get_workspace_size(*argument);
|
| 1151 |
+
}
|
| 1152 |
+
|
| 1153 |
+
// Get the params as byte array
|
| 1154 |
+
char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){
|
| 1155 |
+
GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace);
|
| 1156 |
+
char *bytes = ((char*)(¶ms));
|
| 1157 |
+
char *output = new char[sizeof(GemmType::Params)];
|
| 1158 |
+
for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++)
|
| 1159 |
+
output[i] = bytes[i];
|
| 1160 |
+
|
| 1161 |
+
return output;
|
| 1162 |
+
}
|
| 1163 |
+
|
| 1164 |
+
// Get the total number of blocks for a persistent kernel
|
| 1165 |
+
uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) {
|
| 1166 |
+
auto problem_shape_MNKL = append<4>(problem, Int<1>{});
|
| 1167 |
+
auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] =
|
| 1168 |
+
cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl(
|
| 1169 |
+
problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{});
|
| 1170 |
+
return problem_blocks_m * problem_blocks_n * problem_blocks_l;
|
| 1171 |
+
}
|
| 1172 |
+
|
| 1173 |
+
// Get the grid shape
|
| 1174 |
+
dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int* workspace) {
|
| 1175 |
+
auto tmp_params = GemmType::to_underlying_arguments(*args, workspace);
|
| 1176 |
+
return GemmType::get_grid_shape(tmp_params);
|
| 1177 |
+
}
|
| 1178 |
+
|
| 1179 |
+
// Get the block shape
|
| 1180 |
+
dim3 ${operation_name}_get_block_shape() {
|
| 1181 |
+
return GemmType::get_block_shape();
|
| 1182 |
+
}
|
| 1183 |
+
}
|
| 1184 |
+
"""
|
| 1185 |
+
|
| 1186 |
+
def __init__(self, operation):
|
| 1187 |
+
super(GemmRTUniversal3x, self).__init__(operation)
|
| 1188 |
+
self.extra_funcs = {
|
| 1189 |
+
"get_grid_shape": dim3_,
|
| 1190 |
+
"get_block_shape": dim3_,
|
| 1191 |
+
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
|
| 1192 |
+
"get_kernel_workspace_size": ctypes.c_uint64,
|
| 1193 |
+
"uses_default_epilogue": ctypes.c_bool,
|
| 1194 |
+
}
|
| 1195 |
+
self.emitter = EmitGemmUniversalInstance3x("_type")
|
| 1196 |
+
|
| 1197 |
+
def get_device_workspace_size(self, arguments: GemmArguments3x):
|
| 1198 |
+
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))
|
| 1199 |
+
|
| 1200 |
+
|
| 1201 |
+
class EmitGemmUniversalInstance3x:
|
| 1202 |
+
"""Responsible for emitting a CUTLASS 3 template definition"""
|
| 1203 |
+
|
| 1204 |
+
def __init__(self, operation_suffix=""):
|
| 1205 |
+
self.operation_suffix = operation_suffix
|
| 1206 |
+
self.includes = [
|
| 1207 |
+
"cutlass/cutlass.h",
|
| 1208 |
+
"cute/tensor.hpp",
|
| 1209 |
+
"cute/atom/mma_atom.hpp",
|
| 1210 |
+
"cutlass/numeric_types.h",
|
| 1211 |
+
"cutlass/gemm/collective/collective_builder.hpp",
|
| 1212 |
+
"cutlass/gemm/kernel/sm90_tile_scheduler.hpp",
|
| 1213 |
+
"cutlass/gemm/kernel/gemm_universal.hpp",
|
| 1214 |
+
"cutlass/epilogue/collective/collective_builder.hpp",
|
| 1215 |
+
"cutlass/epilogue/collective/default_epilogue.hpp",
|
| 1216 |
+
"cutlass/epilogue/thread/linear_combination.h"
|
| 1217 |
+
]
|
| 1218 |
+
self.gemm_template_kernel = """
|
| 1219 |
+
using namespace cute;
|
| 1220 |
+
|
| 1221 |
+
using CollectiveEpilogue =
|
| 1222 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
| 1223 |
+
${arch}, ${opcode_class},
|
| 1224 |
+
cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
|
| 1225 |
+
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
| 1226 |
+
cutlass::epilogue::collective::EpilogueTileAuto,
|
| 1227 |
+
${element_accumulator}, ${element_epilogue},
|
| 1228 |
+
${element_c}, ${layout_c}, ${align_c},
|
| 1229 |
+
${element_d}, ${layout_d}, ${align_d},
|
| 1230 |
+
${epilogue_schedule}
|
| 1231 |
+
>::CollectiveOp;
|
| 1232 |
+
|
| 1233 |
+
using CollectiveMainloop =
|
| 1234 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
| 1235 |
+
${arch}, ${opcode_class},
|
| 1236 |
+
${element_a}, ${layout_a}, ${align_a},
|
| 1237 |
+
${element_b}, ${layout_b}, ${align_b},
|
| 1238 |
+
${element_accumulator},
|
| 1239 |
+
cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
|
| 1240 |
+
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
| 1241 |
+
${stage_count_type},
|
| 1242 |
+
${kernel_schedule}
|
| 1243 |
+
>::CollectiveOp;
|
| 1244 |
+
|
| 1245 |
+
// Gemm operator ${operation_name}
|
| 1246 |
+
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
|
| 1247 |
+
Shape<int,int,int,int>,
|
| 1248 |
+
CollectiveMainloop,
|
| 1249 |
+
CollectiveEpilogue,
|
| 1250 |
+
${tile_scheduler}
|
| 1251 |
+
>;
|
| 1252 |
+
|
| 1253 |
+
// Define named type
|
| 1254 |
+
struct ${operation_name}${operation_suffix} :
|
| 1255 |
+
public ${operation_name}_base { };
|
| 1256 |
+
"""
|
| 1257 |
+
self.gemm_template_kernel_visitor = """
|
| 1258 |
+
using namespace cute;
|
| 1259 |
+
|
| 1260 |
+
${callback_decl}
|
| 1261 |
+
|
| 1262 |
+
using CollectiveEpilogue =
|
| 1263 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
| 1264 |
+
${arch}, ${opcode_class},
|
| 1265 |
+
cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
|
| 1266 |
+
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
| 1267 |
+
cutlass::epilogue::collective::EpilogueTileAuto,
|
| 1268 |
+
${element_accumulator}, ${element_epilogue},
|
| 1269 |
+
ElementC, StrideC, ${align_c},
|
| 1270 |
+
ElementD, StrideD, ${align_d},
|
| 1271 |
+
${epilogue_schedule},
|
| 1272 |
+
${callback_name}
|
| 1273 |
+
>::CollectiveOp;
|
| 1274 |
+
|
| 1275 |
+
using CollectiveMainloop =
|
| 1276 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
| 1277 |
+
${arch}, ${opcode_class},
|
| 1278 |
+
${element_a}, ${layout_a}, ${align_a},
|
| 1279 |
+
${element_b}, ${layout_b}, ${align_b},
|
| 1280 |
+
${element_accumulator},
|
| 1281 |
+
cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
|
| 1282 |
+
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
| 1283 |
+
${stage_count_type},
|
| 1284 |
+
${kernel_schedule}
|
| 1285 |
+
>::CollectiveOp;
|
| 1286 |
+
|
| 1287 |
+
// Gemm operator ${operation_name}
|
| 1288 |
+
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
|
| 1289 |
+
Shape<int,int,int,int>,
|
| 1290 |
+
CollectiveMainloop,
|
| 1291 |
+
CollectiveEpilogue,
|
| 1292 |
+
${tile_scheduler}
|
| 1293 |
+
>;
|
| 1294 |
+
|
| 1295 |
+
// Define named type
|
| 1296 |
+
struct ${operation_name}${operation_suffix} :
|
| 1297 |
+
public ${operation_name}_base { };
|
| 1298 |
+
"""
|
| 1299 |
+
|
| 1300 |
+
self.gemm_template_device = self.gemm_template_kernel + """
|
| 1301 |
+
|
| 1302 |
+
// Define device-level operator
|
| 1303 |
+
using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}${operation_suffix}>;
|
| 1304 |
+
"""
|
| 1305 |
+
|
| 1306 |
+
def emit(self, operation):
|
| 1307 |
+
# Support built-in epilogue functors or user-defined functions
|
| 1308 |
+
|
| 1309 |
+
if operation.tile_description.stages is None or operation.tile_description.stages == 0:
|
| 1310 |
+
stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>"
|
| 1311 |
+
else:
|
| 1312 |
+
stage_count_type = "_" + str(operation.tile_description.stages)
|
| 1313 |
+
|
| 1314 |
+
if operation.emission_type == EmissionType.Kernel:
|
| 1315 |
+
gemm_template = self.gemm_template_kernel
|
| 1316 |
+
else:
|
| 1317 |
+
gemm_template = self.gemm_template_device
|
| 1318 |
+
|
| 1319 |
+
kschedule = KernelScheduleType.ScheduleAuto
|
| 1320 |
+
eschedule = EpilogueScheduleType.ScheduleAuto
|
| 1321 |
+
tschedule = TileSchedulerType.Default
|
| 1322 |
+
if operation.tile_description.kernel_schedule is not None:
|
| 1323 |
+
kschedule = operation.tile_description.kernel_schedule
|
| 1324 |
+
if operation.tile_description.epilogue_schedule is not None:
|
| 1325 |
+
eschedule = operation.tile_description.epilogue_schedule
|
| 1326 |
+
if operation.tile_description.tile_scheduler is not None:
|
| 1327 |
+
tschedule = operation.tile_description.tile_scheduler
|
| 1328 |
+
|
| 1329 |
+
emit_tile_m, emit_tile_n, emit_tile_k = operation.tile_description.blackwell_threadblock_shape
|
| 1330 |
+
|
| 1331 |
+
values = {
|
| 1332 |
+
"operation_name": operation.procedural_name(),
|
| 1333 |
+
"operation_suffix": self.operation_suffix,
|
| 1334 |
+
"element_a": DataTypeTag[operation.A.element],
|
| 1335 |
+
"layout_a": LayoutTag[operation.A.layout],
|
| 1336 |
+
"element_b": DataTypeTag[operation.B.element],
|
| 1337 |
+
"layout_b": LayoutTag[operation.B.layout],
|
| 1338 |
+
"element_c": DataTypeTag[operation.C.element],
|
| 1339 |
+
"layout_c": LayoutTag[operation.C.layout],
|
| 1340 |
+
"element_d": DataTypeTag[operation.epilogue_functor.element_output],
|
| 1341 |
+
"layout_d": LayoutTag[operation.C.layout],
|
| 1342 |
+
"element_accumulator": DataTypeTag[operation.accumulator_type()],
|
| 1343 |
+
"element_epilogue": DataTypeTag[operation.epilogue_functor.element_epilogue],
|
| 1344 |
+
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 1345 |
+
"arch": "cutlass::arch::Sm%d" % operation.arch,
|
| 1346 |
+
"threadblock_shape_m": str(emit_tile_m),
|
| 1347 |
+
"threadblock_shape_n": str(emit_tile_n),
|
| 1348 |
+
"threadblock_shape_k": str(emit_tile_k),
|
| 1349 |
+
"cluster_m": str(operation.tile_description.cluster_shape[0]),
|
| 1350 |
+
"cluster_n": str(operation.tile_description.cluster_shape[1]),
|
| 1351 |
+
"cluster_k": str(operation.tile_description.cluster_shape[2]),
|
| 1352 |
+
"align_a": str(operation.A.alignment),
|
| 1353 |
+
"align_b": str(operation.B.alignment),
|
| 1354 |
+
"align_c": str(operation.C.alignment),
|
| 1355 |
+
"align_d": str(operation.C.alignment),
|
| 1356 |
+
"stage_count_type": stage_count_type,
|
| 1357 |
+
"kernel_schedule": KernelScheduleTag[kschedule],
|
| 1358 |
+
"epilogue_schedule": EpilogueScheduleTag[eschedule],
|
| 1359 |
+
"tile_scheduler": TileSchedulerTag[tschedule]
|
| 1360 |
+
}
|
| 1361 |
+
if hasattr(operation.epilogue_functor, "visitor"):
|
| 1362 |
+
callback_name, callback_decl = operation.epilogue_functor.emit(operation)
|
| 1363 |
+
values["callback_name"] = callback_name
|
| 1364 |
+
values["callback_decl"] = callback_decl
|
| 1365 |
+
return SubstituteTemplate(self.gemm_template_kernel_visitor, values)
|
| 1366 |
+
|
| 1367 |
+
else:
|
| 1368 |
+
values["epilogue_functor"] = operation.epilogue_functor.emit()
|
| 1369 |
+
return SubstituteTemplate(gemm_template, values)
|
| 1370 |
+
|
| 1371 |
+
|
| 1372 |
+
###################################################################################################
|
| 1373 |
+
# Runtime module for GEMM Grouped
|
| 1374 |
+
###################################################################################################
|
| 1375 |
+
|
| 1376 |
+
|
| 1377 |
+
class GemmRTGrouped(GemmRTbase):
|
| 1378 |
+
"""
|
| 1379 |
+
GemmRTGrouped manages the CUTLASS runtime components
|
| 1380 |
+
"""
|
| 1381 |
+
|
| 1382 |
+
KernelTemplate = r"""
|
| 1383 |
+
extern "C"
|
| 1384 |
+
__global__ void
|
| 1385 |
+
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
| 1386 |
+
|
| 1387 |
+
// Dynamic shared memory base pointer
|
| 1388 |
+
extern __shared__ int SharedStorageBase[];
|
| 1389 |
+
|
| 1390 |
+
// Declare pointer to dynamic shared memory.
|
| 1391 |
+
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
| 1392 |
+
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
| 1393 |
+
|
| 1394 |
+
${operation_name}${operation_suffix} op;
|
| 1395 |
+
|
| 1396 |
+
op(params, *shared_storage);
|
| 1397 |
+
}
|
| 1398 |
+
"""
|
| 1399 |
+
|
| 1400 |
+
HostTemplate = r"""
|
| 1401 |
+
extern "C" {
|
| 1402 |
+
|
| 1403 |
+
// precompute scheduling information
|
| 1404 |
+
char * ${operation_name}_precompute(${operation_name}_base::Arguments const &args, int tile_count, size_t workspace_bytes) {
|
| 1405 |
+
char* host_workspace = new char[workspace_bytes];
|
| 1406 |
+
${operation_name}_base::ProblemVisitor::host_precompute(
|
| 1407 |
+
args.host_problem_sizes,
|
| 1408 |
+
args.problem_count,
|
| 1409 |
+
args.threadblock_count,
|
| 1410 |
+
(void*)host_workspace
|
| 1411 |
+
);
|
| 1412 |
+
return host_workspace;
|
| 1413 |
+
}
|
| 1414 |
+
|
| 1415 |
+
// Get the size of params in bytes
|
| 1416 |
+
int ${operation_name}_get_param_size(){
|
| 1417 |
+
return sizeof(${operation_name}${operation_suffix}::Params);
|
| 1418 |
+
}
|
| 1419 |
+
|
| 1420 |
+
// Get the size of dynamic shared memory in bytes
|
| 1421 |
+
int ${operation_name}_shared_memory_size() {
|
| 1422 |
+
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
| 1423 |
+
}
|
| 1424 |
+
|
| 1425 |
+
// Get the params as byte array
|
| 1426 |
+
char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int tile_count, void* workspace=nullptr){
|
| 1427 |
+
${operation_name}_base::Params* params;
|
| 1428 |
+
params = new ${operation_name}_base::Params(*argument, workspace, tile_count);
|
| 1429 |
+
|
| 1430 |
+
char *bytes = ((char*)(params));
|
| 1431 |
+
char *output = new char[sizeof(${operation_name}_base::Params)];
|
| 1432 |
+
for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++)
|
| 1433 |
+
output[i] = bytes[i];
|
| 1434 |
+
|
| 1435 |
+
return output;
|
| 1436 |
+
}
|
| 1437 |
+
|
| 1438 |
+
cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape(
|
| 1439 |
+
cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) {
|
| 1440 |
+
return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape(
|
| 1441 |
+
problem_size, tile_size, split_k_slices);
|
| 1442 |
+
}
|
| 1443 |
+
|
| 1444 |
+
dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) {
|
| 1445 |
+
return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape);
|
| 1446 |
+
}
|
| 1447 |
+
}
|
| 1448 |
+
"""
|
| 1449 |
+
|
| 1450 |
+
def __init__(self, operation: "GemmOperation"):
|
| 1451 |
+
super(GemmRTGrouped, self).__init__(operation)
|
| 1452 |
+
self.extra_funcs = {
|
| 1453 |
+
"precompute": None,
|
| 1454 |
+
"get_tiled_shape": GemmCoord_,
|
| 1455 |
+
"get_grid_shape": dim3_,
|
| 1456 |
+
}
|
| 1457 |
+
self.emitter = EmitGemmGroupedInstance("_type")
|
| 1458 |
+
self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor)
|
| 1459 |
+
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p]
|
| 1460 |
+
|
| 1461 |
+
def host_precompute(self, arguments, workspace_bytes):
|
| 1462 |
+
self.precompute.argtype = [
|
| 1463 |
+
self.argtype[0], ctypes.c_int, ctypes.c_longlong]
|
| 1464 |
+
self.precompute.restype = ctypes.POINTER(ctypes.c_byte * workspace_bytes)
|
| 1465 |
+
|
| 1466 |
+
problem_info = self.precompute(
|
| 1467 |
+
ctypes.byref(arguments.arguments),
|
| 1468 |
+
arguments.total_tiles,
|
| 1469 |
+
workspace_bytes)
|
| 1470 |
+
problem_info_array = bytearray(problem_info.contents)
|
| 1471 |
+
|
| 1472 |
+
# copy to device memory
|
| 1473 |
+
return todevice(problem_info_array).ptr
|
| 1474 |
+
|
| 1475 |
+
def plan(self, arguments):
|
| 1476 |
+
return LaunchConfiguration(
|
| 1477 |
+
[arguments.total_tiles, 1, 1],
|
| 1478 |
+
[self.threads, 1, 1],
|
| 1479 |
+
self.shared_memory_capacity,
|
| 1480 |
+
)
|
| 1481 |
+
|
| 1482 |
+
def get_workspace_size(self, arguments):
|
| 1483 |
+
if self.operation.precompute_mode == SchedulerMode.Device:
|
| 1484 |
+
return 0
|
| 1485 |
+
elif self.operation.precompute_mode == SchedulerMode.Host:
|
| 1486 |
+
total_tiles = arguments.total_tiles
|
| 1487 |
+
entries_per_block = 1
|
| 1488 |
+
return 8 * entries_per_block * total_tiles # three int32_t
|
| 1489 |
+
|
| 1490 |
+
|
| 1491 |
+
################################################################################
|
| 1492 |
+
# Runtime module for GEMM and grouped GEMM
|
| 1493 |
+
################################################################################
|
| 1494 |
+
|
| 1495 |
+
|
| 1496 |
+
class GemmOperationBase:
|
| 1497 |
+
"""
|
| 1498 |
+
CUTLASS GEMM operation
|
| 1499 |
+
"""
|
| 1500 |
+
|
| 1501 |
+
def __init__(
|
| 1502 |
+
self, gemm_kind, arch, tile_description: TileDescription,
|
| 1503 |
+
A: TensorDescription, B: TensorDescription, C: TensorDescription,
|
| 1504 |
+
epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1,
|
| 1505 |
+
api=ApiVersion.v2x, emission_type=EmissionType.Kernel, **kwargs):
|
| 1506 |
+
self.operation_kind: OperationKind = OperationKind.Gemm
|
| 1507 |
+
self.arch: int = arch
|
| 1508 |
+
self.tile_description: TileDescription = tile_description
|
| 1509 |
+
self.gemm_kind: GemmKind = gemm_kind
|
| 1510 |
+
|
| 1511 |
+
self.api = api
|
| 1512 |
+
self.prefix = "3x" if self.api == ApiVersion.v3x else ""
|
| 1513 |
+
self.emission_type = emission_type
|
| 1514 |
+
|
| 1515 |
+
# Optionally swap the TensorDescriptions for operands A and B and transpose their
|
| 1516 |
+
# layouts. This is needed to mimic the transpose performed by device::GemmUniversal.
|
| 1517 |
+
# The code below uses deep copy to avoid overwritting the original TensorDescription
|
| 1518 |
+
self.switched = (self.api != ApiVersion.v3x and
|
| 1519 |
+
self.emission_type == EmissionType.Kernel and
|
| 1520 |
+
C.layout == LayoutType.ColumnMajor)
|
| 1521 |
+
|
| 1522 |
+
self.A, self.B, self.C = GemmOperationBase.get_operands(A, B, C, self.switched)
|
| 1523 |
+
|
| 1524 |
+
self.epilogue_functor = epilogue_functor
|
| 1525 |
+
self.swizzling_functor = swizzling_functor
|
| 1526 |
+
|
| 1527 |
+
if "direct_store" in kwargs:
|
| 1528 |
+
self.direct_store = kwargs["direct_store"]
|
| 1529 |
+
else:
|
| 1530 |
+
self.direct_store = False
|
| 1531 |
+
|
| 1532 |
+
@staticmethod
|
| 1533 |
+
def get_operands(A: TensorDescription, B: TensorDescription, C: TensorDescription, swap: bool):
|
| 1534 |
+
"""
|
| 1535 |
+
Makes copies of A, B, and C, and possibly transposes their order. If ``swap`` is set,
|
| 1536 |
+
A and B are swapped, and the layout of A, B, and C are transposed.
|
| 1537 |
+
|
| 1538 |
+
:param A: description of operand A
|
| 1539 |
+
:type A: TensorDescription
|
| 1540 |
+
:param B: description of operand B
|
| 1541 |
+
:type B: TensorDescription
|
| 1542 |
+
:param C: description of operand C
|
| 1543 |
+
:type C: TensorDescription
|
| 1544 |
+
|
| 1545 |
+
:return: descriptions of operands A, B, and C
|
| 1546 |
+
:rtype: tuple[TileDescription]
|
| 1547 |
+
"""
|
| 1548 |
+
if swap:
|
| 1549 |
+
A_out = copy.deepcopy(B)
|
| 1550 |
+
B_out = copy.deepcopy(A)
|
| 1551 |
+
C_out = copy.deepcopy(C)
|
| 1552 |
+
A_out.layout = transpose_layout(A_out.layout)
|
| 1553 |
+
B_out.layout = transpose_layout(B_out.layout)
|
| 1554 |
+
C_out.layout = transpose_layout(C_out.layout)
|
| 1555 |
+
else:
|
| 1556 |
+
A_out = copy.deepcopy(A)
|
| 1557 |
+
B_out = copy.deepcopy(B)
|
| 1558 |
+
C_out = copy.deepcopy(C)
|
| 1559 |
+
return A_out, B_out, C_out
|
| 1560 |
+
|
| 1561 |
+
def run(self, arguments: GemmArguments) -> cuda.CUresult:
|
| 1562 |
+
"""
|
| 1563 |
+
Configure and launch the cuda kernel with input arguments
|
| 1564 |
+
"""
|
| 1565 |
+
if self.emission_type == EmissionType.Device:
|
| 1566 |
+
raise Exception('Running a kernel via PyCUTLASS is only enabled with emission type "Kernel"')
|
| 1567 |
+
|
| 1568 |
+
err = self.rt_module.run(
|
| 1569 |
+
arguments.host_workspace,
|
| 1570 |
+
arguments.device_workspace,
|
| 1571 |
+
arguments.launch_config,
|
| 1572 |
+
arguments.stream
|
| 1573 |
+
)
|
| 1574 |
+
|
| 1575 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 1576 |
+
raise RuntimeError("CUDA Error %s" % str(err))
|
| 1577 |
+
|
| 1578 |
+
return err
|
| 1579 |
+
|
| 1580 |
+
def is_complex(self):
|
| 1581 |
+
complex_operators = [
|
| 1582 |
+
MathOperation.multiply_add_complex,
|
| 1583 |
+
MathOperation.multiply_add_complex_gaussian,
|
| 1584 |
+
MathOperation.multiply_add_complex_fast_f32,
|
| 1585 |
+
]
|
| 1586 |
+
return self.tile_description.math_instruction.math_operation in complex_operators
|
| 1587 |
+
|
| 1588 |
+
def is_planar_complex(self):
|
| 1589 |
+
return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
|
| 1590 |
+
|
| 1591 |
+
def accumulator_type(self):
|
| 1592 |
+
accum = self.tile_description.math_instruction.element_accumulator
|
| 1593 |
+
|
| 1594 |
+
if self.is_complex():
|
| 1595 |
+
return get_complex_from_real(accum)
|
| 1596 |
+
|
| 1597 |
+
return accum
|
| 1598 |
+
|
| 1599 |
+
def short_math_name(self):
|
| 1600 |
+
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
| 1601 |
+
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
| 1602 |
+
return ShortDataTypeNames[self.accumulator_type()]
|
| 1603 |
+
|
| 1604 |
+
def core_name(self):
|
| 1605 |
+
"""The basic operation kind is prefixed with a letter indicating the accumulation type."""
|
| 1606 |
+
|
| 1607 |
+
inst_shape = ""
|
| 1608 |
+
inst_operation = ""
|
| 1609 |
+
intermediate_type = ""
|
| 1610 |
+
|
| 1611 |
+
math_operations_map = {
|
| 1612 |
+
MathOperation.xor_popc: "xor",
|
| 1613 |
+
}
|
| 1614 |
+
|
| 1615 |
+
if (self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or
|
| 1616 |
+
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp):
|
| 1617 |
+
math_op = self.tile_description.math_instruction.math_operation
|
| 1618 |
+
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ""
|
| 1619 |
+
|
| 1620 |
+
if self.tile_description.math_instruction.instruction_shape is not None:
|
| 1621 |
+
if self.api == ApiVersion.v3x and self.arch >= 90:
|
| 1622 |
+
inst_shape = "%dx%dx%d" % tuple(
|
| 1623 |
+
self.tile_description.math_instruction.instruction_shape)
|
| 1624 |
+
else:
|
| 1625 |
+
inst_shape = "%d%d%d" % tuple(
|
| 1626 |
+
self.tile_description.math_instruction.instruction_shape)
|
| 1627 |
+
else:
|
| 1628 |
+
inst_shape = "Default"
|
| 1629 |
+
inst_shape += math_op_string
|
| 1630 |
+
|
| 1631 |
+
if (self.tile_description.math_instruction.element_a != self.A.element and
|
| 1632 |
+
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator):
|
| 1633 |
+
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 1634 |
+
|
| 1635 |
+
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
|
| 1636 |
+
|
| 1637 |
+
def extended_name(self):
|
| 1638 |
+
"""Append data types if they differ from compute type."""
|
| 1639 |
+
if self.is_complex():
|
| 1640 |
+
extended_name = "${core_name}"
|
| 1641 |
+
else:
|
| 1642 |
+
if (self.C.element != self.tile_description.math_instruction.element_accumulator and
|
| 1643 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator):
|
| 1644 |
+
extended_name = "${element_c}_${core_name}_${element_a}"
|
| 1645 |
+
elif (self.C.element == self.tile_description.math_instruction.element_accumulator and
|
| 1646 |
+
self.A.element != self.tile_description.math_instruction.element_accumulator):
|
| 1647 |
+
extended_name = "${core_name}_${element_a}"
|
| 1648 |
+
else:
|
| 1649 |
+
extended_name = "${core_name}"
|
| 1650 |
+
|
| 1651 |
+
extended_name = SubstituteTemplate(extended_name, {
|
| 1652 |
+
"element_a": DataTypeNames[self.A.element],
|
| 1653 |
+
"element_c": DataTypeNames[self.C.element],
|
| 1654 |
+
"core_name": self.core_name(),
|
| 1655 |
+
})
|
| 1656 |
+
|
| 1657 |
+
return extended_name
|
| 1658 |
+
|
| 1659 |
+
def extended_name_3x(self):
|
| 1660 |
+
"""Generates a string representing the MMA atom. Assumes accumulator type is C type."""
|
| 1661 |
+
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
| 1662 |
+
element_a=DataTypeNames[self.A.element],
|
| 1663 |
+
element_b=DataTypeNames[self.B.element],
|
| 1664 |
+
element_acc=DataTypeNames[self.accumulator_type()],
|
| 1665 |
+
element_c=DataTypeNames[self.C.element],
|
| 1666 |
+
element_d=DataTypeNames[self.epilogue_functor.element_output],
|
| 1667 |
+
core_name=self.core_name())
|
| 1668 |
+
return extended_name
|
| 1669 |
+
|
| 1670 |
+
def layout_name(self):
|
| 1671 |
+
if self.is_complex() or self.is_planar_complex():
|
| 1672 |
+
return "%s%s" % (
|
| 1673 |
+
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
| 1674 |
+
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
|
| 1675 |
+
)
|
| 1676 |
+
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
|
| 1677 |
+
|
| 1678 |
+
# Generates a short string representing the ABC layout tags (e.g. ntn or tnn)
|
| 1679 |
+
def layout_name_3x(self):
|
| 1680 |
+
if self.is_complex() or self.is_planar_complex():
|
| 1681 |
+
return "{}{}{}".format(
|
| 1682 |
+
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
| 1683 |
+
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
|
| 1684 |
+
ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)])
|
| 1685 |
+
else:
|
| 1686 |
+
return "{}{}{}".format(
|
| 1687 |
+
ShortLayoutTypeNames[self.A.layout],
|
| 1688 |
+
ShortLayoutTypeNames[self.B.layout],
|
| 1689 |
+
ShortLayoutTypeNames[self.C.layout])
|
| 1690 |
+
|
| 1691 |
+
# Generates a short string representing underlying kernel schedule type
|
| 1692 |
+
def kernel_schedule_name_3x(self):
|
| 1693 |
+
if self.tile_description.kernel_schedule is None:
|
| 1694 |
+
return KernelScheduleSuffixes[KernelScheduleType.ScheduleAuto]
|
| 1695 |
+
else:
|
| 1696 |
+
return KernelScheduleSuffixes[self.tile_description.kernel_schedule]
|
| 1697 |
+
|
| 1698 |
+
# Generates a short string representing underlying epilogue schedule type
|
| 1699 |
+
def epilogue_schedule_name_3x(self):
|
| 1700 |
+
if self.tile_description.epilogue_schedule is None:
|
| 1701 |
+
return EpilogueScheduleSuffixes[EpilogueScheduleType.ScheduleAuto]
|
| 1702 |
+
else:
|
| 1703 |
+
return EpilogueScheduleSuffixes[self.tile_description.epilogue_schedule]
|
| 1704 |
+
|
| 1705 |
+
def procedural_name(self):
|
| 1706 |
+
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
|
| 1707 |
+
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 1708 |
+
if self.api == ApiVersion.v3x and self.arch >= 90:
|
| 1709 |
+
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}"
|
| 1710 |
+
return kernel_name_template.format(
|
| 1711 |
+
p=self.prefix,
|
| 1712 |
+
ar=self.arch,
|
| 1713 |
+
op=opcode_class_name,
|
| 1714 |
+
ex=self.extended_name_3x(),
|
| 1715 |
+
tbm=self.tile_description.threadblock_shape[0],
|
| 1716 |
+
tbn=self.tile_description.threadblock_shape[1],
|
| 1717 |
+
tbk=self.tile_description.threadblock_shape[2],
|
| 1718 |
+
cm=self.tile_description.cluster_shape[0],
|
| 1719 |
+
cn=self.tile_description.cluster_shape[1],
|
| 1720 |
+
ck=self.tile_description.cluster_shape[2],
|
| 1721 |
+
l=self.tile_description.stages,
|
| 1722 |
+
s=self.layout_name_3x(),
|
| 1723 |
+
al=str(self.A.alignment),
|
| 1724 |
+
k=self.kernel_schedule_name_3x(),
|
| 1725 |
+
e=self.epilogue_schedule_name_3x()
|
| 1726 |
+
)
|
| 1727 |
+
else:
|
| 1728 |
+
threadblock = self.tile_description.procedural_name_2x()
|
| 1729 |
+
return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
|
| 1730 |
+
p=self.prefix,
|
| 1731 |
+
op=opcode_class_name,
|
| 1732 |
+
ex=self.extended_name(),
|
| 1733 |
+
tb=threadblock,
|
| 1734 |
+
l=self.layout_name(),
|
| 1735 |
+
a=str(self.A.alignment)
|
| 1736 |
+
)
|
| 1737 |
+
|
| 1738 |
+
def configuration_name(self):
|
| 1739 |
+
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
|
| 1740 |
+
return self.procedural_name()
|
| 1741 |
+
|
| 1742 |
+
|
| 1743 |
+
class GemmOperationUniversal(GemmOperationBase):
|
| 1744 |
+
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
|
| 1745 |
+
epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs):
|
| 1746 |
+
api = api_version(arch, tile_description.math_instruction.opcode_class, A.element)
|
| 1747 |
+
super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description,
|
| 1748 |
+
A, B, C, epilogue_functor, swizzling_functor,
|
| 1749 |
+
api=api, **kwargs, )
|
| 1750 |
+
if api == ApiVersion.v3x:
|
| 1751 |
+
if swizzling_functor == SwizzlingFunctor.StreamK:
|
| 1752 |
+
raise Exception("Stream K swizzle functor is currently only supported for CUTLASS 2.x kernels")
|
| 1753 |
+
self.rt_module = GemmRTUniversal3x(self)
|
| 1754 |
+
else:
|
| 1755 |
+
if swizzling_functor == SwizzlingFunctor.StreamK:
|
| 1756 |
+
self.rt_module = GemmRTUniversalStreamK(self)
|
| 1757 |
+
else:
|
| 1758 |
+
self.rt_module = GemmRTUniversal(self)
|
| 1759 |
+
self.argument_type = self.rt_module.argument_type
|
| 1760 |
+
self.epilogue_type = self.rt_module.epilogue_type
|
| 1761 |
+
|
| 1762 |
+
def device_op(self):
|
| 1763 |
+
"""
|
| 1764 |
+
Returns a new GemmOperationUniversal object that is constructed with emission type
|
| 1765 |
+
``EmissionType.Device``. Since the device-emitted kernel does not require swapping,
|
| 1766 |
+
any swappng performed by the kernel-emitted operation is reversed.
|
| 1767 |
+
|
| 1768 |
+
:return: operation ready for device-level code emission
|
| 1769 |
+
:rtype: GemmUniversalOperation
|
| 1770 |
+
"""
|
| 1771 |
+
A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched)
|
| 1772 |
+
return GemmOperationUniversal(self.arch, self.tile_description, A, B, C,
|
| 1773 |
+
self.epilogue_functor, self.swizzling_functor,
|
| 1774 |
+
emission_type=EmissionType.Device, direct_store=self.direct_store)
|
| 1775 |
+
|
| 1776 |
+
|
| 1777 |
+
class GemmOperationGrouped(GemmOperationBase):
|
| 1778 |
+
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
|
| 1779 |
+
epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs):
|
| 1780 |
+
super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description,
|
| 1781 |
+
A, B, C, epilogue_functor, swizzling_functor, **kwargs)
|
| 1782 |
+
assert "precompute_mode" in kwargs.keys(), "missing keyword arguement 'precompute_mode'."
|
| 1783 |
+
self.precompute_mode = kwargs["precompute_mode"]
|
| 1784 |
+
self.rt_module = GemmRTGrouped(self)
|
| 1785 |
+
self.argument_type = self.rt_module.argument_type
|
| 1786 |
+
self.epilogue_type = self.rt_module.epilogue_type
|
| 1787 |
+
|
| 1788 |
+
def device_op(self):
|
| 1789 |
+
"""
|
| 1790 |
+
Returns a new GemmOperationGrouped object that is constructed with emission type
|
| 1791 |
+
``EmissionType.Device``. Since the device-emitted kernel does not require swapping,
|
| 1792 |
+
any swappng performed by the kernel-emitted operation is reversed.
|
| 1793 |
+
|
| 1794 |
+
:return: operation ready for device-level code emission
|
| 1795 |
+
:rtype: GemmOperationGrouped
|
| 1796 |
+
"""
|
| 1797 |
+
A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched)
|
| 1798 |
+
return GemmOperationGrouped(
|
| 1799 |
+
self.arch, self.tile_description, A, B, C, self.epilogue_functor,
|
| 1800 |
+
self.swizzling_functor, emission_type=EmissionType.Device,
|
| 1801 |
+
direct_store=self.direct_store, precompute_mode=self.precompute_mode, )
|
| 1802 |
+
|
| 1803 |
+
|
| 1804 |
+
###################################################################################################
|
| 1805 |
+
#
|
| 1806 |
+
# Emits single instances of a CUTLASS device-wide operator
|
| 1807 |
+
#
|
| 1808 |
+
###################################################################################################
|
| 1809 |
+
|
| 1810 |
+
|
| 1811 |
+
class EmitGemmUniversalInstance:
|
| 1812 |
+
"""Responsible for emitting a CUTLASS template definition"""
|
| 1813 |
+
|
| 1814 |
+
def __init__(
|
| 1815 |
+
self,
|
| 1816 |
+
operation_suffix="",
|
| 1817 |
+
direct_store=False
|
| 1818 |
+
):
|
| 1819 |
+
self.operation_suffix = operation_suffix
|
| 1820 |
+
self.direct_store = direct_store
|
| 1821 |
+
self.includes = [
|
| 1822 |
+
"cutlass/cutlass.h",
|
| 1823 |
+
"cutlass/gemm_coord.h",
|
| 1824 |
+
"cutlass/numeric_types.h",
|
| 1825 |
+
"cutlass/arch/arch.h",
|
| 1826 |
+
"cutlass/arch/mma.h",
|
| 1827 |
+
"cutlass/layout/matrix.h",
|
| 1828 |
+
"cutlass/gemm/device/gemm.h",
|
| 1829 |
+
"cutlass/gemm/device/gemm_universal_adapter.h",
|
| 1830 |
+
"cutlass/gemm/kernel/default_gemm_universal.h",
|
| 1831 |
+
]
|
| 1832 |
+
if self.direct_store:
|
| 1833 |
+
self.includes.append(
|
| 1834 |
+
"cutlass/epilogue/threadblock/default_epilogue_direct_store.h"
|
| 1835 |
+
)
|
| 1836 |
+
self.gemm_template_kernel = """
|
| 1837 |
+
// Gemm operator ${operation_name}
|
| 1838 |
+
using ${operation_name}_base =
|
| 1839 |
+
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
| 1840 |
+
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
| 1841 |
+
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
| 1842 |
+
${element_c}, ${layout_c},
|
| 1843 |
+
${element_accumulator},
|
| 1844 |
+
${opcode_class},
|
| 1845 |
+
${arch},
|
| 1846 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1847 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1848 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1849 |
+
${epilogue_functor},
|
| 1850 |
+
${swizzling_functor},
|
| 1851 |
+
${stages},
|
| 1852 |
+
${math_operation}
|
| 1853 |
+
>::GemmKernel;
|
| 1854 |
+
|
| 1855 |
+
// Define named type
|
| 1856 |
+
struct ${operation_name}${operation_suffix} :
|
| 1857 |
+
public ${operation_name}_base { };
|
| 1858 |
+
"""
|
| 1859 |
+
|
| 1860 |
+
self.gemm_template_device = """
|
| 1861 |
+
// Gemm operator ${operation_name}
|
| 1862 |
+
using DeviceKernel =
|
| 1863 |
+
typename cutlass::gemm::device::GemmUniversal<
|
| 1864 |
+
// Data type and layout of operand A
|
| 1865 |
+
${element_a}, ${layout_a},
|
| 1866 |
+
// Data type and layout of operand B
|
| 1867 |
+
${element_b}, ${layout_b},
|
| 1868 |
+
// Data type and layout of operand C
|
| 1869 |
+
${element_c}, ${layout_c},
|
| 1870 |
+
// Data type of accumulator
|
| 1871 |
+
${element_accumulator},
|
| 1872 |
+
// Class of operation
|
| 1873 |
+
${opcode_class},
|
| 1874 |
+
// Compute capability of the target kernel
|
| 1875 |
+
${arch},
|
| 1876 |
+
// Threadblock tile shape
|
| 1877 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1878 |
+
// Warp tile shape
|
| 1879 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1880 |
+
// Instruction shape
|
| 1881 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1882 |
+
// Epilogue functor
|
| 1883 |
+
${epilogue_functor},
|
| 1884 |
+
// Swizzling function
|
| 1885 |
+
${swizzling_functor},
|
| 1886 |
+
// Number of pipeline stages
|
| 1887 |
+
${stages},
|
| 1888 |
+
// Alignment of operands A and B
|
| 1889 |
+
${align_a}, ${align_b},
|
| 1890 |
+
// Type of math operation
|
| 1891 |
+
${math_operation},
|
| 1892 |
+
// Complex transform types of operands A and B
|
| 1893 |
+
${transform_a}, ${transform_b}
|
| 1894 |
+
>;
|
| 1895 |
+
"""
|
| 1896 |
+
self.gemm_template_direct_store = """
|
| 1897 |
+
// Gemm operator ${operation_name}
|
| 1898 |
+
using ${operation_name}_default =
|
| 1899 |
+
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
| 1900 |
+
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
| 1901 |
+
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
| 1902 |
+
${element_c}, ${layout_c},
|
| 1903 |
+
${element_accumulator},
|
| 1904 |
+
${opcode_class},
|
| 1905 |
+
${arch},
|
| 1906 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1907 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1908 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1909 |
+
${epilogue_functor},
|
| 1910 |
+
${swizzling_functor},
|
| 1911 |
+
${stages},
|
| 1912 |
+
${math_operation}
|
| 1913 |
+
>::GemmKernel;
|
| 1914 |
+
|
| 1915 |
+
using ${operation_name}_base =
|
| 1916 |
+
cutlass::gemm::kernel::GemmUniversal<
|
| 1917 |
+
${operation_name}_default::Mma,
|
| 1918 |
+
cutlass::epilogue::threadblock::DefaultEpilogueDirectStore<
|
| 1919 |
+
${operation_name}_default::Epilogue
|
| 1920 |
+
>::Epilogue,
|
| 1921 |
+
${operation_name}_default::ThreadblockSwizzle
|
| 1922 |
+
>;
|
| 1923 |
+
|
| 1924 |
+
// Define named type
|
| 1925 |
+
struct ${operation_name}${operation_suffix} :
|
| 1926 |
+
public ${operation_name}_base { };
|
| 1927 |
+
"""
|
| 1928 |
+
self.gemm_template_kernel_visitor = """
|
| 1929 |
+
|
| 1930 |
+
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
| 1931 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1932 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1933 |
+
${element_c},
|
| 1934 |
+
${align_c},
|
| 1935 |
+
${epilogue_stages} /* epilogue stages */
|
| 1936 |
+
>;
|
| 1937 |
+
|
| 1938 |
+
${callback_decl}
|
| 1939 |
+
|
| 1940 |
+
// Gemm operator ${operation_name}
|
| 1941 |
+
using ${operation_name}_base =
|
| 1942 |
+
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
| 1943 |
+
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
| 1944 |
+
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
| 1945 |
+
${element_c}, ${layout_c}, ${align_c},
|
| 1946 |
+
${element_accumulator},
|
| 1947 |
+
${element_epilogue},
|
| 1948 |
+
${opcode_class},
|
| 1949 |
+
${arch},
|
| 1950 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1951 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1952 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1953 |
+
${callback_name},
|
| 1954 |
+
${swizzling_functor},
|
| 1955 |
+
${stages},
|
| 1956 |
+
${math_operation},
|
| 1957 |
+
${epilogue_stages} /* epilogue stages */
|
| 1958 |
+
>::GemmKernel;
|
| 1959 |
+
|
| 1960 |
+
// Define named type
|
| 1961 |
+
struct ${operation_name}${operation_suffix} :
|
| 1962 |
+
public ${operation_name}_base { };
|
| 1963 |
+
"""
|
| 1964 |
+
|
| 1965 |
+
def instance_template(self):
|
| 1966 |
+
return """
|
| 1967 |
+
${compile_guard_start}
|
| 1968 |
+
manifest.append(new ${gemm_kind}<
|
| 1969 |
+
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
| 1970 |
+
>("${operation_name}"));
|
| 1971 |
+
${compile_guard_end}
|
| 1972 |
+
"""
|
| 1973 |
+
|
| 1974 |
+
def emit(self, operation):
|
| 1975 |
+
threadblock_shape = operation.tile_description.threadblock_shape
|
| 1976 |
+
warp_count = operation.tile_description.warp_count
|
| 1977 |
+
|
| 1978 |
+
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 1979 |
+
|
| 1980 |
+
instance_layout_A, instance_layout_B, instance_layout_C = \
|
| 1981 |
+
(operation.A.layout, operation.B.layout, operation.C.layout)
|
| 1982 |
+
|
| 1983 |
+
if operation.emission_type == EmissionType.Kernel:
|
| 1984 |
+
if self.direct_store:
|
| 1985 |
+
gemm_template = self.gemm_template_direct_store
|
| 1986 |
+
else:
|
| 1987 |
+
gemm_template = self.gemm_template_kernel
|
| 1988 |
+
else:
|
| 1989 |
+
gemm_template = self.gemm_template_device
|
| 1990 |
+
|
| 1991 |
+
values = {
|
| 1992 |
+
"operation_name": operation.procedural_name(),
|
| 1993 |
+
"operation_suffix": self.operation_suffix,
|
| 1994 |
+
"element_a": DataTypeTag[operation.A.element],
|
| 1995 |
+
"layout_a": LayoutTag[instance_layout_A],
|
| 1996 |
+
"element_b": DataTypeTag[operation.B.element],
|
| 1997 |
+
"layout_b": LayoutTag[instance_layout_B],
|
| 1998 |
+
"element_c": DataTypeTag[operation.C.element],
|
| 1999 |
+
"layout_c": LayoutTag[instance_layout_C],
|
| 2000 |
+
"element_accumulator": DataTypeTag[operation.accumulator_type()],
|
| 2001 |
+
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 2002 |
+
"arch": "cutlass::arch::Sm%d" % operation.arch,
|
| 2003 |
+
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
|
| 2004 |
+
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
|
| 2005 |
+
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
|
| 2006 |
+
"warp_shape_m": str(warp_shape[0]),
|
| 2007 |
+
"warp_shape_n": str(warp_shape[1]),
|
| 2008 |
+
"warp_shape_k": str(warp_shape[2]),
|
| 2009 |
+
"instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 2010 |
+
"instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 2011 |
+
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 2012 |
+
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
|
| 2013 |
+
"stages": str(operation.tile_description.stages),
|
| 2014 |
+
"align_a": str(operation.A.alignment),
|
| 2015 |
+
"align_b": str(operation.B.alignment),
|
| 2016 |
+
"transform_a": ComplexTransformTag[operation.A.complex_transform],
|
| 2017 |
+
"transform_b": ComplexTransformTag[operation.B.complex_transform],
|
| 2018 |
+
"math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 2019 |
+
}
|
| 2020 |
+
|
| 2021 |
+
if hasattr(operation.epilogue_functor, "visitor"):
|
| 2022 |
+
self.includes += [
|
| 2023 |
+
"cutlass/epilogue/threadblock/fusion/visitors.hpp",
|
| 2024 |
+
"cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
| 2025 |
+
]
|
| 2026 |
+
callback_name, callback_decl = operation.epilogue_functor.emit(operation)
|
| 2027 |
+
values["callback_name"] = callback_name
|
| 2028 |
+
values["callback_decl"] = callback_decl
|
| 2029 |
+
values["align_c"] = str(operation.C.alignment)
|
| 2030 |
+
values["element_epilogue"] = DataTypeTag[operation.epilogue_functor.element_epilogue]
|
| 2031 |
+
if hasattr(operation.epilogue_functor, "epilogue_stages"):
|
| 2032 |
+
epilogue_stages = operation.epilogue_functor.epilogue_stages
|
| 2033 |
+
else:
|
| 2034 |
+
epilogue_stages = 1
|
| 2035 |
+
values["epilogue_stages"] = str(epilogue_stages)
|
| 2036 |
+
return SubstituteTemplate(self.gemm_template_kernel_visitor, values)
|
| 2037 |
+
else:
|
| 2038 |
+
values["epilogue_functor"] = operation.epilogue_functor.emit()
|
| 2039 |
+
return SubstituteTemplate(gemm_template, values)
|
| 2040 |
+
|
| 2041 |
+
|
| 2042 |
+
class EmitGemmGroupedInstance:
|
| 2043 |
+
"""Responsible for emitting a CUTLASS template definition"""
|
| 2044 |
+
|
| 2045 |
+
def __init__(self, operation_suffix=""):
|
| 2046 |
+
self.operation_suffix = operation_suffix
|
| 2047 |
+
self.includes = [
|
| 2048 |
+
"cutlass/cutlass.h",
|
| 2049 |
+
"cutlass/numeric_types.h",
|
| 2050 |
+
"cutlass/arch/arch.h",
|
| 2051 |
+
"cutlass/arch/mma.h",
|
| 2052 |
+
"cutlass/layout/matrix.h",
|
| 2053 |
+
"cutlass/gemm/kernel/gemm_grouped.h",
|
| 2054 |
+
"cutlass/gemm/kernel/default_gemm_grouped.h",
|
| 2055 |
+
]
|
| 2056 |
+
self.gemm_template_kernel = """
|
| 2057 |
+
// Gemm operator ${operation_name}
|
| 2058 |
+
using ${operation_name}_base =
|
| 2059 |
+
typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
| 2060 |
+
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
| 2061 |
+
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
| 2062 |
+
${element_c}, ${layout_c},
|
| 2063 |
+
${element_accumulator},
|
| 2064 |
+
${opcode_class},
|
| 2065 |
+
${arch},
|
| 2066 |
+
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 2067 |
+
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 2068 |
+
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 2069 |
+
${epilogue_functor},
|
| 2070 |
+
${swizzling_functor},
|
| 2071 |
+
${stages},
|
| 2072 |
+
${precompute_mode},
|
| 2073 |
+
${math_operation}
|
| 2074 |
+
>::GemmKernel;
|
| 2075 |
+
|
| 2076 |
+
// Define named type
|
| 2077 |
+
struct ${operation_name}${operation_suffix} :
|
| 2078 |
+
public ${operation_name}_base { };
|
| 2079 |
+
"""
|
| 2080 |
+
self.gemm_template_device = (
|
| 2081 |
+
self.gemm_template_kernel
|
| 2082 |
+
+ """
|
| 2083 |
+
using DeviceKernel = cutlass::gemm::device::GemmGrouped<${operation_name}_base>;
|
| 2084 |
+
"""
|
| 2085 |
+
)
|
| 2086 |
+
|
| 2087 |
+
def instance_template(self):
|
| 2088 |
+
return """
|
| 2089 |
+
${compile_guard_start}
|
| 2090 |
+
manifest.append(new ${gemm_kind}<
|
| 2091 |
+
cutlass::gemm::device::GemmGrouped<${operation_name}>
|
| 2092 |
+
>("${operation_name}"));
|
| 2093 |
+
${compile_guard_end}
|
| 2094 |
+
"""
|
| 2095 |
+
|
| 2096 |
+
def emit(self, operation):
|
| 2097 |
+
threadblock_shape = operation.tile_description.threadblock_shape
|
| 2098 |
+
warp_count = operation.tile_description.warp_count
|
| 2099 |
+
|
| 2100 |
+
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 2101 |
+
|
| 2102 |
+
instance_layout_A, instance_layout_B, instance_layout_C = \
|
| 2103 |
+
(operation.A.layout, operation.B.layout, operation.C.layout)
|
| 2104 |
+
|
| 2105 |
+
# Support built-in epilogue functors or user-defined functions
|
| 2106 |
+
epilogue_functor = operation.epilogue_functor.emit()
|
| 2107 |
+
|
| 2108 |
+
values = {
|
| 2109 |
+
"operation_name": operation.procedural_name(),
|
| 2110 |
+
"operation_suffix": self.operation_suffix,
|
| 2111 |
+
"element_a": DataTypeTag[operation.A.element],
|
| 2112 |
+
"layout_a": LayoutTag[instance_layout_A],
|
| 2113 |
+
"element_b": DataTypeTag[operation.B.element],
|
| 2114 |
+
"layout_b": LayoutTag[instance_layout_B],
|
| 2115 |
+
"element_c": DataTypeTag[operation.C.element],
|
| 2116 |
+
"layout_c": LayoutTag[instance_layout_C],
|
| 2117 |
+
"element_accumulator": DataTypeTag[operation.accumulator_type()],
|
| 2118 |
+
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 2119 |
+
"arch": "cutlass::arch::Sm%d" % operation.arch,
|
| 2120 |
+
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
|
| 2121 |
+
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
|
| 2122 |
+
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
|
| 2123 |
+
"warp_shape_m": str(warp_shape[0]),
|
| 2124 |
+
"warp_shape_n": str(warp_shape[1]),
|
| 2125 |
+
"warp_shape_k": str(warp_shape[2]),
|
| 2126 |
+
"instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 2127 |
+
"instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 2128 |
+
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 2129 |
+
"epilogue_functor": epilogue_functor,
|
| 2130 |
+
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
|
| 2131 |
+
"stages": str(operation.tile_description.stages),
|
| 2132 |
+
"align_a": str(operation.A.alignment),
|
| 2133 |
+
"align_b": str(operation.B.alignment),
|
| 2134 |
+
"transform_a": ComplexTransformTag[operation.A.complex_transform],
|
| 2135 |
+
"transform_b": ComplexTransformTag[operation.B.complex_transform],
|
| 2136 |
+
"precompute_mode": SchedulerModeTag[operation.precompute_mode],
|
| 2137 |
+
"math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 2138 |
+
}
|
| 2139 |
+
|
| 2140 |
+
if operation.emission_type == EmissionType.Kernel:
|
| 2141 |
+
gemm_template = self.gemm_template_kernel
|
| 2142 |
+
else:
|
| 2143 |
+
gemm_template = self.gemm_template_device
|
| 2144 |
+
|
| 2145 |
+
return SubstituteTemplate(gemm_template, values)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Common data types and string names/tags for them
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import enum
|
| 38 |
+
|
| 39 |
+
from cutlass_library import (
|
| 40 |
+
ComplexTransform,
|
| 41 |
+
DataType,
|
| 42 |
+
DataTypeSize,
|
| 43 |
+
EpilogueScheduleType,
|
| 44 |
+
KernelScheduleSuffixes,
|
| 45 |
+
KernelScheduleType,
|
| 46 |
+
MathOperation,
|
| 47 |
+
OpcodeClass,
|
| 48 |
+
TileSchedulerType
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
|
| 53 |
+
# as the default 3.5.2 on Ubuntu 16.04.
|
| 54 |
+
#
|
| 55 |
+
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
from enum import auto as enum_auto
|
| 59 |
+
except ImportError:
|
| 60 |
+
__cutlass_library_auto_enum = 0
|
| 61 |
+
|
| 62 |
+
def enum_auto() -> int:
|
| 63 |
+
global __cutlass_library_auto_enum
|
| 64 |
+
i = __cutlass_library_auto_enum
|
| 65 |
+
__cutlass_library_auto_enum += 1
|
| 66 |
+
return i
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class DataTypeSizeBytes:
|
| 70 |
+
"""
|
| 71 |
+
Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the
|
| 72 |
+
data type key is less than a full byte or a non-integer number of bytes.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def __class_getitem__(datatype):
|
| 77 |
+
"""
|
| 78 |
+
Returns the number of bytes in size the data type is. Raises an exception if the data type
|
| 79 |
+
is either less than a full byte or a non-integer number of bytes in size.
|
| 80 |
+
|
| 81 |
+
:param datatype: data type to query
|
| 82 |
+
|
| 83 |
+
:return: number of bytes the data type occupies
|
| 84 |
+
:rtype: int
|
| 85 |
+
"""
|
| 86 |
+
bits = DataTypeSize[datatype]
|
| 87 |
+
if bits < 8:
|
| 88 |
+
raise Exception(
|
| 89 |
+
f"Data type {datatype} is less than one byte in size."
|
| 90 |
+
)
|
| 91 |
+
elif bits % 8 != 0:
|
| 92 |
+
raise Exception(
|
| 93 |
+
f"Data type datatype is not an integer number of bytes."
|
| 94 |
+
)
|
| 95 |
+
return bits // 8
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class SchedulerMode(enum.Enum):
|
| 99 |
+
Device = enum_auto()
|
| 100 |
+
Host = enum_auto()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
SchedulerModeTag = {
|
| 104 |
+
SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
|
| 105 |
+
SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute",
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class FunctionalOp(enum.Enum):
|
| 113 |
+
AtomicAdd = enum_auto()
|
| 114 |
+
AtomicMaximum = enum_auto()
|
| 115 |
+
Divides = enum_auto()
|
| 116 |
+
Maximum = enum_auto()
|
| 117 |
+
Minimum = enum_auto()
|
| 118 |
+
Minus = enum_auto()
|
| 119 |
+
Multiplies = enum_auto()
|
| 120 |
+
MultiplyAdd = enum_auto()
|
| 121 |
+
Plus = enum_auto()
|
| 122 |
+
Exp = enum_auto()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
FunctionalOpTag = {
|
| 126 |
+
FunctionalOp.AtomicAdd: "cutlass::atomic_add",
|
| 127 |
+
FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
|
| 128 |
+
FunctionalOp.Divides: "cutlass::divides",
|
| 129 |
+
FunctionalOp.Maximum: "cutlass::maximum",
|
| 130 |
+
FunctionalOp.Minimum: "cutlass::minimum",
|
| 131 |
+
FunctionalOp.Minus: "cutlass::minus",
|
| 132 |
+
FunctionalOp.Multiplies: "cutlass::multiplies",
|
| 133 |
+
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
|
| 134 |
+
FunctionalOp.Plus: "cutlass::plus",
|
| 135 |
+
FunctionalOp.Exp: "cutlass::fast_exp_op",
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class ActivationOp(enum.Enum):
|
| 140 |
+
DGelu = enum_auto()
|
| 141 |
+
Gelu = enum_auto()
|
| 142 |
+
GeluTaylor = enum_auto()
|
| 143 |
+
HardSwish = enum_auto()
|
| 144 |
+
Identity = enum_auto()
|
| 145 |
+
LeakyReLU = enum_auto()
|
| 146 |
+
ReLU = enum_auto()
|
| 147 |
+
Sigmoid = enum_auto()
|
| 148 |
+
SiLU = enum_auto()
|
| 149 |
+
Tanh = enum_auto()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
ActivationOpTag = {
|
| 153 |
+
ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
|
| 154 |
+
ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
|
| 155 |
+
ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
|
| 156 |
+
ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
|
| 157 |
+
ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
|
| 158 |
+
ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
|
| 159 |
+
ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
|
| 160 |
+
ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
|
| 161 |
+
ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
|
| 162 |
+
ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def op_tag(op) -> str:
|
| 167 |
+
"""
|
| 168 |
+
Dispatches `op` to the appropriate *Tag dictionary depending on whether
|
| 169 |
+
`op` is an ActivationOp or FunctionalOp. This is useful for cases in which
|
| 170 |
+
either type can be used.
|
| 171 |
+
|
| 172 |
+
:param op: operation to emit a tag for
|
| 173 |
+
:type op: ActivationOp | FunctionalOp
|
| 174 |
+
|
| 175 |
+
:return: tag corresponding to op
|
| 176 |
+
:rtype: str
|
| 177 |
+
"""
|
| 178 |
+
if isinstance(op, ActivationOp):
|
| 179 |
+
return ActivationOpTag[op]
|
| 180 |
+
elif isinstance(op, FunctionalOp):
|
| 181 |
+
return FunctionalOpTag[op]
|
| 182 |
+
else:
|
| 183 |
+
raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class FloatRoundStyle(enum.Enum):
|
| 187 |
+
ToNearest = enum_auto()
|
| 188 |
+
ToNearestSatfinite = enum_auto()
|
| 189 |
+
Indeterminate = enum_auto()
|
| 190 |
+
TowardZero = enum_auto()
|
| 191 |
+
TowardInfinity = enum_auto()
|
| 192 |
+
TowardNegInfinity = enum_auto()
|
| 193 |
+
HalfUlpTruncDntz = enum_auto()
|
| 194 |
+
HalfUlpTruncate = enum_auto()
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
FloatRoundStyleTag = {
|
| 198 |
+
FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest",
|
| 199 |
+
FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite",
|
| 200 |
+
FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate",
|
| 201 |
+
FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero",
|
| 202 |
+
FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity",
|
| 203 |
+
FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity",
|
| 204 |
+
FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz",
|
| 205 |
+
FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate",
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class MathInstruction:
|
| 210 |
+
"""
|
| 211 |
+
Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel
|
| 212 |
+
"""
|
| 213 |
+
|
| 214 |
+
def __init__(
|
| 215 |
+
self,
|
| 216 |
+
instruction_shape,
|
| 217 |
+
element_a,
|
| 218 |
+
element_b,
|
| 219 |
+
element_accumulator,
|
| 220 |
+
opcode_class=OpcodeClass.Simt,
|
| 221 |
+
math_operation=MathOperation.multiply_add,
|
| 222 |
+
):
|
| 223 |
+
"""
|
| 224 |
+
:param instruction_shape: size of the [M, N, K] dimensions of the instruction
|
| 225 |
+
:type instruction_shape: list or tuple
|
| 226 |
+
:param element_a: data type of operand A
|
| 227 |
+
:param element_b: data type of operand B
|
| 228 |
+
:param element_accumulator: data type used in accumulation
|
| 229 |
+
:param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core)
|
| 230 |
+
:type opcode_class: cutlass_library.library.OpcodeClass
|
| 231 |
+
:param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate)
|
| 232 |
+
:type math_operation: MathOperation
|
| 233 |
+
"""
|
| 234 |
+
self.instruction_shape = instruction_shape
|
| 235 |
+
self.element_a = element_a
|
| 236 |
+
self.element_b = element_b
|
| 237 |
+
self.element_accumulator = element_accumulator
|
| 238 |
+
self.opcode_class = opcode_class
|
| 239 |
+
self.math_operation = math_operation
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule):
|
| 243 |
+
blackwell_threadblock_shape = tile_description.threadblock_shape
|
| 244 |
+
is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule])
|
| 245 |
+
if cluster_shape[0] > 0:
|
| 246 |
+
blackwell_threadblock_shape = [
|
| 247 |
+
tile_description.threadblock_shape[0] // cluster_shape[0],
|
| 248 |
+
tile_description.threadblock_shape[1] // cluster_shape[1],
|
| 249 |
+
tile_description.threadblock_shape[2] // cluster_shape[2]
|
| 250 |
+
]
|
| 251 |
+
if is_2sm:
|
| 252 |
+
blackwell_threadblock_shape[0] *= 2
|
| 253 |
+
else:
|
| 254 |
+
blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape
|
| 255 |
+
return blackwell_threadblock_shape, is_2sm
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class TileDescription:
|
| 259 |
+
"""
|
| 260 |
+
Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes,
|
| 261 |
+
stage count, and math instruction specification
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
def __init__(
|
| 265 |
+
self,
|
| 266 |
+
threadblock_shape,
|
| 267 |
+
stages,
|
| 268 |
+
warp_count,
|
| 269 |
+
math_instruction,
|
| 270 |
+
cluster_shape=[1, 1, 1],
|
| 271 |
+
kernel_schedule: KernelScheduleType = None,
|
| 272 |
+
epilogue_schedule: EpilogueScheduleType = None,
|
| 273 |
+
tile_scheduler: TileSchedulerType = None
|
| 274 |
+
):
|
| 275 |
+
"""
|
| 276 |
+
:param threadblock_shape: shape of a threadblock tyle
|
| 277 |
+
:type threadblock_shape: list or tuple
|
| 278 |
+
:param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum
|
| 279 |
+
number of stages that can be supported for an operation on a given architecture will be computed at a later time
|
| 280 |
+
:type stages: int or None
|
| 281 |
+
:param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile
|
| 282 |
+
:type warp_count: list, tuple, or None
|
| 283 |
+
:param math_instruction: specification of the instruction type and shape to be performed and the types of its operands
|
| 284 |
+
:type math_instruction: MathInstruction
|
| 285 |
+
:param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster
|
| 286 |
+
:param kernel_schedule: type of kernel schedule to use (only available for SM90+)
|
| 287 |
+
:type kernel_schedule: cutlass_library.KernelScheduleType
|
| 288 |
+
:param epilogue_schedule: type of epilogue schedule to use (only available for SM90+)
|
| 289 |
+
:type epilogue_schedule: cutlass_library.EpilogueScheduleType
|
| 290 |
+
:param tile_scheduler: type of tile scheduler to use (only available for SM90+)
|
| 291 |
+
:type tile_scheduler: cutlass_library.TileSchedulerType
|
| 292 |
+
"""
|
| 293 |
+
if ((kernel_schedule is None and epilogue_schedule is not None) or
|
| 294 |
+
(kernel_schedule is not None and epilogue_schedule is None)):
|
| 295 |
+
raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.")
|
| 296 |
+
|
| 297 |
+
self.threadblock_shape = threadblock_shape
|
| 298 |
+
self.cluster_shape = cluster_shape
|
| 299 |
+
self.kernel_schedule = kernel_schedule
|
| 300 |
+
self.epilogue_schedule = epilogue_schedule
|
| 301 |
+
self.tile_scheduler = tile_scheduler
|
| 302 |
+
self.stages = stages
|
| 303 |
+
|
| 304 |
+
self.math_instruction = math_instruction
|
| 305 |
+
self.instruction_shape = math_instruction.instruction_shape
|
| 306 |
+
|
| 307 |
+
# Number of warps along x, y, z directions
|
| 308 |
+
self.warp_count = warp_count
|
| 309 |
+
|
| 310 |
+
self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule)
|
| 311 |
+
|
| 312 |
+
def clone_and_update(self, td: dict):
|
| 313 |
+
attrs = {
|
| 314 |
+
"cluster_shape": None,
|
| 315 |
+
"threadblock_shape": None,
|
| 316 |
+
"warp_count": None,
|
| 317 |
+
"stages": None,
|
| 318 |
+
"instruction_shape": None,
|
| 319 |
+
"kernel_schedule": None,
|
| 320 |
+
"epilogue_schedule": None,
|
| 321 |
+
"tile_scheduler": None
|
| 322 |
+
}
|
| 323 |
+
for key in attrs.keys():
|
| 324 |
+
if key in td.keys():
|
| 325 |
+
attrs[key] = td[key]
|
| 326 |
+
else:
|
| 327 |
+
attrs[key] = getattr(self, key)
|
| 328 |
+
|
| 329 |
+
attrs["math_instruction"] = MathInstruction(
|
| 330 |
+
attrs["instruction_shape"],
|
| 331 |
+
self.math_instruction.element_a,
|
| 332 |
+
self.math_instruction.element_b,
|
| 333 |
+
self.math_instruction.element_accumulator,
|
| 334 |
+
self.math_instruction.opcode_class,
|
| 335 |
+
self.math_instruction.math_operation
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# Remove the instruction shape
|
| 339 |
+
del attrs["instruction_shape"]
|
| 340 |
+
|
| 341 |
+
return TileDescription(**attrs)
|
| 342 |
+
|
| 343 |
+
@property
|
| 344 |
+
def num_threads(self):
|
| 345 |
+
"""
|
| 346 |
+
Returns the number of threads in the threadblock
|
| 347 |
+
|
| 348 |
+
:return: number of threads in the threadblock
|
| 349 |
+
:rtype: int or None (if warp count is None)
|
| 350 |
+
"""
|
| 351 |
+
if self.warp_count is not None:
|
| 352 |
+
threads = 32
|
| 353 |
+
for cnt in self.warp_count:
|
| 354 |
+
threads *= cnt
|
| 355 |
+
return threads
|
| 356 |
+
return None
|
| 357 |
+
|
| 358 |
+
def procedural_name(self):
|
| 359 |
+
"""
|
| 360 |
+
Returns a name identifying the tile description
|
| 361 |
+
|
| 362 |
+
:return: name identifying the tile description
|
| 363 |
+
:rtype: int
|
| 364 |
+
"""
|
| 365 |
+
emit_stages = 0 if self.stages is None else self.stages
|
| 366 |
+
name = "%dx%dx%d_%dx%d_%dx%d" % (
|
| 367 |
+
self.cluster_shape[0],
|
| 368 |
+
self.cluster_shape[1],
|
| 369 |
+
self.cluster_shape[2],
|
| 370 |
+
self.threadblock_shape[0],
|
| 371 |
+
self.threadblock_shape[1],
|
| 372 |
+
self.threadblock_shape[2],
|
| 373 |
+
emit_stages
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
return name
|
| 377 |
+
|
| 378 |
+
def procedural_name_2x(self):
|
| 379 |
+
"""
|
| 380 |
+
Returns a name identifying the tile description
|
| 381 |
+
|
| 382 |
+
:return: name identifying the tile description
|
| 383 |
+
:rtype: int
|
| 384 |
+
"""
|
| 385 |
+
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
|
| 386 |
+
|
| 387 |
+
def __str__(self):
|
| 388 |
+
"""
|
| 389 |
+
Returns a string with containing each of the tile description's values
|
| 390 |
+
|
| 391 |
+
:return: contents of tile description
|
| 392 |
+
:rtype: str
|
| 393 |
+
"""
|
| 394 |
+
if self.kernel_schedule is not None:
|
| 395 |
+
kschedule = self.kernel_schedule
|
| 396 |
+
else:
|
| 397 |
+
kschedule = KernelScheduleType.ScheduleAuto
|
| 398 |
+
|
| 399 |
+
if self.epilogue_schedule is not None:
|
| 400 |
+
eschedule = self.epilogue_schedule
|
| 401 |
+
else:
|
| 402 |
+
eschedule = EpilogueScheduleType.ScheduleAuto
|
| 403 |
+
|
| 404 |
+
if self.tile_scheduler is not None:
|
| 405 |
+
tschedule = self.tile_scheduler.name
|
| 406 |
+
else:
|
| 407 |
+
tschedule = "None"
|
| 408 |
+
return f"""
|
| 409 |
+
{{
|
| 410 |
+
ClusterShape: {self.cluster_shape}
|
| 411 |
+
ThreadblockShape: {self.threadblock_shape}
|
| 412 |
+
WarpCount: {self.warp_count}
|
| 413 |
+
Stages: {self.stages if self.stages is not None else 'Auto'}
|
| 414 |
+
InstructionShape: {self.math_instruction.instruction_shape}
|
| 415 |
+
Kernel schedule: {kschedule.name}
|
| 416 |
+
Epilogue schedule: {kschedule.name}
|
| 417 |
+
TileScheduler: {tschedule}
|
| 418 |
+
}}"""
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class TensorDescription:
|
| 422 |
+
def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none):
|
| 423 |
+
self.element = element
|
| 424 |
+
self.layout = layout
|
| 425 |
+
if element != DataType.void:
|
| 426 |
+
self.alignment = min(128 // DataTypeSize[self.element], alignment)
|
| 427 |
+
else:
|
| 428 |
+
self.alignment = alignment
|
| 429 |
+
self.complex_transform = complex_transform
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def CalculateSmemUsagePerStage(operation):
|
| 433 |
+
"""
|
| 434 |
+
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
|
| 435 |
+
|
| 436 |
+
:param op: operation for which the maximum stages should be computed. If stages are
|
| 437 |
+
set via the `op.tile_description.stages` parameter, this setting is ignored
|
| 438 |
+
in the present calculation
|
| 439 |
+
:type op: cutlass_cppgen.backend.Operation
|
| 440 |
+
|
| 441 |
+
:return: number of bytes of shared memory consumed by a single stage
|
| 442 |
+
:rtype: int
|
| 443 |
+
"""
|
| 444 |
+
m, n, k = operation.tile_description.threadblock_shape
|
| 445 |
+
|
| 446 |
+
if operation.operation_kind == OperationKind.Gemm:
|
| 447 |
+
stage_barrier_bytes = 32
|
| 448 |
+
return (
|
| 449 |
+
(DataTypeSize[operation.A.element] * m * k // 8)
|
| 450 |
+
+ (DataTypeSize[operation.B.element] * k * n // 8)
|
| 451 |
+
+ stage_barrier_bytes
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
raise Exception("Unsupported operation kind {}.".format(operation.operation_kind))
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def CalculateSmemUsage(operation):
|
| 458 |
+
"""
|
| 459 |
+
Returns the amount of shared memory in bytes consumed by a kernel.
|
| 460 |
+
|
| 461 |
+
:param op: operation for which the maximum stages should be computed. If stages are
|
| 462 |
+
set via the `op.tile_description.stages` parameter, this setting is ignored
|
| 463 |
+
in the present calculation
|
| 464 |
+
:type op: cutlass_cppgen.backend.Operation
|
| 465 |
+
|
| 466 |
+
:return: int
|
| 467 |
+
"""
|
| 468 |
+
return operation.tile_description.stages * CalculateSmemUsagePerStage(operation)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class ApiVersion(enum.Enum):
|
| 472 |
+
"""
|
| 473 |
+
Differentiate between CUTLASS 2.x and 3.x API versions
|
| 474 |
+
"""
|
| 475 |
+
|
| 476 |
+
v2x = enum_auto()
|
| 477 |
+
v3x = enum_auto()
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def api_version(arch, opclass, dtype):
|
| 481 |
+
"""
|
| 482 |
+
Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x
|
| 483 |
+
or 3.x for code emission.
|
| 484 |
+
|
| 485 |
+
:param arch: compute capability of device on which to run
|
| 486 |
+
:type arch: int
|
| 487 |
+
:param opclass: class of the operation being performed
|
| 488 |
+
:type opclass: cutlass_library.OpcodeClass
|
| 489 |
+
:param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same)
|
| 490 |
+
:type dtype: cutlass_library.DataType
|
| 491 |
+
|
| 492 |
+
:return: API version to be used in code emission
|
| 493 |
+
:rtype: ApiVersion
|
| 494 |
+
"""
|
| 495 |
+
if (arch in [90, 100, 101, 103] and
|
| 496 |
+
opclass == OpcodeClass.TensorOp and
|
| 497 |
+
(dtype != DataType.f64)):
|
| 498 |
+
return ApiVersion.v3x
|
| 499 |
+
else:
|
| 500 |
+
return ApiVersion.v2x
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class EmissionType(enum.Enum):
|
| 504 |
+
"""
|
| 505 |
+
Tags for whether to emit a kernel- or device-level operation
|
| 506 |
+
"""
|
| 507 |
+
|
| 508 |
+
Kernel = enum_auto()
|
| 509 |
+
Device = enum_auto()
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
|
| 35 |
+
import cutlass_cppgen
|
| 36 |
+
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
| 37 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 38 |
+
|
| 39 |
+
if cutlass_cppgen.use_rmm:
|
| 40 |
+
import rmm
|
| 41 |
+
else:
|
| 42 |
+
cudart = lazy_import("cuda.cudart")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PoolMemoryManager:
|
| 46 |
+
def __init__(self, init_pool_size: int, max_pool_size: int) -> None:
|
| 47 |
+
self.pool = rmm.mr.PoolMemoryResource(
|
| 48 |
+
rmm.mr.CudaMemoryResource(),
|
| 49 |
+
initial_pool_size=init_pool_size,
|
| 50 |
+
maximum_pool_size=max_pool_size
|
| 51 |
+
)
|
| 52 |
+
self.mr = rmm.mr.TrackingResourceAdaptor(self.pool)
|
| 53 |
+
rmm.mr.set_current_device_resource(self.mr)
|
| 54 |
+
|
| 55 |
+
def pool_size(self):
|
| 56 |
+
return self.pool.pool_size()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class DevicePtrWrapper:
|
| 60 |
+
"""
|
| 61 |
+
Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer
|
| 62 |
+
(at least in terms of the interface used by the CUTLASS Python interface)
|
| 63 |
+
"""
|
| 64 |
+
def __init__(self, dev_ptr):
|
| 65 |
+
self.dev_ptr = dev_ptr
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def ptr(self):
|
| 69 |
+
return self.dev_ptr
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _todevice(host_data):
|
| 73 |
+
"""
|
| 74 |
+
Helper for transferring host data to device memory
|
| 75 |
+
"""
|
| 76 |
+
if cutlass_cppgen.use_rmm:
|
| 77 |
+
return rmm.DeviceBuffer.to_device(host_data.tobytes())
|
| 78 |
+
else:
|
| 79 |
+
nbytes = len(host_data.tobytes())
|
| 80 |
+
dev_ptr_wrapper = device_mem_alloc(nbytes)
|
| 81 |
+
err, = cudart.cudaMemcpy(
|
| 82 |
+
dev_ptr_wrapper.ptr,
|
| 83 |
+
host_data.__array_interface__['data'][0],
|
| 84 |
+
nbytes,
|
| 85 |
+
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
|
| 86 |
+
)
|
| 87 |
+
if err != cudart.cudaError_t.cudaSuccess:
|
| 88 |
+
raise Exception(f"cudaMemcpy failed with error {err}")
|
| 89 |
+
return dev_ptr_wrapper
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def todevice(host_data, dtype=np.float32):
|
| 93 |
+
"""
|
| 94 |
+
Pass the host_data to device memory
|
| 95 |
+
"""
|
| 96 |
+
if isinstance(host_data, list):
|
| 97 |
+
return _todevice(np.array(host_data, dtype=dtype))
|
| 98 |
+
elif is_numpy_tensor(host_data):
|
| 99 |
+
return _todevice(host_data)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def device_mem_alloc(size):
|
| 103 |
+
if cutlass_cppgen.use_rmm:
|
| 104 |
+
return rmm.DeviceBuffer(size=size)
|
| 105 |
+
else:
|
| 106 |
+
err, ptr = cudart.cudaMalloc(size)
|
| 107 |
+
if err != cudart.cudaError_t.cudaSuccess:
|
| 108 |
+
raise Exception(f"cudaMalloc failed with error {err}")
|
| 109 |
+
return DevicePtrWrapper(ptr)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def align_size(size, alignment=256):
|
| 113 |
+
return ((size + alignment - 1) // alignment) * alignment
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34):
|
| 117 |
+
if cutlass_cppgen.use_rmm:
|
| 118 |
+
memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size)
|
| 119 |
+
return memory_pool
|
| 120 |
+
else:
|
| 121 |
+
return None
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
import ctypes
|
| 34 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 35 |
+
cuda = lazy_import("cuda.cuda")
|
| 36 |
+
|
| 37 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 38 |
+
|
| 39 |
+
_supports_cluster_launch = None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def supports_cluster_launch():
|
| 43 |
+
from cuda import __version__
|
| 44 |
+
_version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")]
|
| 45 |
+
global _supports_cluster_launch
|
| 46 |
+
if _supports_cluster_launch is None:
|
| 47 |
+
major, minor = _version_splits[0], _version_splits[1]
|
| 48 |
+
_supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8))
|
| 49 |
+
return _supports_cluster_launch
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class LaunchConfiguration:
|
| 53 |
+
def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0):
|
| 54 |
+
self.grid = grid
|
| 55 |
+
self.block = block
|
| 56 |
+
self.shared_memory_capacity = smem
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ExecutableOperation:
|
| 60 |
+
def __init__(self, operation):
|
| 61 |
+
self.operation = operation
|
| 62 |
+
self.module = None
|
| 63 |
+
self.kernel = None
|
| 64 |
+
|
| 65 |
+
def name(self):
|
| 66 |
+
return self.operation.procedural_name()
|
| 67 |
+
|
| 68 |
+
def emit(self):
|
| 69 |
+
return ""
|
| 70 |
+
|
| 71 |
+
def can_implement(self, configuration, arguments):
|
| 72 |
+
raise NotImplementedError()
|
| 73 |
+
|
| 74 |
+
def get_host_workspace_size(self, arguments):
|
| 75 |
+
raise NotImplementedError()
|
| 76 |
+
|
| 77 |
+
def get_device_workspace_size(self, arguments):
|
| 78 |
+
raise NotImplementedError()
|
| 79 |
+
|
| 80 |
+
def plan(self, arguments):
|
| 81 |
+
raise NotImplementedError()
|
| 82 |
+
|
| 83 |
+
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None):
|
| 84 |
+
raise NotImplementedError()
|
| 85 |
+
|
| 86 |
+
def run_with_clusters(self, launch_config, kernel_params, stream=None):
|
| 87 |
+
if not stream:
|
| 88 |
+
stream = cuda.CUstream(0)
|
| 89 |
+
if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"):
|
| 90 |
+
attr = cuda.CUlaunchAttribute()
|
| 91 |
+
attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape
|
| 92 |
+
attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
| 93 |
+
attrs = [attr]
|
| 94 |
+
|
| 95 |
+
# Allow for non-portable cluster sizes
|
| 96 |
+
err, = cuda.cuFuncSetAttribute(
|
| 97 |
+
self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)
|
| 98 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 99 |
+
return err
|
| 100 |
+
else:
|
| 101 |
+
attrs = []
|
| 102 |
+
|
| 103 |
+
config = cuda.CUlaunchConfig()
|
| 104 |
+
config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid
|
| 105 |
+
config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block
|
| 106 |
+
config.blockDimZ = launch_config.block[2]
|
| 107 |
+
config.sharedMemBytes = launch_config.shared_memory_capacity
|
| 108 |
+
config.hStream = stream
|
| 109 |
+
config.attrs = attrs
|
| 110 |
+
config.numAttrs = len(attrs)
|
| 111 |
+
|
| 112 |
+
err, = cuda.cuLaunchKernelEx(
|
| 113 |
+
config, f=self.kernel, kernelParams=kernel_params, extra=0)
|
| 114 |
+
return err
|
| 115 |
+
|
| 116 |
+
def run_without_clusters(self, launch_config, kernel_params, stream=None):
|
| 117 |
+
if not stream:
|
| 118 |
+
stream = cuda.CUstream(0)
|
| 119 |
+
err, = cuda.cuLaunchKernel(
|
| 120 |
+
self.kernel,
|
| 121 |
+
launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
|
| 122 |
+
launch_config.block[0], launch_config.block[1], launch_config.block[2],
|
| 123 |
+
launch_config.shared_memory_capacity,
|
| 124 |
+
stream,
|
| 125 |
+
kernel_params,
|
| 126 |
+
0)
|
| 127 |
+
|
| 128 |
+
return err
|
| 129 |
+
|
| 130 |
+
def run(self, host_workspace, device_workspace, launch_config, stream=None):
|
| 131 |
+
if not stream:
|
| 132 |
+
stream = cuda.CUstream(0)
|
| 133 |
+
cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
|
| 134 |
+
packed = (ctypes.c_void_p * 1)()
|
| 135 |
+
packed[0] = ctypes.addressof(cArg)
|
| 136 |
+
|
| 137 |
+
if supports_cluster_launch():
|
| 138 |
+
return self.run_with_clusters(launch_config, packed, stream)
|
| 139 |
+
else:
|
| 140 |
+
return self.run_without_clusters(launch_config, packed, stream)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
from __future__ import annotations
|
| 33 |
+
|
| 34 |
+
import ctypes
|
| 35 |
+
from typing import Union
|
| 36 |
+
|
| 37 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 38 |
+
cuda = lazy_import("cuda.cuda")
|
| 39 |
+
cudart = lazy_import("cuda.cudart")
|
| 40 |
+
import numpy as np
|
| 41 |
+
|
| 42 |
+
from cutlass_library import (
|
| 43 |
+
DataTypeNames,
|
| 44 |
+
DataTypeSize,
|
| 45 |
+
DataTypeTag,
|
| 46 |
+
LayoutType,
|
| 47 |
+
SubstituteTemplate
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
import cutlass_cppgen
|
| 51 |
+
from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
|
| 52 |
+
from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend
|
| 53 |
+
from cutlass_cppgen.backend.library import TensorDescription
|
| 54 |
+
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
|
| 55 |
+
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
|
| 56 |
+
from cutlass_cppgen.shape import MatrixCoord
|
| 57 |
+
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ReductionOperation:
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ReductionArguments:
|
| 65 |
+
"""
|
| 66 |
+
Arguments of reduction
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self,
|
| 71 |
+
operation: ReductionOperation,
|
| 72 |
+
problem_size: "list[int]",
|
| 73 |
+
partitions: int,
|
| 74 |
+
workspace: cuda.CUdeviceptr,
|
| 75 |
+
destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
|
| 76 |
+
source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
|
| 77 |
+
**kwargs,
|
| 78 |
+
) -> None:
|
| 79 |
+
# tensor_C can be interpreted as the bias with bias=True in keyword args
|
| 80 |
+
if "bias" in kwargs.keys():
|
| 81 |
+
self.bias = kwargs["bias"]
|
| 82 |
+
else:
|
| 83 |
+
# by default, tensor_C is not bias
|
| 84 |
+
self.bias = False
|
| 85 |
+
if "stream" in kwargs.keys():
|
| 86 |
+
self.stream = kwargs["stream"]
|
| 87 |
+
else:
|
| 88 |
+
self.stream = cuda.CUstream(0)
|
| 89 |
+
|
| 90 |
+
self.operation = operation
|
| 91 |
+
self.ptr_workspace = workspace
|
| 92 |
+
|
| 93 |
+
# number of split-k partitions
|
| 94 |
+
self.partitions = partitions
|
| 95 |
+
|
| 96 |
+
if is_numpy_tensor(destination):
|
| 97 |
+
self.host_D = destination
|
| 98 |
+
self.destination_buffer = NumpyFrontend.argument(destination, True)
|
| 99 |
+
self.source_buffer = NumpyFrontend.argument(source, False)
|
| 100 |
+
self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr)
|
| 101 |
+
self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr)
|
| 102 |
+
elif is_torch_tensor(destination):
|
| 103 |
+
self.ptr_destination = TorchFrontend.argument(destination)
|
| 104 |
+
self.ptr_source = TorchFrontend.argument(source)
|
| 105 |
+
elif isinstance(destination, cuda.CUdeviceptr):
|
| 106 |
+
self.ptr_destination = destination
|
| 107 |
+
self.ptr_source = source
|
| 108 |
+
else:
|
| 109 |
+
raise TypeError("unknown Type")
|
| 110 |
+
|
| 111 |
+
self.problem_size = MatrixCoord_(problem_size[0], problem_size[1])
|
| 112 |
+
|
| 113 |
+
self.partition_stride = (
|
| 114 |
+
problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if "output_op" in kwargs.keys():
|
| 118 |
+
self.output_op = kwargs["output_op"]
|
| 119 |
+
else:
|
| 120 |
+
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
| 121 |
+
|
| 122 |
+
self.get_arguments()
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def get_tensor_ref(
|
| 126 |
+
extent: "tuple[int]",
|
| 127 |
+
device_ptr: cuda.CUdeviceptr,
|
| 128 |
+
layout: LayoutType,
|
| 129 |
+
):
|
| 130 |
+
if layout == LayoutType.RowMajor:
|
| 131 |
+
return TensorRef2D_(int(device_ptr), extent[1])
|
| 132 |
+
else:
|
| 133 |
+
raise ValueError(f"Unknown layout type {layout}")
|
| 134 |
+
|
| 135 |
+
def get_arguments(self):
|
| 136 |
+
ref_workspace = ReductionArguments.get_tensor_ref(
|
| 137 |
+
extent=[
|
| 138 |
+
self.problem_size.row,
|
| 139 |
+
self.problem_size.column,
|
| 140 |
+
],
|
| 141 |
+
device_ptr=self.ptr_workspace,
|
| 142 |
+
layout=LayoutType.RowMajor,
|
| 143 |
+
)
|
| 144 |
+
if self.bias:
|
| 145 |
+
ref_source = ReductionArguments.get_tensor_ref(
|
| 146 |
+
extent=[0, 0],
|
| 147 |
+
device_ptr=self.ptr_source,
|
| 148 |
+
layout=LayoutType.RowMajor,
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
ref_source = ReductionArguments.get_tensor_ref(
|
| 152 |
+
extent=[
|
| 153 |
+
self.problem_size.row,
|
| 154 |
+
self.problem_size.column,
|
| 155 |
+
],
|
| 156 |
+
device_ptr=self.ptr_source,
|
| 157 |
+
layout=LayoutType.RowMajor,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
ref_destination = ReductionArguments.get_tensor_ref(
|
| 161 |
+
extent=[
|
| 162 |
+
self.problem_size.row,
|
| 163 |
+
self.problem_size.column,
|
| 164 |
+
],
|
| 165 |
+
device_ptr=self.ptr_destination,
|
| 166 |
+
layout=LayoutType.RowMajor,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self.c_arguments = self.operation.argument_type(
|
| 170 |
+
self.problem_size,
|
| 171 |
+
self.partitions,
|
| 172 |
+
self.partition_stride,
|
| 173 |
+
ref_workspace,
|
| 174 |
+
ref_destination,
|
| 175 |
+
ref_source,
|
| 176 |
+
self.output_op,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments))
|
| 180 |
+
self.host_workspace = bytearray(params_.contents)
|
| 181 |
+
|
| 182 |
+
def sync(self):
|
| 183 |
+
(err,) = cudart.cudaDeviceSynchronize()
|
| 184 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 185 |
+
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 186 |
+
|
| 187 |
+
if hasattr(self, "host_D"):
|
| 188 |
+
(err,) = cuda.cuMemcpyDtoH(
|
| 189 |
+
self.host_D,
|
| 190 |
+
self.ptr_destination,
|
| 191 |
+
self.host_D.size * self.host_D.itemsize,
|
| 192 |
+
)
|
| 193 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 194 |
+
raise RuntimeError("CUDA Error %s" % str(err))
|
| 195 |
+
|
| 196 |
+
self.free()
|
| 197 |
+
|
| 198 |
+
def free(self):
|
| 199 |
+
"""
|
| 200 |
+
Frees allocated device-side memory
|
| 201 |
+
"""
|
| 202 |
+
# Free any device memory allocated manually
|
| 203 |
+
if not cutlass_cppgen.use_rmm:
|
| 204 |
+
for attr in ["destination_buffer", "source_buffer"]:
|
| 205 |
+
if hasattr(self, attr):
|
| 206 |
+
buf = getattr(self, attr)
|
| 207 |
+
if isinstance(buf, DevicePtrWrapper):
|
| 208 |
+
err, = cudart.cudaFree(buf.ptr)
|
| 209 |
+
if err != cudart.cudaError_t.cudaSuccess:
|
| 210 |
+
raise RuntimeError(f"cudaFree failed with error {err}")
|
| 211 |
+
del buf
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class ReductionRT(ExecutableOperation):
|
| 215 |
+
"""
|
| 216 |
+
ReductionRT manages the CUTLASS runtime components for reduction
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
KernelTemplate = r"""
|
| 220 |
+
extern "C"
|
| 221 |
+
__global__ void
|
| 222 |
+
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
| 223 |
+
|
| 224 |
+
// Dynamic shared memory base pointer
|
| 225 |
+
extern __shared__ int SharedStorageBase[];
|
| 226 |
+
|
| 227 |
+
// Declare pointer to dynamic shared memory.
|
| 228 |
+
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
| 229 |
+
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
| 230 |
+
|
| 231 |
+
${operation_name}${operation_suffix} op;
|
| 232 |
+
|
| 233 |
+
op(params, *shared_storage);
|
| 234 |
+
}
|
| 235 |
+
"""
|
| 236 |
+
HostTemplate = r"""
|
| 237 |
+
extern "C" {
|
| 238 |
+
// Get the size of params in bytes
|
| 239 |
+
int ${operation_name}_get_param_size(){
|
| 240 |
+
return sizeof(${operation_name}${operation_suffix}::Params);
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
// Get the size of dynamic shared memory in bytes
|
| 244 |
+
int ${operation_name}_shared_memory_size() {
|
| 245 |
+
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
// Get the params as byte array
|
| 249 |
+
char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){
|
| 250 |
+
char *bytes = ((char*)(params));
|
| 251 |
+
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
|
| 252 |
+
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
|
| 253 |
+
output[i] = bytes[i];
|
| 254 |
+
|
| 255 |
+
return output;
|
| 256 |
+
}
|
| 257 |
+
}
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(self, operation: ReductionOperation):
|
| 261 |
+
super().__init__(operation)
|
| 262 |
+
|
| 263 |
+
self.operation: ReductionOperation = operation
|
| 264 |
+
self.emitter = EmitReductionInstance("_type")
|
| 265 |
+
|
| 266 |
+
self.elements_per_access = self.operation.count
|
| 267 |
+
(
|
| 268 |
+
self.argument_type,
|
| 269 |
+
self.epilogue_type,
|
| 270 |
+
) = get_reduction_params(operation.epilogue_functor)
|
| 271 |
+
self.argtype = [ctypes.POINTER(self.argument_type)]
|
| 272 |
+
|
| 273 |
+
def emit(self):
|
| 274 |
+
return self.emitter.emit(self.operation)
|
| 275 |
+
|
| 276 |
+
def plan(self, arguments: ReductionArguments):
|
| 277 |
+
block_shape = [
|
| 278 |
+
self.operation.shape.column // self.elements_per_access,
|
| 279 |
+
self.operation.shape.row,
|
| 280 |
+
1,
|
| 281 |
+
]
|
| 282 |
+
grid_shape = [
|
| 283 |
+
(arguments.problem_size.row + self.operation.shape.row - 1)
|
| 284 |
+
// self.operation.shape.row,
|
| 285 |
+
(arguments.problem_size.column + self.operation.shape.column - 1)
|
| 286 |
+
// self.operation.shape.column,
|
| 287 |
+
1,
|
| 288 |
+
]
|
| 289 |
+
return LaunchConfiguration(
|
| 290 |
+
grid_shape,
|
| 291 |
+
block_shape,
|
| 292 |
+
self.shared_memory_capacity,
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
def initialize(self):
|
| 296 |
+
(err,) = cuda.cuFuncSetAttribute(
|
| 297 |
+
self.kernel,
|
| 298 |
+
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
| 299 |
+
value=self.shared_memory_capacity,
|
| 300 |
+
)
|
| 301 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 302 |
+
raise RuntimeError(f"CUDA Error: {err}")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class ReductionOperation:
|
| 306 |
+
"""
|
| 307 |
+
CUTLASS reduction Operation
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
shape: MatrixCoord,
|
| 313 |
+
C: TensorDescription,
|
| 314 |
+
element_accumulator,
|
| 315 |
+
element_workspace=None,
|
| 316 |
+
element_compute=None,
|
| 317 |
+
epilogue_functor=None,
|
| 318 |
+
count: int = 1,
|
| 319 |
+
partitions_per_stage: int = 4,
|
| 320 |
+
) -> None:
|
| 321 |
+
self.shape = shape
|
| 322 |
+
self.epilogue_functor = epilogue_functor
|
| 323 |
+
self.element_accumulator = element_accumulator
|
| 324 |
+
|
| 325 |
+
if element_workspace is None:
|
| 326 |
+
self.element_workspace = element_accumulator
|
| 327 |
+
else:
|
| 328 |
+
self.element_workspace = element_workspace
|
| 329 |
+
|
| 330 |
+
if element_compute is None:
|
| 331 |
+
self.element_compute = element_accumulator
|
| 332 |
+
else:
|
| 333 |
+
self.element_compute = element_compute
|
| 334 |
+
|
| 335 |
+
self.element_output = C.element
|
| 336 |
+
self.C: TensorDescription = C
|
| 337 |
+
|
| 338 |
+
# Reduce op processing size
|
| 339 |
+
self.count: int = count
|
| 340 |
+
|
| 341 |
+
# Number of partitions to reduce per stage
|
| 342 |
+
self.partitions_per_stage: int = partitions_per_stage
|
| 343 |
+
|
| 344 |
+
self.rt_module: ReductionRT = ReductionRT(self)
|
| 345 |
+
self.argument_type = self.rt_module.argument_type
|
| 346 |
+
self.epilogue_type = self.rt_module.epilogue_type
|
| 347 |
+
|
| 348 |
+
def extended_name(self):
|
| 349 |
+
extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}"
|
| 350 |
+
|
| 351 |
+
return SubstituteTemplate(
|
| 352 |
+
extend_name,
|
| 353 |
+
{
|
| 354 |
+
"element_workspace": DataTypeNames[self.element_workspace],
|
| 355 |
+
"element_accumulator": DataTypeNames[self.element_accumulator],
|
| 356 |
+
"element_compute": DataTypeNames[self.element_compute],
|
| 357 |
+
"element_output": DataTypeNames[self.element_output],
|
| 358 |
+
},
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
def configuration_name(self):
|
| 362 |
+
"""The full procedural name indicates architecture, extended name, tile size"""
|
| 363 |
+
|
| 364 |
+
configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}"
|
| 365 |
+
|
| 366 |
+
threadblock = "%dx%d" % (
|
| 367 |
+
self.shape.row,
|
| 368 |
+
self.shape.column,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
return SubstituteTemplate(
|
| 372 |
+
configuration_name,
|
| 373 |
+
{
|
| 374 |
+
"extended_name": self.extended_name(),
|
| 375 |
+
"threadblock": threadblock,
|
| 376 |
+
},
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
def procedural_name(self):
|
| 380 |
+
"""The full procedural name indicates architeture, extended name, tile size"""
|
| 381 |
+
return self.configuration_name()
|
| 382 |
+
|
| 383 |
+
def run(self, arguments: ReductionArguments) -> cuda.CUresult:
|
| 384 |
+
"""
|
| 385 |
+
Configure and launch the cuda kernel with input arguments
|
| 386 |
+
"""
|
| 387 |
+
launch_config = self.rt_module.plan(arguments)
|
| 388 |
+
|
| 389 |
+
host_workspace = arguments.host_workspace
|
| 390 |
+
device_workspace = None
|
| 391 |
+
|
| 392 |
+
err = self.rt_module.run(
|
| 393 |
+
host_workspace,
|
| 394 |
+
device_workspace,
|
| 395 |
+
launch_config,
|
| 396 |
+
arguments.stream
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 400 |
+
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 401 |
+
|
| 402 |
+
return err
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
class EmitReductionInstance:
|
| 406 |
+
def __init__(self, operation_suffix="") -> None:
|
| 407 |
+
self.operation_suffix = operation_suffix
|
| 408 |
+
self.includes = [
|
| 409 |
+
"cutlass/cutlass.h",
|
| 410 |
+
"cutlass/numeric_types.h",
|
| 411 |
+
"cutlass/arch/arch.h",
|
| 412 |
+
"cutlass/arch/mma.h",
|
| 413 |
+
"cutlass/layout/matrix.h",
|
| 414 |
+
"cutlass/gemm/device/gemm.h",
|
| 415 |
+
"cutlass/gemm/device/gemm_universal_adapter.h",
|
| 416 |
+
"cutlass/gemm/kernel/default_gemm_universal.h",
|
| 417 |
+
"cutlass/reduction/kernel/reduce_split_k.h",
|
| 418 |
+
"cutlass/reduction/thread/reduction_operators.h",
|
| 419 |
+
]
|
| 420 |
+
self.template = """
|
| 421 |
+
// Reduction kernel instance
|
| 422 |
+
using ${operation_name}_base =
|
| 423 |
+
typename cutlass::reduction::kernel::ReduceSplitK<
|
| 424 |
+
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
|
| 425 |
+
${epilogue_functor},
|
| 426 |
+
cutlass::reduction::thread::ReduceAdd<
|
| 427 |
+
${element_accumulator},
|
| 428 |
+
${element_output},
|
| 429 |
+
${count}>,
|
| 430 |
+
${partition_per_stage}>;
|
| 431 |
+
|
| 432 |
+
struct ${operation_name}${operation_suffix}:
|
| 433 |
+
public ${operation_name}_base { };
|
| 434 |
+
"""
|
| 435 |
+
|
| 436 |
+
def emit(self, operation: ReductionOperation):
|
| 437 |
+
vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
|
| 438 |
+
epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element]
|
| 439 |
+
|
| 440 |
+
values = {
|
| 441 |
+
"operation_name": operation.configuration_name(),
|
| 442 |
+
"operation_suffix": self.operation_suffix,
|
| 443 |
+
"shape_row": str(operation.shape.row),
|
| 444 |
+
"shape_column": str(operation.shape.column),
|
| 445 |
+
"epilogue_functor": operation.epilogue_functor.emit(),
|
| 446 |
+
"element_output": DataTypeTag[operation.element_output],
|
| 447 |
+
"epilogue_vector_length": str(epilogue_vector_length),
|
| 448 |
+
"element_accumulator": DataTypeTag[operation.element_accumulator],
|
| 449 |
+
"element_compute": DataTypeTag[operation.element_compute],
|
| 450 |
+
"element_workspace": DataTypeTag[operation.element_workspace],
|
| 451 |
+
"count": str(operation.count),
|
| 452 |
+
"partition_per_stage": str(operation.partitions_per_stage),
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
return SubstituteTemplate(self.template, values)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]"
|
| 34 |
+
|
| 35 |
+
Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]"
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utility functions for interacting with the device
|
| 35 |
+
"""
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
|
| 38 |
+
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 39 |
+
cuda = lazy_import("cuda.cuda")
|
| 40 |
+
cudart = lazy_import("cuda.cudart")
|
| 41 |
+
|
| 42 |
+
import cutlass_cppgen
|
| 43 |
+
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def check_cuda_errors(result: list):
|
| 47 |
+
"""
|
| 48 |
+
Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
|
| 49 |
+
returns the result contained in the remaining fields of `result`.
|
| 50 |
+
|
| 51 |
+
:param result: the results of the `cudart` method, consisting of an error code and any method results
|
| 52 |
+
:type result: list
|
| 53 |
+
|
| 54 |
+
:return: non-error-code results from the `results` parameter
|
| 55 |
+
"""
|
| 56 |
+
# `result` is of the format : (cudaError_t, result...)
|
| 57 |
+
err = result[0]
|
| 58 |
+
if err.value:
|
| 59 |
+
raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err)))
|
| 60 |
+
|
| 61 |
+
if len(result) == 1:
|
| 62 |
+
return None
|
| 63 |
+
elif len(result) == 2:
|
| 64 |
+
return result[1]
|
| 65 |
+
else:
|
| 66 |
+
return result[1:]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def device_cc(device: int = -1) -> int:
|
| 70 |
+
"""
|
| 71 |
+
Returns the compute capability of the device with ID `device`.
|
| 72 |
+
|
| 73 |
+
:param device: ID of the device to query
|
| 74 |
+
:type device: int
|
| 75 |
+
|
| 76 |
+
:return: compute capability of the queried device (e.g., 80 for SM80)
|
| 77 |
+
:rtype: int
|
| 78 |
+
"""
|
| 79 |
+
if device == -1:
|
| 80 |
+
device = cutlass_cppgen.device_id()
|
| 81 |
+
|
| 82 |
+
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
|
| 83 |
+
major = str(deviceProp.major)
|
| 84 |
+
minor = str(deviceProp.minor)
|
| 85 |
+
return int(major + minor)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def device_sm_count(device: int = -1):
|
| 89 |
+
if device == -1:
|
| 90 |
+
device = cutlass_cppgen.device_id()
|
| 91 |
+
err, device_sm_count = cuda.cuDeviceGetAttribute(
|
| 92 |
+
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device
|
| 93 |
+
)
|
| 94 |
+
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 95 |
+
raise Exception(
|
| 96 |
+
"Failed to retireve SM count. "
|
| 97 |
+
f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
return device_sm_count
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def to_device_ptr(tensor) -> cuda.CUdeviceptr:
|
| 104 |
+
"""
|
| 105 |
+
Converts a tensor to a CUdeviceptr
|
| 106 |
+
|
| 107 |
+
:param tensor: tensor to convert
|
| 108 |
+
:type tensor: np.ndarray | torch.Tensor | cp.ndarray | int
|
| 109 |
+
|
| 110 |
+
:return: device pointer
|
| 111 |
+
:rtype: cuda.CUdeviceptr
|
| 112 |
+
"""
|
| 113 |
+
if is_numpy_tensor(tensor):
|
| 114 |
+
ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
|
| 115 |
+
elif is_torch_tensor(tensor):
|
| 116 |
+
ptr = cuda.CUdeviceptr(tensor.data_ptr())
|
| 117 |
+
elif is_cupy_tensor(tensor):
|
| 118 |
+
ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
|
| 119 |
+
elif isinstance(tensor, cuda.CUdeviceptr):
|
| 120 |
+
ptr = tensor
|
| 121 |
+
elif isinstance(tensor, int):
|
| 122 |
+
ptr = cuda.CUdeviceptr(tensor)
|
| 123 |
+
else:
|
| 124 |
+
raise NotImplementedError(tensor)
|
| 125 |
+
|
| 126 |
+
return ptr
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from cutlass_cppgen.emit.pytorch import pytorch
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Common utilities for emitting CUTLASS kernels
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import cutlass_cppgen
|
| 38 |
+
|
| 39 |
+
# Strings used for printing information about the generation of emitted scripts
|
| 40 |
+
_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR}
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
_PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR}
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
_CUTLASS_KERNEL_ARGS_2x = """
|
| 51 |
+
typename DeviceKernel::Arguments arguments {
|
| 52 |
+
cutlass::gemm::GemmUniversalMode::kGemm,
|
| 53 |
+
{M, N, K}, // problem size
|
| 54 |
+
1,
|
| 55 |
+
{alpha, beta},
|
| 56 |
+
A, B, C, D,
|
| 57 |
+
0, 0, 0, 0, // batch strides
|
| 58 |
+
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
|
| 59 |
+
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
|
| 60 |
+
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
|
| 61 |
+
DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
|
| 62 |
+
};
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
_CUTLASS_KERNEL_ARGS_2x_STREAM_K = """
|
| 66 |
+
typename DeviceKernel::Arguments arguments {
|
| 67 |
+
cutlass::gemm::GemmUniversalMode::kGemm,
|
| 68 |
+
{M, N, K}, // problem size
|
| 69 |
+
1,
|
| 70 |
+
{alpha, beta},
|
| 71 |
+
A, B, C, D,
|
| 72 |
+
0, 0, 0, 0, // batch strides
|
| 73 |
+
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
|
| 74 |
+
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
|
| 75 |
+
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
|
| 76 |
+
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
|
| 77 |
+
-1 // avail_sms
|
| 78 |
+
};
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
_CUTLASS_KERNEL_RUN_GEMM_2x = """
|
| 82 |
+
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
| 83 |
+
|
| 84 |
+
cutlass::Status ${name}_kernel_run(int M, int N, int K,
|
| 85 |
+
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
|
| 86 |
+
ElementCompute alpha, ElementCompute beta) {
|
| 87 |
+
${args}
|
| 88 |
+
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
| 89 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 90 |
+
|
| 91 |
+
DeviceKernel gemm_op;
|
| 92 |
+
cutlass::Status status = gemm_op.initialize(arguments,
|
| 93 |
+
workspace.get(),
|
| 94 |
+
nullptr); // CUDA stream
|
| 95 |
+
|
| 96 |
+
if (status != cutlass::Status::kSuccess) {
|
| 97 |
+
return status;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
status = gemm_op();
|
| 101 |
+
return status;
|
| 102 |
+
}
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
_CUTLASS_KERNEL_RUN_GEMM_3x = """
|
| 106 |
+
using StrideA = typename DeviceKernel::GemmKernel::StrideA;
|
| 107 |
+
using StrideB = typename DeviceKernel::GemmKernel::StrideB;
|
| 108 |
+
using StrideC = typename DeviceKernel::GemmKernel::StrideC;
|
| 109 |
+
using StrideD = typename DeviceKernel::GemmKernel::StrideD;
|
| 110 |
+
|
| 111 |
+
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
| 112 |
+
|
| 113 |
+
cutlass::Status ${name}_kernel_run(
|
| 114 |
+
int M, int N, int K, int L,
|
| 115 |
+
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
|
| 116 |
+
ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) {
|
| 117 |
+
|
| 118 |
+
typename DeviceKernel::Arguments arguments{
|
| 119 |
+
cutlass::gemm::GemmUniversalMode::kGemm,
|
| 120 |
+
{M, N, K, L}, // problem size
|
| 121 |
+
{
|
| 122 |
+
A, // ptrA
|
| 123 |
+
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
|
| 124 |
+
B, // ptrB
|
| 125 |
+
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
{alpha, beta},
|
| 129 |
+
C, // ptrC
|
| 130 |
+
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
|
| 131 |
+
D, // ptrD
|
| 132 |
+
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
|
| 133 |
+
},
|
| 134 |
+
hw_info
|
| 135 |
+
};
|
| 136 |
+
|
| 137 |
+
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
| 138 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 139 |
+
|
| 140 |
+
DeviceKernel gemm_op;
|
| 141 |
+
cutlass::Status status = gemm_op.run(arguments,
|
| 142 |
+
workspace.get(),
|
| 143 |
+
nullptr); // CUDA stream
|
| 144 |
+
|
| 145 |
+
return status;
|
| 146 |
+
}
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """
|
| 151 |
+
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
| 152 |
+
|
| 153 |
+
int threadblock_count = DeviceKernel::sufficient();
|
| 154 |
+
|
| 155 |
+
cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes,
|
| 156 |
+
DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D,
|
| 157 |
+
int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd,
|
| 158 |
+
ElementCompute alpha, ElementCompute beta) {
|
| 159 |
+
|
| 160 |
+
typename DeviceKernel::Arguments arguments {
|
| 161 |
+
problem_sizes,
|
| 162 |
+
problem_count,
|
| 163 |
+
threadblock_count,
|
| 164 |
+
{alpha, beta},
|
| 165 |
+
A, B, C, D,
|
| 166 |
+
lda, ldb, ldc, ldd
|
| 167 |
+
};
|
| 168 |
+
|
| 169 |
+
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
| 170 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 171 |
+
|
| 172 |
+
DeviceKernel gemm_op;
|
| 173 |
+
cutlass::Status status = gemm_op.initialize(arguments,
|
| 174 |
+
workspace.get(),
|
| 175 |
+
nullptr); // CUDA stream
|
| 176 |
+
|
| 177 |
+
if (status != cutlass::Status::kSuccess) {
|
| 178 |
+
return status;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
status = gemm_op();
|
| 182 |
+
return status;
|
| 183 |
+
}
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
_CUTLASS_KERNEL_RUN_CONV2D_2x = """
|
| 188 |
+
|
| 189 |
+
using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
|
| 190 |
+
namespace {
|
| 191 |
+
using TensorRefA = typename UnderlyingKernel::TensorRefA;
|
| 192 |
+
using TensorRefB = typename UnderlyingKernel::TensorRefB;
|
| 193 |
+
using TensorRefC = typename UnderlyingKernel::TensorRefC;
|
| 194 |
+
using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
template<typename TensorRef, typename Element>
|
| 198 |
+
TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){
|
| 199 |
+
cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
|
| 200 |
+
TensorRef tensor_ref(ptr, layout);
|
| 201 |
+
return tensor_ref;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size,
|
| 205 |
+
UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B,
|
| 206 |
+
UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D,
|
| 207 |
+
ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
|
| 208 |
+
cudaStream_t stream, int device_id=0) {
|
| 209 |
+
// create the tensor references
|
| 210 |
+
cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
|
| 211 |
+
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
| 212 |
+
);
|
| 213 |
+
cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
|
| 214 |
+
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
| 215 |
+
);
|
| 216 |
+
cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
|
| 217 |
+
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
| 218 |
+
);
|
| 219 |
+
|
| 220 |
+
TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
|
| 221 |
+
TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
|
| 222 |
+
TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
|
| 223 |
+
TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
|
| 224 |
+
|
| 225 |
+
cutlass::conv::SplitKMode mode;
|
| 226 |
+
if (split_k_mode == "serial") {
|
| 227 |
+
mode = cutlass::conv::SplitKMode::kSerial;
|
| 228 |
+
} else if (split_k_mode == "parallel") {
|
| 229 |
+
mode = cutlass::conv::SplitKMode::kParallel;
|
| 230 |
+
} else {
|
| 231 |
+
throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
typename DeviceKernel::Arguments arguments{
|
| 235 |
+
*problem_size,
|
| 236 |
+
tensor_ref_A,
|
| 237 |
+
tensor_ref_B,
|
| 238 |
+
tensor_ref_C,
|
| 239 |
+
tensor_ref_D,
|
| 240 |
+
{alpha, beta},
|
| 241 |
+
mode
|
| 242 |
+
};
|
| 243 |
+
|
| 244 |
+
DeviceKernel implicit_gemm_op;
|
| 245 |
+
|
| 246 |
+
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
| 247 |
+
|
| 248 |
+
void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
|
| 249 |
+
|
| 250 |
+
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
|
| 251 |
+
if (status != cutlass::Status::kSuccess) {
|
| 252 |
+
return status;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
|
| 256 |
+
if (status != cutlass::Status::kSuccess) {
|
| 257 |
+
return status;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
//
|
| 261 |
+
// Launch initialized CUTLASS kernel
|
| 262 |
+
//
|
| 263 |
+
status = implicit_gemm_op(stream);
|
| 264 |
+
|
| 265 |
+
return status;
|
| 266 |
+
}
|
| 267 |
+
"""
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py
ADDED
|
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.
|
| 35 |
+
If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.
|
| 36 |
+
|
| 37 |
+
Example usage with JIT compilation:
|
| 38 |
+
|
| 39 |
+
.. highlight:: python
|
| 40 |
+
.. code-block:: python
|
| 41 |
+
|
| 42 |
+
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
|
| 43 |
+
op = plan.construct()
|
| 44 |
+
mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
|
| 45 |
+
|
| 46 |
+
# Generate inputs for the GEMM
|
| 47 |
+
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
|
| 48 |
+
|
| 49 |
+
# Run the module
|
| 50 |
+
D = mod.run(A, B, C)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
Example usage without JIT compilation:
|
| 54 |
+
|
| 55 |
+
.. highlight:: python
|
| 56 |
+
.. code-block:: python
|
| 57 |
+
|
| 58 |
+
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 59 |
+
op = plan.construct()
|
| 60 |
+
cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
|
| 61 |
+
|
| 62 |
+
After this call, the directory ``output`` contains ``setup.py``,
|
| 63 |
+
``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
|
| 64 |
+
within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.
|
| 65 |
+
|
| 66 |
+
The module can later be used in Python via:
|
| 67 |
+
|
| 68 |
+
.. highlight:: python
|
| 69 |
+
.. code-block:: python
|
| 70 |
+
|
| 71 |
+
import torch
|
| 72 |
+
import cutlass_gemm
|
| 73 |
+
|
| 74 |
+
# Generate inputs for the GEMM
|
| 75 |
+
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
|
| 76 |
+
|
| 77 |
+
# Run the module
|
| 78 |
+
D = cutlass_gemm.run(A, B, C)
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
import logging
|
| 82 |
+
import os
|
| 83 |
+
|
| 84 |
+
from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate
|
| 85 |
+
|
| 86 |
+
from cutlass_cppgen import CUTLASS_PATH, logger, swizzle
|
| 87 |
+
from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
|
| 88 |
+
from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation
|
| 89 |
+
from cutlass_cppgen.backend.library import ApiVersion
|
| 90 |
+
from cutlass_cppgen.emit import common
|
| 91 |
+
from cutlass_cppgen.utils.datatypes import is_torch_available
|
| 92 |
+
|
| 93 |
+
if is_torch_available():
|
| 94 |
+
import torch
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
| 98 |
+
#include <cuda_runtime.h>
|
| 99 |
+
#include <torch/extension.h>
|
| 100 |
+
#include <ATen/ATen.h>
|
| 101 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 102 |
+
#include "cutlass/cutlass.h"
|
| 103 |
+
#include "cutlass/util/device_memory.h"
|
| 104 |
+
|
| 105 |
+
// helper function allocating the memory
|
| 106 |
+
void* device_memory_allocation(size_t size, int device_id=0) {
|
| 107 |
+
if (size > 0) {
|
| 108 |
+
torch::Device device(torch::kCUDA, device_id);
|
| 109 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 110 |
+
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
|
| 111 |
+
at::Tensor device_tensor = torch::empty({(long)size,}, options);
|
| 112 |
+
return reinterpret_cast<void*>(device_tensor.data_ptr());
|
| 113 |
+
} else {
|
| 114 |
+
return nullptr;
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
${includes}
|
| 119 |
+
${declaration}
|
| 120 |
+
${impl}
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
| 124 |
+
#include <torch/extension.h>
|
| 125 |
+
#include <ATen/ATen.h>
|
| 126 |
+
#include <pybind11/stl.h>
|
| 127 |
+
|
| 128 |
+
// CUDA forward declarations
|
| 129 |
+
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);
|
| 130 |
+
|
| 131 |
+
// C++ interface
|
| 132 |
+
at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {
|
| 133 |
+
return ${name}_kernel(A, B, C, alpha, beta);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 137 |
+
m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
|
| 138 |
+
}
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
| 142 |
+
#include <torch/extension.h>
|
| 143 |
+
#include <ATen/ATen.h>
|
| 144 |
+
#include <pybind11/stl.h>
|
| 145 |
+
|
| 146 |
+
// CUDA forward declarations
|
| 147 |
+
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);
|
| 148 |
+
|
| 149 |
+
// C++ interface
|
| 150 |
+
std::vector<at::Tensor> ${name}(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {
|
| 151 |
+
return ${name}_kernel(A, B, C, alpha, beta);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 155 |
+
m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&${name}),
|
| 156 |
+
py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
|
| 157 |
+
}
|
| 158 |
+
"""
|
| 159 |
+
|
| 160 |
+
_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
| 161 |
+
#include <torch/extension.h>
|
| 162 |
+
#include <ATen/ATen.h>
|
| 163 |
+
#include <pybind11/stl.h>
|
| 164 |
+
|
| 165 |
+
// CUDA forward declarations
|
| 166 |
+
at::Tensor ${name}_kernel(
|
| 167 |
+
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 168 |
+
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
| 169 |
+
float alpha=1.f, float beta=0.f,
|
| 170 |
+
std::string split_k_mode="serial", int split_k_slices=1);
|
| 171 |
+
|
| 172 |
+
// C++ interface
|
| 173 |
+
at::Tensor ${name}(
|
| 174 |
+
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 175 |
+
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
| 176 |
+
float alpha=1.f, float beta=0.f,
|
| 177 |
+
std::string split_k_mode="serial", int split_k_slices=1) {
|
| 178 |
+
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 182 |
+
m.def("run",
|
| 183 |
+
py::overload_cast<
|
| 184 |
+
const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
| 185 |
+
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
| 186 |
+
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
| 187 |
+
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
| 188 |
+
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
| 189 |
+
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
| 190 |
+
}
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
| 194 |
+
#include <torch/extension.h>
|
| 195 |
+
#include <ATen/ATen.h>
|
| 196 |
+
#include <pybind11/stl.h>
|
| 197 |
+
|
| 198 |
+
// CUDA forward declarations
|
| 199 |
+
at::Tensor ${name}_kernel(
|
| 200 |
+
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 201 |
+
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
| 202 |
+
float alpha=1.f, float beta=0.f,
|
| 203 |
+
std::string split_k_mode="serial", int split_k_slices=1);
|
| 204 |
+
|
| 205 |
+
// C++ interface
|
| 206 |
+
at::Tensor ${name}(
|
| 207 |
+
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 208 |
+
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
| 209 |
+
float alpha=1.f, float beta=0.f,
|
| 210 |
+
std::string split_k_mode="serial", int split_k_slices=1) {
|
| 211 |
+
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 215 |
+
m.def("run",
|
| 216 |
+
py::overload_cast<
|
| 217 |
+
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
| 218 |
+
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
| 219 |
+
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
| 220 |
+
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
| 221 |
+
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
| 222 |
+
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
| 223 |
+
}
|
| 224 |
+
"""
|
| 225 |
+
|
| 226 |
+
_PYTORCH_GEMM_INCLUDES = {
|
| 227 |
+
ApiVersion.v2x: """
|
| 228 |
+
#include "cutlass/gemm/device/gemm_universal.h"
|
| 229 |
+
""",
|
| 230 |
+
ApiVersion.v3x: """
|
| 231 |
+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
| 232 |
+
#include "cutlass/gemm/collective/collective_builder.hpp"
|
| 233 |
+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
| 234 |
+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
| 235 |
+
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
| 236 |
+
#include "cutlass/util/packed_stride.hpp"
|
| 237 |
+
""",
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
_PYTORCH_GROUPED_GEMM_INCLUDES = """
|
| 241 |
+
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
| 242 |
+
#include "cutlass/gemm/device/gemm_grouped.h"
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
_PYTORCH_CONV2D_INCLUDES = """
|
| 246 |
+
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
| 247 |
+
#include "cutlass/conv/kernel/default_conv2d_dgrad.h"
|
| 248 |
+
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
|
| 249 |
+
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
_CUTLASS_TYPE_TO_TORCH_TYPE = {
|
| 253 |
+
DataType.f16: "torch::kF16",
|
| 254 |
+
DataType.f32: "torch::kF32",
|
| 255 |
+
DataType.f64: "torch::kF64",
|
| 256 |
+
DataType.s8: "torch::kI8",
|
| 257 |
+
DataType.s32: "torch::kI32",
|
| 258 |
+
DataType.bf16: "torch::kBFloat16",
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
|
| 262 |
+
common._CUTLASS_KERNEL_RUN_GEMM_2x
|
| 263 |
+
+ """
|
| 264 |
+
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
|
| 265 |
+
int M = A.size(0);
|
| 266 |
+
int N = B.size(1);
|
| 267 |
+
int K = A.size(1);
|
| 268 |
+
|
| 269 |
+
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
|
| 270 |
+
nullptr :
|
| 271 |
+
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
|
| 272 |
+
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
|
| 273 |
+
|
| 274 |
+
cutlass::Status status = ${name}_kernel_run(M, N, K,
|
| 275 |
+
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
|
| 276 |
+
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
|
| 277 |
+
ptrC,
|
| 278 |
+
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
|
| 279 |
+
ElementCompute(alpha), ElementCompute(beta));
|
| 280 |
+
|
| 281 |
+
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
| 282 |
+
return D;
|
| 283 |
+
}
|
| 284 |
+
"""
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
_PYTORCH_GEMM_IMPL_TEMPLATE_3x = (
|
| 288 |
+
common._CUTLASS_KERNEL_RUN_GEMM_3x
|
| 289 |
+
+ """
|
| 290 |
+
bool hw_info_queried = false;
|
| 291 |
+
cutlass::KernelHardwareInfo hw_info;
|
| 292 |
+
|
| 293 |
+
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
|
| 294 |
+
int M = A.size(0);
|
| 295 |
+
int N = B.size(1);
|
| 296 |
+
int K = A.size(1);
|
| 297 |
+
int L = 1;
|
| 298 |
+
|
| 299 |
+
// Query hardware info if we haven't already
|
| 300 |
+
if (!hw_info_queried) {
|
| 301 |
+
hw_info.device_id = 0;
|
| 302 |
+
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
|
| 306 |
+
nullptr :
|
| 307 |
+
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
|
| 308 |
+
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
|
| 309 |
+
|
| 310 |
+
cutlass::Status status = ${name}_kernel_run(M, N, K, L,
|
| 311 |
+
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
|
| 312 |
+
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
|
| 313 |
+
ptrC,
|
| 314 |
+
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
|
| 315 |
+
ElementCompute(alpha), ElementCompute(beta),
|
| 316 |
+
hw_info);
|
| 317 |
+
|
| 318 |
+
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
| 319 |
+
return D;
|
| 320 |
+
}
|
| 321 |
+
"""
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = (
|
| 326 |
+
common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x
|
| 327 |
+
+ """
|
| 328 |
+
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {
|
| 329 |
+
size_t num = A.size();
|
| 330 |
+
|
| 331 |
+
// To avoid performing many small cudaMallocs and host-to-device copies,
|
| 332 |
+
// we serialize the grouped GEMM arguments on the host, allocate one
|
| 333 |
+
// large chunk of device memory, and perform a single cudaMemcpy to
|
| 334 |
+
// copy the host data to the device. Allocation overheads could be
|
| 335 |
+
// avoided by using a memory pool.
|
| 336 |
+
|
| 337 |
+
// Calculate the total size of the data to be copied from host to device
|
| 338 |
+
size_t total_size = sizeof(cutlass::gemm::GemmCoord) +
|
| 339 |
+
sizeof(DeviceKernel::ElementA*) +
|
| 340 |
+
sizeof(DeviceKernel::ElementB*) +
|
| 341 |
+
sizeof(DeviceKernel::ElementC*) +
|
| 342 |
+
sizeof(DeviceKernel::ElementC*) +
|
| 343 |
+
sizeof(int64_t) +
|
| 344 |
+
sizeof(int64_t) +
|
| 345 |
+
sizeof(int64_t);
|
| 346 |
+
total_size *= num;
|
| 347 |
+
|
| 348 |
+
// num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple
|
| 349 |
+
// of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).
|
| 350 |
+
// To ensure that we don't end up having misaligned loads in the kernel,
|
| 351 |
+
// we pad to the nearest multiple of 8.
|
| 352 |
+
//
|
| 353 |
+
// Note that, even on a 32-bit system (for which sizeof(X*) will not equal
|
| 354 |
+
// sizeof(int64_t)), only padding between the list of GemmCoords and the
|
| 355 |
+
// list of ptr_As is sufficient because the set of four equal-length lists of pointers
|
| 356 |
+
// (A*, B*, C*, D*) will ensure that the first list of int64_ts will always
|
| 357 |
+
// start on a multiple of 8.
|
| 358 |
+
int64_t padding = 8 - (total_size % 8);
|
| 359 |
+
total_size += padding;
|
| 360 |
+
|
| 361 |
+
uint8_t* host_data = new uint8_t[total_size];
|
| 362 |
+
cutlass::DeviceAllocation<uint8_t> device_data(total_size);
|
| 363 |
+
|
| 364 |
+
uint8_t* start = host_data;
|
| 365 |
+
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);
|
| 366 |
+
|
| 367 |
+
// Apply the padding after the list of GemmCoords
|
| 368 |
+
start += num * sizeof(cutlass::gemm::GemmCoord) + padding;
|
| 369 |
+
|
| 370 |
+
int64_t ptr_A_offset = start - host_data;
|
| 371 |
+
DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);
|
| 372 |
+
start += num * sizeof(DeviceKernel::ElementA*);
|
| 373 |
+
|
| 374 |
+
int64_t ptr_B_offset = start - host_data;
|
| 375 |
+
DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);
|
| 376 |
+
start += num * sizeof(DeviceKernel::ElementB*);
|
| 377 |
+
|
| 378 |
+
int64_t ptr_C_offset = start - host_data;
|
| 379 |
+
DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
|
| 380 |
+
start += num * sizeof(DeviceKernel::ElementC*);
|
| 381 |
+
|
| 382 |
+
int64_t ptr_D_offset = start - host_data;
|
| 383 |
+
DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
|
| 384 |
+
start += num * sizeof(DeviceKernel::ElementC*);
|
| 385 |
+
|
| 386 |
+
int64_t lda_offset = start - host_data;
|
| 387 |
+
int64_t* lda_host = reinterpret_cast<int64_t*>(start);
|
| 388 |
+
start += num * sizeof(int64_t);
|
| 389 |
+
|
| 390 |
+
int64_t ldb_offset = start - host_data;
|
| 391 |
+
int64_t* ldb_host = reinterpret_cast<int64_t*>(start);
|
| 392 |
+
start += num * sizeof(int64_t);
|
| 393 |
+
|
| 394 |
+
int64_t ldc_offset = start - host_data;
|
| 395 |
+
int64_t* ldc_host = reinterpret_cast<int64_t*>(start);
|
| 396 |
+
start += num * sizeof(int64_t);
|
| 397 |
+
|
| 398 |
+
std::vector<at::Tensor> D(num);
|
| 399 |
+
|
| 400 |
+
bool need_C = (C != at::nullopt) && (beta != 0.f);
|
| 401 |
+
for (size_t i = 0; i < num; ++i) {
|
| 402 |
+
int M = A[i].size(0);
|
| 403 |
+
int N = B[i].size(1);
|
| 404 |
+
int K = A[i].size(1);
|
| 405 |
+
*(problem_sizes_host + i) = {M, N, K};
|
| 406 |
+
*(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());
|
| 407 |
+
*(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());
|
| 408 |
+
|
| 409 |
+
if (need_C) {
|
| 410 |
+
*(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());
|
| 411 |
+
}
|
| 412 |
+
else {
|
| 413 |
+
*(ptr_C_host + i) = nullptr;
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
D[i] = B[i].new_empty({M, N}, ${torch_type_C});
|
| 417 |
+
*(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());
|
| 418 |
+
|
| 419 |
+
*(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);
|
| 420 |
+
*(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);
|
| 421 |
+
*(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
device_data.copy_from_host(host_data);
|
| 425 |
+
|
| 426 |
+
cutlass::Status status = ${name}_kernel_run(
|
| 427 |
+
num,
|
| 428 |
+
reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),
|
| 429 |
+
reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),
|
| 430 |
+
reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),
|
| 431 |
+
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),
|
| 432 |
+
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),
|
| 433 |
+
reinterpret_cast<int64_t*>(device_data.get() + lda_offset),
|
| 434 |
+
reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),
|
| 435 |
+
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
|
| 436 |
+
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
|
| 437 |
+
ElementCompute(alpha), ElementCompute(beta));
|
| 438 |
+
|
| 439 |
+
delete[] host_data;
|
| 440 |
+
|
| 441 |
+
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
| 442 |
+
return D;
|
| 443 |
+
}
|
| 444 |
+
"""
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
|
| 448 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 449 |
+
|
| 450 |
+
cutlass::Status status = ${name}_kernel_run(
|
| 451 |
+
&problem_size,
|
| 452 |
+
reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
|
| 453 |
+
reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
|
| 454 |
+
ptrC,
|
| 455 |
+
reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
|
| 456 |
+
alpha, beta,
|
| 457 |
+
split_k_mode, stream, B.device().index());
|
| 458 |
+
|
| 459 |
+
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
| 460 |
+
return D;
|
| 461 |
+
}
|
| 462 |
+
"""
|
| 463 |
+
|
| 464 |
+
_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
|
| 465 |
+
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
| 466 |
+
+ """
|
| 467 |
+
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 468 |
+
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
| 469 |
+
float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
|
| 470 |
+
int N, H, W, C_, K, R, S, P, Q;
|
| 471 |
+
N = A.size(0);
|
| 472 |
+
C_ = A.size(1);
|
| 473 |
+
H = A.size(2);
|
| 474 |
+
W = A.size(3);
|
| 475 |
+
|
| 476 |
+
K = B.size(0);
|
| 477 |
+
R = B.size(2);
|
| 478 |
+
S = B.size(3);
|
| 479 |
+
|
| 480 |
+
cutlass::conv::Conv2dProblemSize problem_size(
|
| 481 |
+
cutlass::Tensor4DCoord(N, H, W, C_),
|
| 482 |
+
cutlass::Tensor4DCoord(K, R, S, C_),
|
| 483 |
+
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
| 484 |
+
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
| 485 |
+
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
| 486 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 487 |
+
split_k_slices
|
| 488 |
+
);
|
| 489 |
+
|
| 490 |
+
P = problem_size.P;
|
| 491 |
+
Q = problem_size.Q;
|
| 492 |
+
|
| 493 |
+
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
| 494 |
+
nullptr :
|
| 495 |
+
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
| 496 |
+
|
| 497 |
+
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
| 498 |
+
at::Tensor D = torch::zeros({N, K, P, Q}, options);
|
| 499 |
+
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
|
| 504 |
+
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
| 505 |
+
+ """
|
| 506 |
+
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 507 |
+
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
| 508 |
+
std::string split_k_mode="serial", int split_k_slices=1) {
|
| 509 |
+
int N, H, W, C_, K, R, S;
|
| 510 |
+
N = std::get<0>(input_size);
|
| 511 |
+
C_ = std::get<1>(input_size);
|
| 512 |
+
H = std::get<2>(input_size);
|
| 513 |
+
W = std::get<3>(input_size);
|
| 514 |
+
|
| 515 |
+
K = B.size(0);
|
| 516 |
+
R = B.size(2);
|
| 517 |
+
S = B.size(3);
|
| 518 |
+
|
| 519 |
+
cutlass::conv::Conv2dProblemSize problem_size(
|
| 520 |
+
cutlass::Tensor4DCoord(N, H, W, C_),
|
| 521 |
+
cutlass::Tensor4DCoord(K, R, S, C_),
|
| 522 |
+
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
| 523 |
+
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
| 524 |
+
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
| 525 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 526 |
+
split_k_slices
|
| 527 |
+
);
|
| 528 |
+
|
| 529 |
+
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
| 530 |
+
nullptr :
|
| 531 |
+
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
| 532 |
+
|
| 533 |
+
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
| 534 |
+
at::Tensor D = torch::empty({N, C_, H, W}, options);
|
| 535 |
+
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
|
| 540 |
+
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
| 541 |
+
+ """
|
| 542 |
+
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 543 |
+
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
| 544 |
+
std::string split_k_mode="serial", int split_k_slices=1) {
|
| 545 |
+
int N, H, W, C_, K, R, S;
|
| 546 |
+
K = std::get<0>(weight_size);
|
| 547 |
+
C_ = std::get<1>(weight_size);
|
| 548 |
+
R = std::get<2>(weight_size);
|
| 549 |
+
S = std::get<3>(weight_size);
|
| 550 |
+
|
| 551 |
+
N = B.size(0);
|
| 552 |
+
H = B.size(2);
|
| 553 |
+
W = B.size(3);
|
| 554 |
+
|
| 555 |
+
cutlass::conv::Conv2dProblemSize problem_size(
|
| 556 |
+
cutlass::Tensor4DCoord(N, H, W, C_),
|
| 557 |
+
cutlass::Tensor4DCoord(K, R, S, C_),
|
| 558 |
+
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
| 559 |
+
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
| 560 |
+
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
| 561 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 562 |
+
split_k_slices
|
| 563 |
+
);
|
| 564 |
+
|
| 565 |
+
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
| 566 |
+
nullptr :
|
| 567 |
+
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
| 568 |
+
|
| 569 |
+
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
| 570 |
+
at::Tensor D = torch::empty({K, C_, R, S}, options);
|
| 571 |
+
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
|
| 576 |
+
from setuptools import setup
|
| 577 |
+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 578 |
+
|
| 579 |
+
setup(
|
| 580 |
+
name='${name}',
|
| 581 |
+
ext_modules=[
|
| 582 |
+
CUDAExtension('${name}', [
|
| 583 |
+
'${name}.cpp',
|
| 584 |
+
'${name}_kernel.cu',
|
| 585 |
+
],
|
| 586 |
+
include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
|
| 587 |
+
extra_compile_args={
|
| 588 |
+
'cxx': ['-std=c++17'],
|
| 589 |
+
'nvcc': ['-std=c++17', ${extra_compile_args}],
|
| 590 |
+
},
|
| 591 |
+
libraries=['cuda']
|
| 592 |
+
),
|
| 593 |
+
],
|
| 594 |
+
cmdclass={
|
| 595 |
+
'build_ext': BuildExtension
|
| 596 |
+
})
|
| 597 |
+
|
| 598 |
+
"""
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""):
|
| 602 |
+
"""
|
| 603 |
+
Generates a setup.py file for the extension
|
| 604 |
+
|
| 605 |
+
:param name: name of the module to generate
|
| 606 |
+
:type name: str
|
| 607 |
+
:param sourcedir: directory to which generated source files should be written
|
| 608 |
+
:type sourcedir: str
|
| 609 |
+
:param extra_compile_args: additional arguments to pass to setup.py
|
| 610 |
+
:type extra_args: str
|
| 611 |
+
"""
|
| 612 |
+
setup_py_file = os.path.join(sourcedir, "setup.py")
|
| 613 |
+
setup_source = SubstituteTemplate(
|
| 614 |
+
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args}
|
| 615 |
+
)
|
| 616 |
+
with open(setup_py_file, "w") as outfile:
|
| 617 |
+
outfile.write(setup_source)
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
class _ArchListSetter:
|
| 621 |
+
"""
|
| 622 |
+
Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``
|
| 623 |
+
environment variable when building a PyTorch CUDA module.
|
| 624 |
+
|
| 625 |
+
``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch
|
| 626 |
+
CUDA module should be compiled.
|
| 627 |
+
|
| 628 |
+
For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of
|
| 629 |
+
``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the
|
| 630 |
+
compilation of the module.
|
| 631 |
+
|
| 632 |
+
This utility wraps the building of a PyTorch CUDA module with a setting of this environment
|
| 633 |
+
variable according to the current compute capability being targetted.
|
| 634 |
+
|
| 635 |
+
Example usage:
|
| 636 |
+
|
| 637 |
+
.. highlight:: python
|
| 638 |
+
.. code-block:: python
|
| 639 |
+
|
| 640 |
+
# Temporarily set TORCH_CUDA_ARCH_LIST="8.0"
|
| 641 |
+
with _ArchListSetter(80):
|
| 642 |
+
# Perform JIT compilation and loading of the module
|
| 643 |
+
mod = torch.utils.cpp_extension.load(...)
|
| 644 |
+
|
| 645 |
+
:param cc: compute capability
|
| 646 |
+
:type cc: int
|
| 647 |
+
"""
|
| 648 |
+
|
| 649 |
+
_TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST"
|
| 650 |
+
|
| 651 |
+
def __init__(self, cc: int):
|
| 652 |
+
self.cc_str = ".".join(list(str(cc)))
|
| 653 |
+
|
| 654 |
+
def __enter__(self):
|
| 655 |
+
"""
|
| 656 |
+
Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``
|
| 657 |
+
"""
|
| 658 |
+
self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST)
|
| 659 |
+
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str
|
| 660 |
+
|
| 661 |
+
return self
|
| 662 |
+
|
| 663 |
+
def __exit__(self, exc_type, exc_val, traceback):
|
| 664 |
+
"""
|
| 665 |
+
Restores the old value of TORCH_CUDA_ARCH_LIST
|
| 666 |
+
"""
|
| 667 |
+
if self.old_arch_list is None:
|
| 668 |
+
del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST]
|
| 669 |
+
else:
|
| 670 |
+
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
|
| 674 |
+
"""
|
| 675 |
+
JIT compiles and loads a PyTorch CUDA extension.
|
| 676 |
+
|
| 677 |
+
:param name: name of the module to generate
|
| 678 |
+
:type name: str
|
| 679 |
+
:param cc: compute capability of the device the module should target
|
| 680 |
+
:type cc: int
|
| 681 |
+
:param cpp_file: path to file containing extension's C++ interface
|
| 682 |
+
:type cpp_file: str
|
| 683 |
+
:param cuda_file: path to file containing extension's CUDA interface
|
| 684 |
+
:type cuda_file: str
|
| 685 |
+
|
| 686 |
+
:return: loaded PyTorch module
|
| 687 |
+
"""
|
| 688 |
+
|
| 689 |
+
from torch.utils.cpp_extension import load
|
| 690 |
+
|
| 691 |
+
extra_cuda_cflags = ["-std=c++17"]
|
| 692 |
+
if cc in [90, 100, 101, 103]:
|
| 693 |
+
# PyTorch does not currently add the sm_90a target when compute capability
|
| 694 |
+
# 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.
|
| 695 |
+
extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a")
|
| 696 |
+
|
| 697 |
+
with _ArchListSetter(cc):
|
| 698 |
+
jitmodule = load(
|
| 699 |
+
name,
|
| 700 |
+
[cpp_file, cuda_file],
|
| 701 |
+
extra_cuda_cflags=extra_cuda_cflags,
|
| 702 |
+
extra_include_paths=[
|
| 703 |
+
os.path.join(CUTLASS_PATH, "include"),
|
| 704 |
+
os.path.join(CUTLASS_PATH, "tools/util/include"),
|
| 705 |
+
],
|
| 706 |
+
extra_ldflags=["-lcuda"],
|
| 707 |
+
verbose=(logger.level == logging.DEBUG)
|
| 708 |
+
)
|
| 709 |
+
return jitmodule
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
| 713 |
+
"""
|
| 714 |
+
Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM
|
| 715 |
+
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
| 716 |
+
compiled, loaded, and returned.
|
| 717 |
+
|
| 718 |
+
:param op: operation to emit in the module
|
| 719 |
+
:param name: name of the module to generate
|
| 720 |
+
:type name: str
|
| 721 |
+
:param cc: compute capability of the device the module should target
|
| 722 |
+
:type cc: int
|
| 723 |
+
:param jit: whether the module should be just-in-time compiled
|
| 724 |
+
:type jit: bool
|
| 725 |
+
:param sourcedir: directory to which generated source files should be written
|
| 726 |
+
:type sourcedir: str
|
| 727 |
+
|
| 728 |
+
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
| 729 |
+
"""
|
| 730 |
+
if sourcedir != "" and not os.path.isdir(sourcedir):
|
| 731 |
+
os.makedirs(sourcedir)
|
| 732 |
+
|
| 733 |
+
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
| 734 |
+
extra_kw = {}
|
| 735 |
+
if op.api == ApiVersion.v3x:
|
| 736 |
+
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
|
| 737 |
+
else:
|
| 738 |
+
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
|
| 739 |
+
if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
|
| 740 |
+
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
|
| 741 |
+
else:
|
| 742 |
+
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
|
| 743 |
+
impl_template = (
|
| 744 |
+
_PYTORCH_GEMM_IMPL_TEMPLATE_3x
|
| 745 |
+
if op.api == ApiVersion.v3x
|
| 746 |
+
else _PYTORCH_GEMM_IMPL_TEMPLATE_2x
|
| 747 |
+
)
|
| 748 |
+
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
|
| 749 |
+
cuda_source = SubstituteTemplate(
|
| 750 |
+
_PYTORCH_CUDA_TEMPLATE,
|
| 751 |
+
{
|
| 752 |
+
"includes": _PYTORCH_GEMM_INCLUDES[op.api],
|
| 753 |
+
"declaration": op.rt_module.emit(),
|
| 754 |
+
"procedural_name": op.procedural_name(),
|
| 755 |
+
"impl": cuda_impl,
|
| 756 |
+
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
| 757 |
+
},
|
| 758 |
+
)
|
| 759 |
+
with open(cuda_file, "w") as outfile:
|
| 760 |
+
outfile.write(cuda_source)
|
| 761 |
+
|
| 762 |
+
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
| 763 |
+
cpp_source = SubstituteTemplate(
|
| 764 |
+
_PYTORCH_GEMM_CPP_TEMPLATE,
|
| 765 |
+
{"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"},
|
| 766 |
+
)
|
| 767 |
+
with open(cpp_file, "w") as outfile:
|
| 768 |
+
outfile.write(cpp_source)
|
| 769 |
+
|
| 770 |
+
extra_compile_args = ""
|
| 771 |
+
if cc in [90, 100, 101, 103]:
|
| 772 |
+
extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'"
|
| 773 |
+
_generate_setup(name, sourcedir, extra_compile_args)
|
| 774 |
+
|
| 775 |
+
if jit:
|
| 776 |
+
return _jit(name, cc, cpp_file, cuda_file)
|
| 777 |
+
|
| 778 |
+
return None
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
def _pytorch_grouped_gemm(
|
| 782 |
+
op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
|
| 783 |
+
):
|
| 784 |
+
"""
|
| 785 |
+
Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM
|
| 786 |
+
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
| 787 |
+
compiled, loaded, and returned.
|
| 788 |
+
|
| 789 |
+
:param op: operation to emit in the module
|
| 790 |
+
:param name: name of the module to generate
|
| 791 |
+
:type name: str
|
| 792 |
+
:param cc: compute capability of the device the module should target
|
| 793 |
+
:type cc: int
|
| 794 |
+
:param jit: whether the module should be just-in-time compiled
|
| 795 |
+
:type jit: bool
|
| 796 |
+
:param sourcedir: directory to which generated source files should be written
|
| 797 |
+
:type sourcedir: str
|
| 798 |
+
|
| 799 |
+
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
| 800 |
+
"""
|
| 801 |
+
if op.api != ApiVersion.v2x:
|
| 802 |
+
raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x")
|
| 803 |
+
|
| 804 |
+
if sourcedir != "" and not os.path.isdir(sourcedir):
|
| 805 |
+
os.makedirs(sourcedir)
|
| 806 |
+
|
| 807 |
+
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
| 808 |
+
cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name})
|
| 809 |
+
cuda_source = SubstituteTemplate(
|
| 810 |
+
_PYTORCH_CUDA_TEMPLATE,
|
| 811 |
+
{
|
| 812 |
+
"includes": _PYTORCH_GROUPED_GEMM_INCLUDES,
|
| 813 |
+
"declaration": op.rt_module.emit(),
|
| 814 |
+
"procedural_name": op.procedural_name(),
|
| 815 |
+
"impl": cuda_impl,
|
| 816 |
+
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
| 817 |
+
},
|
| 818 |
+
)
|
| 819 |
+
with open(cuda_file, "w") as outfile:
|
| 820 |
+
outfile.write(cuda_source)
|
| 821 |
+
|
| 822 |
+
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
| 823 |
+
cpp_source = SubstituteTemplate(
|
| 824 |
+
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE,
|
| 825 |
+
{"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"},
|
| 826 |
+
)
|
| 827 |
+
with open(cpp_file, "w") as outfile:
|
| 828 |
+
outfile.write(cpp_source)
|
| 829 |
+
|
| 830 |
+
_generate_setup(name, sourcedir)
|
| 831 |
+
|
| 832 |
+
if jit:
|
| 833 |
+
return _jit(name, cc, cpp_file, cuda_file)
|
| 834 |
+
|
| 835 |
+
return None
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
| 839 |
+
"""
|
| 840 |
+
Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d
|
| 841 |
+
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
| 842 |
+
compiled, loaded, and returned.
|
| 843 |
+
|
| 844 |
+
:param op: operation to emit in the module
|
| 845 |
+
:param name: name of the module to generate
|
| 846 |
+
:type name: str
|
| 847 |
+
:param cc: compute capability of the device the module should target
|
| 848 |
+
:type cc: int
|
| 849 |
+
:param jit: whether the module should be just-in-time compiled
|
| 850 |
+
:type jit: bool
|
| 851 |
+
:param sourcedir: directory to which generated source files should be written
|
| 852 |
+
:type sourcedir: str
|
| 853 |
+
|
| 854 |
+
Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
|
| 855 |
+
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
|
| 856 |
+
for H/W/R/S given the same P/Q.
|
| 857 |
+
|
| 858 |
+
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
| 859 |
+
"""
|
| 860 |
+
if sourcedir != "" and not os.path.isdir(sourcedir):
|
| 861 |
+
os.makedirs(sourcedir)
|
| 862 |
+
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
| 863 |
+
extra_kw = {}
|
| 864 |
+
if op.conv_kind == ConvKind.Fprop:
|
| 865 |
+
impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
|
| 866 |
+
cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
|
| 867 |
+
elif op.conv_kind == ConvKind.Dgrad:
|
| 868 |
+
impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
|
| 869 |
+
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
| 870 |
+
elif op.conv_kind == ConvKind.Wgrad:
|
| 871 |
+
impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
|
| 872 |
+
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
| 873 |
+
extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
|
| 874 |
+
extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element]
|
| 875 |
+
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
|
| 876 |
+
cuda_source = SubstituteTemplate(
|
| 877 |
+
_PYTORCH_CUDA_TEMPLATE,
|
| 878 |
+
{
|
| 879 |
+
"includes": _PYTORCH_CONV2D_INCLUDES,
|
| 880 |
+
"declaration": op.rt_module.emit(),
|
| 881 |
+
"procedural_name": op.procedural_name(),
|
| 882 |
+
"impl": cuda_impl,
|
| 883 |
+
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
| 884 |
+
},
|
| 885 |
+
)
|
| 886 |
+
with open(cuda_file, "w") as outfile:
|
| 887 |
+
outfile.write(cuda_source)
|
| 888 |
+
|
| 889 |
+
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
| 890 |
+
cpp_source = SubstituteTemplate(
|
| 891 |
+
cpp_template,
|
| 892 |
+
{"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"},
|
| 893 |
+
)
|
| 894 |
+
with open(cpp_file, "w") as outfile:
|
| 895 |
+
outfile.write(cpp_source)
|
| 896 |
+
|
| 897 |
+
_generate_setup(name, sourcedir)
|
| 898 |
+
|
| 899 |
+
if jit:
|
| 900 |
+
return _jit(name, cc, cpp_file, cuda_file)
|
| 901 |
+
|
| 902 |
+
return None
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
| 906 |
+
"""
|
| 907 |
+
Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
|
| 908 |
+
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
| 909 |
+
compiled, loaded, and returned.
|
| 910 |
+
|
| 911 |
+
The result of this method is files within ``sourcedir`` that can be used for building
|
| 912 |
+
a PyTorch module.
|
| 913 |
+
|
| 914 |
+
:param op: operation to emit in the module
|
| 915 |
+
:param name: name of the module to generate
|
| 916 |
+
:type name: str
|
| 917 |
+
:param cc: compute capability of the device the module should target
|
| 918 |
+
:type cc: int
|
| 919 |
+
:param jit: whether the module should be just-in-time compiled
|
| 920 |
+
:type jit: bool
|
| 921 |
+
:param sourcedir: directory to which generated source files should be written
|
| 922 |
+
:type sourcedir: str
|
| 923 |
+
|
| 924 |
+
:return: loaded PyTorch module (if ``jit=True``) or None
|
| 925 |
+
"""
|
| 926 |
+
device_op = op.device_op()
|
| 927 |
+
if isinstance(op, GemmOperationUniversal):
|
| 928 |
+
return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
|
| 929 |
+
elif isinstance(op, GemmOperationGrouped):
|
| 930 |
+
return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
|
| 931 |
+
elif isinstance(op, Conv2dOperation):
|
| 932 |
+
return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
|
| 933 |
+
else:
|
| 934 |
+
raise Exception(
|
| 935 |
+
f"Operation type {type(op)} is not currently supported for PyTorch emission."
|
| 936 |
+
)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from cutlass_cppgen.epilogue.epilogue import (
|
| 34 |
+
get_activations,
|
| 35 |
+
get_activation_epilogue,
|
| 36 |
+
gelu,
|
| 37 |
+
hardswish,
|
| 38 |
+
identity,
|
| 39 |
+
leaky_relu,
|
| 40 |
+
relu,
|
| 41 |
+
sigmoid,
|
| 42 |
+
silu,
|
| 43 |
+
tanh,
|
| 44 |
+
trace
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
from cutlass_cppgen.epilogue.evt_ops import (
|
| 48 |
+
max,
|
| 49 |
+
multiply_add,
|
| 50 |
+
sum,
|
| 51 |
+
permute,
|
| 52 |
+
reshape,
|
| 53 |
+
maximum,
|
| 54 |
+
minimum,
|
| 55 |
+
exp
|
| 56 |
+
)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Registry of elementwise epilogues
|
| 35 |
+
|
| 36 |
+
Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via
|
| 37 |
+
code like the following for GEMM:
|
| 38 |
+
|
| 39 |
+
.. highlight:: python
|
| 40 |
+
.. code-block:: python
|
| 41 |
+
|
| 42 |
+
plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 43 |
+
plan.activation = cutlass_cppgen.epilogue.relu
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
from cutlass_cppgen.backend import epilogue, device_cc
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
gelu = epilogue.gelu
|
| 50 |
+
hardswish = epilogue.hardswish
|
| 51 |
+
identity = epilogue.identity
|
| 52 |
+
leaky_relu = epilogue.leaky_relu
|
| 53 |
+
relu = epilogue.relu
|
| 54 |
+
sigmoid = epilogue.sigmoid
|
| 55 |
+
silu = epilogue.silu
|
| 56 |
+
tanh = epilogue.tanh
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
_activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_activations() -> list:
|
| 63 |
+
"""
|
| 64 |
+
Returns a list of available activation functions
|
| 65 |
+
|
| 66 |
+
:return: list of available activation functions
|
| 67 |
+
:rtype: list
|
| 68 |
+
"""
|
| 69 |
+
return _activations
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_activation_epilogue(
|
| 73 |
+
activation,
|
| 74 |
+
element_output,
|
| 75 |
+
elements_per_access,
|
| 76 |
+
element_accumulator,
|
| 77 |
+
element_compute,
|
| 78 |
+
):
|
| 79 |
+
"""
|
| 80 |
+
Return an epilogue corresponding to the activation function, data types, and alignment
|
| 81 |
+
used in the kernel
|
| 82 |
+
|
| 83 |
+
:param activation: elementwise activation function to use
|
| 84 |
+
:param element_output: data type of the output
|
| 85 |
+
:param elements_per_access: alignment of operand C of the kernel
|
| 86 |
+
:type elements_per_access: int
|
| 87 |
+
:param element_accumulator: data type of the accumulated output C
|
| 88 |
+
:param element_compute: data type in which compute operations should be performed
|
| 89 |
+
|
| 90 |
+
:return: epilogue functor
|
| 91 |
+
"""
|
| 92 |
+
if activation not in _activations:
|
| 93 |
+
raise Exception(
|
| 94 |
+
f"Unsupported activation type {activation}. Available activations are: {_activations}"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
if activation == identity:
|
| 98 |
+
return epilogue.LinearCombination(
|
| 99 |
+
element_output, elements_per_access, element_accumulator, element_compute
|
| 100 |
+
)
|
| 101 |
+
else:
|
| 102 |
+
return epilogue.LinearCombinationGeneric(
|
| 103 |
+
activation,
|
| 104 |
+
element_output,
|
| 105 |
+
elements_per_access,
|
| 106 |
+
element_accumulator,
|
| 107 |
+
element_compute,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
Frontend for EVT that generates epilogue functor through tracing the input function
|
| 113 |
+
"""
|
| 114 |
+
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def trace(fn, example_tensors, **kwargs):
|
| 118 |
+
"""
|
| 119 |
+
Trace `fn(**example_tensors)` and generates epilogue visitor
|
| 120 |
+
|
| 121 |
+
:param fn or str: Python callable or string of the epilogue function
|
| 122 |
+
:param example_tensors: example inputs for fn
|
| 123 |
+
:type example_tensors: dict
|
| 124 |
+
|
| 125 |
+
.. hightlight:: python
|
| 126 |
+
.. code-block:: python
|
| 127 |
+
import cutlass_cppgen.backend.evt
|
| 128 |
+
|
| 129 |
+
# Define epilogue function as Python callable
|
| 130 |
+
def example_fn(accum, C, alpha, beta, gamma):
|
| 131 |
+
D = ((accum + C) * alpha - gamma) / beta
|
| 132 |
+
return D
|
| 133 |
+
|
| 134 |
+
# Define the example tensors
|
| 135 |
+
example_inputs = {
|
| 136 |
+
"accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
|
| 137 |
+
"C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
|
| 138 |
+
"alpha": 1.5,
|
| 139 |
+
"beta": 0.5,
|
| 140 |
+
"gamma": 2.5,
|
| 141 |
+
"D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda")
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
# Generate the epilogue functor
|
| 145 |
+
epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs)
|
| 146 |
+
"""
|
| 147 |
+
if callable(fn):
|
| 148 |
+
class EpilogueFunctor(PythonASTFrontend):
|
| 149 |
+
def __init__(self, cc=None, **kwargs):
|
| 150 |
+
if not cc:
|
| 151 |
+
cc = device_cc()
|
| 152 |
+
super().__init__(cc, **kwargs)
|
| 153 |
+
pass
|
| 154 |
+
setattr(EpilogueFunctor, "__call__", staticmethod(fn))
|
| 155 |
+
|
| 156 |
+
epilogue_functor = EpilogueFunctor(**kwargs)
|
| 157 |
+
epilogue_functor.trace(example_tensors)
|
| 158 |
+
return epilogue_functor
|
| 159 |
+
elif isinstance(fn, str):
|
| 160 |
+
class EpilogueFunctor(PythonASTFrontend):
|
| 161 |
+
def __init__(self, cc=None, **kwargs):
|
| 162 |
+
self.source = textwrap.dedent(fn)
|
| 163 |
+
if not cc:
|
| 164 |
+
cc = device_cc()
|
| 165 |
+
super().__init__(cc, **kwargs)
|
| 166 |
+
|
| 167 |
+
def parse(self, example_inputs) -> None:
|
| 168 |
+
self.example_inputs = example_inputs
|
| 169 |
+
self.ast = ast.parse(self.source)
|
| 170 |
+
self.visit(self.ast)
|
| 171 |
+
|
| 172 |
+
epilogue_functor = EpilogueFunctor(**kwargs)
|
| 173 |
+
epilogue_functor.trace(example_tensors)
|
| 174 |
+
return epilogue_functor
|
| 175 |
+
else:
|
| 176 |
+
raise NotImplementedError("Expect a callable Python function")
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Collection of builtin functions used for host reference in EVT
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import numpy as np
|
| 38 |
+
|
| 39 |
+
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor
|
| 40 |
+
|
| 41 |
+
if is_torch_available():
|
| 42 |
+
import torch
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def multiply_add(x, y, z):
|
| 46 |
+
return x * y + z
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def sum(x, dim):
|
| 50 |
+
if is_numpy_tensor(x):
|
| 51 |
+
return x.sum(axis=tuple(dim))
|
| 52 |
+
elif is_torch_tensor(x):
|
| 53 |
+
return torch.sum(x, dim)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def max(x, dim):
|
| 57 |
+
if is_numpy_tensor(x):
|
| 58 |
+
return x.max(axis=tuple(dim))
|
| 59 |
+
elif is_torch_tensor(x):
|
| 60 |
+
return torch.amax(x, dim)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def maximum(x, y):
|
| 64 |
+
if is_numpy_tensor(x):
|
| 65 |
+
return np.maximum(x, y)
|
| 66 |
+
elif is_torch_tensor(x):
|
| 67 |
+
return torch.maximum(x, torch.tensor(y))
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def minimum(x, y):
|
| 71 |
+
if is_numpy_tensor(x):
|
| 72 |
+
return np.minimum(x, y)
|
| 73 |
+
elif is_torch_tensor(x):
|
| 74 |
+
return torch.minimum(x, torch.tensor(y))
|
| 75 |
+
|
| 76 |
+
def exp(x):
|
| 77 |
+
if is_numpy_tensor(x):
|
| 78 |
+
return np.exp(x)
|
| 79 |
+
elif is_torch_tensor(x):
|
| 80 |
+
return torch.exp(x)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
##############################################################################
|
| 84 |
+
# Layout manipulate nodes
|
| 85 |
+
##############################################################################
|
| 86 |
+
|
| 87 |
+
def permute(x, indices: tuple):
|
| 88 |
+
if is_numpy_tensor(x):
|
| 89 |
+
return np.transpose(x, axes=indices)
|
| 90 |
+
elif is_torch_tensor(x):
|
| 91 |
+
return x.permute(*indices)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def reshape(x, new_shape: tuple):
|
| 95 |
+
if is_numpy_tensor(x):
|
| 96 |
+
return np.reshape(x, newshape=new_shape)
|
| 97 |
+
elif is_torch_tensor(x):
|
| 98 |
+
return x.view(new_shape)
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Classes containing valid operations for a given compute capability and data types.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from itertools import combinations_with_replacement
|
| 38 |
+
import logging
|
| 39 |
+
|
| 40 |
+
import cutlass_library
|
| 41 |
+
from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
|
| 42 |
+
|
| 43 |
+
import cutlass_cppgen
|
| 44 |
+
from cutlass_cppgen.utils.check import valid_stage_count
|
| 45 |
+
from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
_generator_ccs = [50, 60, 61, 70, 75, 80, 90, 100]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class KernelsForDataType:
|
| 52 |
+
"""
|
| 53 |
+
Container class for keeping track of kernels that correspond to a particular combination
|
| 54 |
+
of data types for operands A, B, and accumulator
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, datatype_comb: tuple, layout_comb: tuple):
|
| 58 |
+
self.datatype_comb = datatype_comb
|
| 59 |
+
self.layout_comb = layout_comb
|
| 60 |
+
self.math_operations = set()
|
| 61 |
+
|
| 62 |
+
# Dictionary mapping from alignment (int) to a list of kernels that fit the alignment
|
| 63 |
+
# constraint for the data type combination
|
| 64 |
+
self.kernels_by_alignment = {}
|
| 65 |
+
|
| 66 |
+
def add(self, operation):
|
| 67 |
+
"""
|
| 68 |
+
Add an operation to the list of supported kernels
|
| 69 |
+
"""
|
| 70 |
+
alignment_key = f"{operation.A.alignment} {operation.B.alignment} {operation.C.alignment}"
|
| 71 |
+
if alignment_key not in self.kernels_by_alignment:
|
| 72 |
+
self.kernels_by_alignment[alignment_key] = []
|
| 73 |
+
self.kernels_by_alignment[alignment_key].append(operation)
|
| 74 |
+
self.math_operations.add(operation.tile_description.math_instruction.math_operation)
|
| 75 |
+
|
| 76 |
+
def alignments(self, operand: str):
|
| 77 |
+
"""
|
| 78 |
+
Returns an unsorted list of alignments supported by this data type combination
|
| 79 |
+
|
| 80 |
+
:param operand: identifier of operand in question (e.g., A, B, C)
|
| 81 |
+
:type operand: str
|
| 82 |
+
|
| 83 |
+
:return: unsorted list of alignments supported by this data type combination
|
| 84 |
+
:rtype: list
|
| 85 |
+
"""
|
| 86 |
+
operand_idx = self._operand_idx(operand)
|
| 87 |
+
return [int(key.split(" ")[operand_idx]) for key in self.kernels_by_alignment.keys()]
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def all_operations(self):
|
| 91 |
+
"""
|
| 92 |
+
Returns a list of all operations supported by this data type combination
|
| 93 |
+
|
| 94 |
+
:return: list of all operations supported by this data type combination
|
| 95 |
+
:rtype: list
|
| 96 |
+
"""
|
| 97 |
+
ops = []
|
| 98 |
+
for _, alignment_ops in self.kernels_by_alignment.items():
|
| 99 |
+
ops.extend(alignment_ops)
|
| 100 |
+
return ops
|
| 101 |
+
|
| 102 |
+
def default_operation(self, math_operation: cutlass_cppgen.MathOperation):
|
| 103 |
+
key = sorted(list(self.kernels_by_alignment.keys()))[0]
|
| 104 |
+
kernels = self.kernels_by_alignment[key]
|
| 105 |
+
if math_operation is not None:
|
| 106 |
+
kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation]
|
| 107 |
+
return kernels[0]
|
| 108 |
+
|
| 109 |
+
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation):
|
| 110 |
+
"""
|
| 111 |
+
Returns operations satisfying the alignment constraints
|
| 112 |
+
|
| 113 |
+
:param alignment_A: alignment constraint of operations to return
|
| 114 |
+
:type alignment_A: int
|
| 115 |
+
:param alignment_B: alignment constraint of operations to return
|
| 116 |
+
:type alignment_B: int
|
| 117 |
+
:param alignment_C: alignment constraint of operations to return
|
| 118 |
+
:type alignment_C: int
|
| 119 |
+
:param math_operation: math operation to consider
|
| 120 |
+
:type math_operation: cutlass_cppgen.MathOperation
|
| 121 |
+
|
| 122 |
+
:return: list of operations
|
| 123 |
+
:rtype: list
|
| 124 |
+
"""
|
| 125 |
+
key = f"{alignment_A} {alignment_B} {alignment_C}"
|
| 126 |
+
|
| 127 |
+
if key not in self.kernels_by_alignment:
|
| 128 |
+
og_key = key
|
| 129 |
+
# Reconcile A, B, and C alignments by trying to align to the minimum
|
| 130 |
+
min_alignment = min(alignment_A, alignment_B, alignment_C)
|
| 131 |
+
key = f"{min_alignment} {min_alignment} {min_alignment}"
|
| 132 |
+
if key not in self.kernels_by_alignment:
|
| 133 |
+
# Finally, go through all available alignment combinations and find
|
| 134 |
+
# one for which all values are less than those passed in.
|
| 135 |
+
key = None
|
| 136 |
+
alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
|
| 137 |
+
for align_A, align_B, align_C in alignments:
|
| 138 |
+
if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0:
|
| 139 |
+
key = f"{align_A} {align_B} {align_C}"
|
| 140 |
+
break
|
| 141 |
+
|
| 142 |
+
if key is None:
|
| 143 |
+
raise Exception(
|
| 144 |
+
f"No operations of alignment {og_key} found for data type and layout "
|
| 145 |
+
f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments "
|
| 146 |
+
f"are {self.kernels_by_alignment.keys()}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
ops = self.kernels_by_alignment[key]
|
| 150 |
+
if math_operation is not None:
|
| 151 |
+
ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation]
|
| 152 |
+
return ops
|
| 153 |
+
|
| 154 |
+
def _operand_idx(self, key: str) -> int:
|
| 155 |
+
operand_list = ["A", "B", "C"]
|
| 156 |
+
if key not in operand_list:
|
| 157 |
+
raise Exception(f"Unexpected operand {operand}")
|
| 158 |
+
|
| 159 |
+
return operand_list.index(key)
|
| 160 |
+
|
| 161 |
+
def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int:
|
| 162 |
+
"""
|
| 163 |
+
Returns the most preferable alignment for a given shape and layout
|
| 164 |
+
|
| 165 |
+
:param shape: extent of each dimension of the tensor
|
| 166 |
+
:type shape: tuple
|
| 167 |
+
:param layout: layout of the tensor
|
| 168 |
+
:type layout: cutlass_cppgen.LayoutType
|
| 169 |
+
:param operand: descriptor of the operand in question
|
| 170 |
+
:type operand: str
|
| 171 |
+
|
| 172 |
+
:return: maximum alignment supported by the data type combination and tensor size
|
| 173 |
+
:rtype: int
|
| 174 |
+
"""
|
| 175 |
+
operand_idx = self._operand_idx(operand)
|
| 176 |
+
|
| 177 |
+
# Determine the leading dimension of the shape
|
| 178 |
+
if layout == cutlass_cppgen.LayoutType.ColumnMajor:
|
| 179 |
+
ld = shape[-2]
|
| 180 |
+
elif layout == cutlass_cppgen.LayoutType.RowMajor:
|
| 181 |
+
ld = shape[-1]
|
| 182 |
+
elif layout == cutlass_cppgen.LayoutType.TensorNHWC:
|
| 183 |
+
ld = shape[-1]
|
| 184 |
+
else:
|
| 185 |
+
raise Exception(f"Unexpected or unsupported layout {layout}")
|
| 186 |
+
|
| 187 |
+
for alignments in sorted(list(self.kernels_by_alignment.keys()), reverse=True):
|
| 188 |
+
alignment = int(alignments.split(" ")[operand_idx])
|
| 189 |
+
if ld % alignment == 0:
|
| 190 |
+
return alignment
|
| 191 |
+
|
| 192 |
+
# Default to alignment of 1 if no others match
|
| 193 |
+
return 1
|
| 194 |
+
|
| 195 |
+
def sort(self):
|
| 196 |
+
"""
|
| 197 |
+
Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape
|
| 198 |
+
"""
|
| 199 |
+
key = lambda op: (
|
| 200 |
+
op.tile_description.threadblock_shape[0]
|
| 201 |
+
* op.tile_description.threadblock_shape[1]
|
| 202 |
+
* op.tile_description.threadblock_shape[2]
|
| 203 |
+
)
|
| 204 |
+
for alignment in self.kernels_by_alignment.keys():
|
| 205 |
+
self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
|
| 206 |
+
|
| 207 |
+
def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool:
|
| 208 |
+
"""
|
| 209 |
+
Returns whether `math_operation` is supported by at least one operation.
|
| 210 |
+
|
| 211 |
+
:param math_operation: math operation to consider
|
| 212 |
+
:type math_operation: cutlass_cppgen.MathOperation
|
| 213 |
+
|
| 214 |
+
:return: whether math_operation is supported by at least one operation
|
| 215 |
+
:rtype: bool
|
| 216 |
+
"""
|
| 217 |
+
return math_operation is None or math_operation in self.math_operations
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class ArchOptions:
|
| 221 |
+
"""
|
| 222 |
+
Structure for keeping track of kernels available on a given compute capability
|
| 223 |
+
|
| 224 |
+
:param target_cc: compute capability of the device on which kernels will be run
|
| 225 |
+
:type target_cc: int
|
| 226 |
+
:param kernel_cc: compute capability of the kernels to generate
|
| 227 |
+
:type kernel_cc: int
|
| 228 |
+
:param operation_kind: type of operation to register
|
| 229 |
+
:type operation_kind: cutlass_library.OperationKind
|
| 230 |
+
:param gemm_kinds: types of GEMM operations that can be included
|
| 231 |
+
:type gemm_kinds: list
|
| 232 |
+
:param allowed_math_operations: types of primitive math operations allowed
|
| 233 |
+
:type allowed_math_operations: list
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
def __init__(
|
| 237 |
+
self,
|
| 238 |
+
target_cc: int,
|
| 239 |
+
kernel_cc: int,
|
| 240 |
+
operation_kind: cutlass_library.OperationKind,
|
| 241 |
+
gemm_kinds: list,
|
| 242 |
+
allowed_math_operations: list = [
|
| 243 |
+
cutlass_library.MathOperation.multiply_add,
|
| 244 |
+
cutlass_library.MathOperation.multiply_add_saturate,
|
| 245 |
+
cutlass_library.MathOperation.multiply_add_mixed_input_upcast,
|
| 246 |
+
cutlass_library.MathOperation.multiply_add_fast_f32
|
| 247 |
+
]
|
| 248 |
+
):
|
| 249 |
+
self.cc = kernel_cc
|
| 250 |
+
|
| 251 |
+
# Dictionary with following structure:
|
| 252 |
+
# Key: OpcodeClass
|
| 253 |
+
# Value: Dictionary with the following structure:
|
| 254 |
+
# Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType),
|
| 255 |
+
# representing ((element_a, element_b, element_accumulator), (layout_a, layout_b))
|
| 256 |
+
# Value: KernelsForDataType
|
| 257 |
+
self.operations_by_opclass = {}
|
| 258 |
+
self.op_class = None
|
| 259 |
+
self.allowed_math_operations = allowed_math_operations
|
| 260 |
+
|
| 261 |
+
if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100:
|
| 262 |
+
return
|
| 263 |
+
|
| 264 |
+
# Identify the method within CUTLASS generator script that generates kernel
|
| 265 |
+
# descriptions for the target CC
|
| 266 |
+
generate_function_name = "GenerateSM" + str(kernel_cc)
|
| 267 |
+
if not hasattr(cutlass_library.generator, generate_function_name):
|
| 268 |
+
cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}")
|
| 269 |
+
return
|
| 270 |
+
generate_function = getattr(cutlass_library.generator, generate_function_name)
|
| 271 |
+
|
| 272 |
+
# Initialize a default manifest and populate it with valid kernel descriptions
|
| 273 |
+
# for the target CC
|
| 274 |
+
args = [
|
| 275 |
+
"--kernels=all",
|
| 276 |
+
f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}"
|
| 277 |
+
]
|
| 278 |
+
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
|
| 279 |
+
manifest = cutlass_library.manifest.Manifest(manifest_args)
|
| 280 |
+
generate_function(manifest, cutlass_cppgen._nvcc_version)
|
| 281 |
+
|
| 282 |
+
if operation_kind not in manifest.operations:
|
| 283 |
+
# No kernels generated for this architecture, this could be because the CUDA
|
| 284 |
+
# toolkit is insufficient to support operations in this CC
|
| 285 |
+
cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
# Only one CC should be returned, given the setup above of calling only the generation scripts
|
| 289 |
+
# for a given CC
|
| 290 |
+
if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]:
|
| 291 |
+
raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version "
|
| 292 |
+
"is sufficient for the architecture in question.")
|
| 293 |
+
|
| 294 |
+
# Iterate through the available operations for this operation kind and
|
| 295 |
+
# find available opclasses and data types
|
| 296 |
+
for name, op_list in manifest.operations[operation_kind][kernel_cc].items():
|
| 297 |
+
for op in op_list:
|
| 298 |
+
|
| 299 |
+
if operation_kind == cutlass_library.OperationKind.Gemm:
|
| 300 |
+
if op.gemm_kind not in gemm_kinds:
|
| 301 |
+
continue
|
| 302 |
+
|
| 303 |
+
mi = op.tile_description.math_instruction
|
| 304 |
+
if mi.math_operation not in self.allowed_math_operations:
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
# Prune operations that don't fit in shared memory
|
| 308 |
+
td = td_from_profiler_op(op)
|
| 309 |
+
if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]:
|
| 310 |
+
continue
|
| 311 |
+
|
| 312 |
+
if mi.opcode_class not in self.operations_by_opclass:
|
| 313 |
+
self.operations_by_opclass[mi.opcode_class] = {}
|
| 314 |
+
|
| 315 |
+
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
|
| 316 |
+
layout_comb = (op.A.layout, op.B.layout)
|
| 317 |
+
|
| 318 |
+
# Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations
|
| 319 |
+
if datatype_comb == (cutlass_library.DataType.tf32, cutlass_library.DataType.tf32, cutlass_library.DataType.f32):
|
| 320 |
+
# TF32 kernels only supported on SM80 and beyond
|
| 321 |
+
if self.cc < 80:
|
| 322 |
+
continue
|
| 323 |
+
elif self.cc == 90 or self.cc == 100:
|
| 324 |
+
if (op.A.element != cutlass_library.DataType.f32
|
| 325 |
+
or op.B.element != cutlass_library.DataType.f32
|
| 326 |
+
or op.C.element != cutlass_library.DataType.f32):
|
| 327 |
+
continue
|
| 328 |
+
|
| 329 |
+
datatype_comb = (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32)
|
| 330 |
+
|
| 331 |
+
opclass_dict = self.operations_by_opclass[mi.opcode_class]
|
| 332 |
+
key = (datatype_comb, layout_comb)
|
| 333 |
+
if key not in opclass_dict:
|
| 334 |
+
opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb)
|
| 335 |
+
opclass_dict[key].add(op)
|
| 336 |
+
|
| 337 |
+
# Set the default opclass to TensorOp, if available. Otherwise default to SIMT
|
| 338 |
+
if cutlass_library.OpcodeClass.TensorOp in self.operations_by_opclass:
|
| 339 |
+
self.op_class = cutlass_library.OpcodeClass.TensorOp
|
| 340 |
+
else:
|
| 341 |
+
self.op_class = cutlass_library.OpcodeClass.Simt
|
| 342 |
+
|
| 343 |
+
# The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels.
|
| 344 |
+
# Here, we generate additional versions via a generic TileDescription.
|
| 345 |
+
if cutlass_library.OpcodeClass.Simt not in self.operations_by_opclass:
|
| 346 |
+
self.operations_by_opclass[cutlass_library.OpcodeClass.Simt] = {}
|
| 347 |
+
|
| 348 |
+
if operation_kind == cutlass_library.OperationKind.Gemm:
|
| 349 |
+
types = [
|
| 350 |
+
(cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s8),
|
| 351 |
+
(cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32),
|
| 352 |
+
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
|
| 353 |
+
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
|
| 354 |
+
(cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
|
| 355 |
+
(cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
|
| 356 |
+
]
|
| 357 |
+
|
| 358 |
+
# Add FP8 A/B/C
|
| 359 |
+
fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2]
|
| 360 |
+
for type_comb in combinations_with_replacement(fp8_types, 3):
|
| 361 |
+
types.append(type_comb)
|
| 362 |
+
|
| 363 |
+
# Add FP8 A/B with FP32 C
|
| 364 |
+
for type_comb in combinations_with_replacement(fp8_types, 2):
|
| 365 |
+
types.append(type_comb + (cutlass_cppgen.DataType.f32,))
|
| 366 |
+
|
| 367 |
+
layouts = [
|
| 368 |
+
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor),
|
| 369 |
+
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor),
|
| 370 |
+
(cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.RowMajor),
|
| 371 |
+
(cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.ColumnMajor),
|
| 372 |
+
]
|
| 373 |
+
elif operation_kind == cutlass_library.OperationKind.Conv2d:
|
| 374 |
+
types = [
|
| 375 |
+
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
|
| 376 |
+
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
|
| 377 |
+
(cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
|
| 378 |
+
(cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
|
| 379 |
+
]
|
| 380 |
+
|
| 381 |
+
layouts = [
|
| 382 |
+
(cutlass_library.LayoutType.TensorNHWC, cutlass_library.LayoutType.TensorNHWC),
|
| 383 |
+
]
|
| 384 |
+
else:
|
| 385 |
+
raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.")
|
| 386 |
+
|
| 387 |
+
alignment = 1
|
| 388 |
+
epilogue_functor = cutlass_library.EpilogueFunctor.LinearCombination
|
| 389 |
+
swizzling_functor = cutlass_library.SwizzlingFunctor.Identity8
|
| 390 |
+
for type_comb in types:
|
| 391 |
+
for layout_comb in layouts:
|
| 392 |
+
comb = (type_comb, layout_comb)
|
| 393 |
+
if comb in self.operations_by_opclass[cutlass_library.OpcodeClass.Simt]:
|
| 394 |
+
continue
|
| 395 |
+
|
| 396 |
+
A = cutlass_library.TensorDescription(type_comb[0], layout_comb[0], alignment)
|
| 397 |
+
B = cutlass_library.TensorDescription(type_comb[1], layout_comb[1], alignment)
|
| 398 |
+
C = cutlass_library.TensorDescription(type_comb[2], cutlass_library.LayoutType.ColumnMajor, alignment)
|
| 399 |
+
math_inst = cutlass_library.MathInstruction(
|
| 400 |
+
[1, 1, 1],
|
| 401 |
+
type_comb[0],
|
| 402 |
+
type_comb[1],
|
| 403 |
+
type_comb[2],
|
| 404 |
+
cutlass_library.OpcodeClass.Simt,
|
| 405 |
+
cutlass_library.MathOperation.multiply_add
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
td = cutlass_library.TileDescription(
|
| 409 |
+
[128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024)
|
| 410 |
+
|
| 411 |
+
# Prune operations that don't fit in shared memory
|
| 412 |
+
if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td), verbose=False)[0]:
|
| 413 |
+
continue
|
| 414 |
+
|
| 415 |
+
new_kernels = KernelsForDataType(type_comb, layout_comb)
|
| 416 |
+
|
| 417 |
+
if operation_kind == cutlass_library.OperationKind.Gemm:
|
| 418 |
+
new_operation = cutlass_library.manifest.GemmOperation(
|
| 419 |
+
cutlass_library.GemmKind.Universal, td.minimum_compute_capability,
|
| 420 |
+
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
|
| 421 |
+
new_kernels.add(new_operation)
|
| 422 |
+
elif operation_kind == cutlass_library.OperationKind.Conv2d:
|
| 423 |
+
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
| 424 |
+
new_operation = cutlass_library.manifest.Conv2dOperation(
|
| 425 |
+
conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td,
|
| 426 |
+
A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor,
|
| 427 |
+
group_mode=GroupMode.SingleGroup
|
| 428 |
+
)
|
| 429 |
+
new_kernels.add(new_operation)
|
| 430 |
+
|
| 431 |
+
self.operations_by_opclass[cutlass_library.OpcodeClass.Simt][comb] = new_kernels
|
| 432 |
+
|
| 433 |
+
# Sort all operations
|
| 434 |
+
for oc in self.operations_by_opclass.keys():
|
| 435 |
+
for comb in self.operations_by_opclass[oc].keys():
|
| 436 |
+
self.operations_by_opclass[oc][comb].sort()
|
| 437 |
+
|
| 438 |
+
def opclass_supports_combination(
|
| 439 |
+
self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation
|
| 440 |
+
) -> bool:
|
| 441 |
+
"""
|
| 442 |
+
Returns whether the provided operation class supports the provided data type and layout combination
|
| 443 |
+
|
| 444 |
+
:param op_class: operation class to consider
|
| 445 |
+
:type op_class: cutlass_library.OpcodeClass
|
| 446 |
+
:param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator)
|
| 447 |
+
:type datatype_comb: tuple[cutlass_library.DataType]
|
| 448 |
+
:param layout_comb: tuple of data types for (layout_A, layout_B)
|
| 449 |
+
:type layout_comb: tuple[cutlass_library.LayoutType]
|
| 450 |
+
:param math_operation: math operation to consider or None if any can be considered
|
| 451 |
+
:type math_operation: cutlass_cppgen.MathOperation
|
| 452 |
+
|
| 453 |
+
:return: set of operation classes that support the provided data type and layout combination
|
| 454 |
+
:rtype: set
|
| 455 |
+
"""
|
| 456 |
+
if op_class not in self.operations_by_opclass:
|
| 457 |
+
raise Exception(f"Unexpected or unsupported operation class {op_class}")
|
| 458 |
+
|
| 459 |
+
if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)):
|
| 460 |
+
if math_operation is not None:
|
| 461 |
+
return operations.supports_math_operation(math_operation)
|
| 462 |
+
else:
|
| 463 |
+
return True
|
| 464 |
+
|
| 465 |
+
return False
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def supporting_opclasses(
|
| 469 |
+
self,
|
| 470 |
+
element_a: cutlass_library.DataType,
|
| 471 |
+
element_b: cutlass_library.DataType,
|
| 472 |
+
element_accumulator: cutlass_library.DataType,
|
| 473 |
+
layout_a: cutlass_library.LayoutType,
|
| 474 |
+
layout_b: cutlass_library.LayoutType,
|
| 475 |
+
math_operation: cutlass_library.MathOperation,
|
| 476 |
+
) -> set:
|
| 477 |
+
"""
|
| 478 |
+
Returns a set of operation classes that support the provided data type combination
|
| 479 |
+
|
| 480 |
+
:param element_a: data type of operand A
|
| 481 |
+
:type element_a: cutlass_library.DataType
|
| 482 |
+
:param element_b: data type of operand B
|
| 483 |
+
:type element_b: cutlass_library.DataType
|
| 484 |
+
:param element_accumulator: data type of accumulator
|
| 485 |
+
:type element_accumulator: cutlass_library.DataType
|
| 486 |
+
:param layout_a: layout of operand A
|
| 487 |
+
:type layout_a: cutlass_library.LayoutType
|
| 488 |
+
:param layout_b: layout of operand B
|
| 489 |
+
:type layout_b: cutlass_library.LayoutType
|
| 490 |
+
:param math_operation: math operation to consider
|
| 491 |
+
:type math_operation: cutlass_cppgen.MathOperation
|
| 492 |
+
|
| 493 |
+
:return: set of operation classes that support the provided data type combination
|
| 494 |
+
:rtype: set
|
| 495 |
+
"""
|
| 496 |
+
supporting_op_classes = set()
|
| 497 |
+
datatype_comb = (element_a, element_b, element_accumulator)
|
| 498 |
+
layout_comb = (layout_a, layout_b)
|
| 499 |
+
|
| 500 |
+
for op_class in self.operations_by_opclass.keys():
|
| 501 |
+
if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
|
| 502 |
+
supporting_op_classes.add(op_class)
|
| 503 |
+
return supporting_op_classes
|
| 504 |
+
|
| 505 |
+
def operations(
|
| 506 |
+
self,
|
| 507 |
+
op_class: cutlass_library.OpcodeClass,
|
| 508 |
+
element_a: cutlass_library.DataType,
|
| 509 |
+
element_b: cutlass_library.DataType,
|
| 510 |
+
element_accumulator: cutlass_library.DataType,
|
| 511 |
+
layout_a: cutlass_library.LayoutType,
|
| 512 |
+
layout_b: cutlass_library.LayoutType,
|
| 513 |
+
math_operation: cutlass_library.MathOperation,
|
| 514 |
+
) -> KernelsForDataType:
|
| 515 |
+
"""
|
| 516 |
+
Returns whether the provided operation class supports the provided data type combination
|
| 517 |
+
|
| 518 |
+
:param op_class: operation class to consider
|
| 519 |
+
:type op_class: cutlass_library.OpcodeClass
|
| 520 |
+
:param element_a: data type of operand A
|
| 521 |
+
:type element_a: cutlass_library.DataType
|
| 522 |
+
:param element_b: data type of operand B
|
| 523 |
+
:type element_b: cutlass_library.DataType
|
| 524 |
+
:param element_accumulator: data type of accumulator
|
| 525 |
+
:type element_accumulator: cutlass_library.DataType
|
| 526 |
+
:param layout_a: layout of operand A
|
| 527 |
+
:type layout_a: cutlass_library.LayoutType
|
| 528 |
+
:param layout_b: layout of operand B
|
| 529 |
+
:type layout_b: cutlass_library.LayoutType
|
| 530 |
+
:param math_operation: math operation to consider
|
| 531 |
+
:type math_operation: cutlass_cppgen.MathOperation
|
| 532 |
+
|
| 533 |
+
:return: container of kernels by alignment supported by the provided combination of parameters
|
| 534 |
+
:rtype: KernelsForDataType
|
| 535 |
+
"""
|
| 536 |
+
datatype_comb = (element_a, element_b, element_accumulator)
|
| 537 |
+
layout_comb = (layout_a, layout_b)
|
| 538 |
+
if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
|
| 539 |
+
raise Exception(
|
| 540 |
+
f"Data type layout combination {datatype_comb}, {layout_comb} "
|
| 541 |
+
f"is not supported by opcode class {op_class} on CC {self.cc}."
|
| 542 |
+
)
|
| 543 |
+
return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)]
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
class OptionRegistry:
|
| 547 |
+
"""
|
| 548 |
+
Container of all architecture-specific options
|
| 549 |
+
|
| 550 |
+
:param target_cc: compute capability of the device on which operations will be run
|
| 551 |
+
:type target_cc: int
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
def __init__(self, target_cc: int):
|
| 555 |
+
self.registry = {}
|
| 556 |
+
|
| 557 |
+
if target_cc > 100 and (target_cc not in [101, 103, 120, 121]):
|
| 558 |
+
raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to the Blackwell architecture.")
|
| 559 |
+
|
| 560 |
+
gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x]
|
| 561 |
+
operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d]
|
| 562 |
+
# Construct options for each CC
|
| 563 |
+
for kernel_cc in _generator_ccs:
|
| 564 |
+
self.registry[kernel_cc] = {}
|
| 565 |
+
for opkind in operation_kinds:
|
| 566 |
+
self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds)
|
| 567 |
+
|
| 568 |
+
def options_for_cc(self, cc: int, op_kind=cutlass_library.OperationKind.Gemm) -> ArchOptions:
|
| 569 |
+
return self.registry.get(cc, None)[op_kind]
|
build/torch212-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
| 34 |
+
from cutlass_cppgen.op.gemm import Gemm
|
| 35 |
+
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
|
| 36 |
+
from cutlass_cppgen.op.op import OperationBase
|