Kernels:
Trusted publisher
Uploaded using `kernel-builder` (batch 12/32).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py +0 -336
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py +0 -294
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py +0 -306
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py +0 -277
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py +0 -137
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py +0 -42
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py +0 -143
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py +0 -120
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py +0 -169
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py +0 -64
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py +0 -90
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py +0 -217
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py +0 -164
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py +0 -53
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py +0 -97
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py +0 -59
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py +0 -319
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py +0 -46
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py +0 -109
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py +0 -2145
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py +0 -509
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py +0 -121
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py +0 -140
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py +0 -455
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py +0 -35
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py +0 -33
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py +0 -126
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py +0 -33
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py +0 -267
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py +0 -936
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py +0 -56
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py +0 -176
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py +0 -98
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py +0 -569
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py +0 -36
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py +0 -997
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py +0 -725
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py +0 -269
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py +0 -431
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py +0 -184
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py +0 -65
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py +0 -41
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py +0 -262
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py +0 -362
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py +0 -41
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py +0 -196
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py +0 -63
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py +0 -621
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py +0 -482
- build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py +0 -250
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/layout_nodes.py
DELETED
|
@@ -1,336 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Layout manipulation nodes and implementations
|
| 35 |
-
|
| 36 |
-
The layout Nodes change the layout of intermediate nodes in epilogue visitor graph
|
| 37 |
-
"""
|
| 38 |
-
|
| 39 |
-
from copy import deepcopy
|
| 40 |
-
|
| 41 |
-
from cutlass_library import LayoutType
|
| 42 |
-
from pycute import product, flatten
|
| 43 |
-
|
| 44 |
-
import cutlass_cppgen
|
| 45 |
-
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list
|
| 46 |
-
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
| 47 |
-
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class PermutationImpl:
|
| 51 |
-
"""
|
| 52 |
-
Detailed implementation and helper functions for permutation
|
| 53 |
-
"""
|
| 54 |
-
def __init__(self, node) -> None:
|
| 55 |
-
assert "indices" in node.kwargs.keys()
|
| 56 |
-
self.indices = list(node.kwargs["indices"])
|
| 57 |
-
self.inverse_indices = self.get_inverse_indices(self.indices)
|
| 58 |
-
|
| 59 |
-
def get_inverse_impl(self):
|
| 60 |
-
inverse_impl = deepcopy(self)
|
| 61 |
-
inverse_impl.indices = self.inverse_indices
|
| 62 |
-
inverse_impl.inverse_indices = self.indices
|
| 63 |
-
return inverse_impl
|
| 64 |
-
|
| 65 |
-
def update(self, shape):
|
| 66 |
-
num_dim = len(shape)
|
| 67 |
-
indices = self.indices
|
| 68 |
-
num_old_dim = len(indices)
|
| 69 |
-
# Add offset
|
| 70 |
-
for i, idx in enumerate(indices):
|
| 71 |
-
indices[i] = idx + num_dim - num_old_dim
|
| 72 |
-
# Add broadcast dims
|
| 73 |
-
for i in range(num_dim - num_old_dim):
|
| 74 |
-
indices = [i,] + indices
|
| 75 |
-
|
| 76 |
-
self.indices = indices
|
| 77 |
-
self.inverse_indices = self.get_inverse_indices(self.indices)
|
| 78 |
-
|
| 79 |
-
def get_inverse_indices(self, indices):
|
| 80 |
-
"""
|
| 81 |
-
Get the indices for inverse permutation
|
| 82 |
-
"""
|
| 83 |
-
num_dim = len(indices)
|
| 84 |
-
inverse_indices = [0] * num_dim
|
| 85 |
-
for i in range(num_dim):
|
| 86 |
-
inverse_indices[indices[i]] = i
|
| 87 |
-
return inverse_indices
|
| 88 |
-
|
| 89 |
-
def shape_propagation(self, input_node_meta):
|
| 90 |
-
input_shape = input_node_meta.tensor.shape
|
| 91 |
-
output_shape = tuple([input_shape[idx] for idx in self.indices])
|
| 92 |
-
return output_shape
|
| 93 |
-
|
| 94 |
-
def broadcast(self, shape, node_meta: NodeBase):
|
| 95 |
-
"""
|
| 96 |
-
Broadcast the inputs based on current shape
|
| 97 |
-
"""
|
| 98 |
-
self.update(shape)
|
| 99 |
-
inverse_shape = tuple([shape[idx] for idx in self.inverse_indices])
|
| 100 |
-
node_meta.tensor.broadcast(inverse_shape)
|
| 101 |
-
|
| 102 |
-
def apply_to_user(self, usr_meta: NodeBase):
|
| 103 |
-
"""
|
| 104 |
-
Propagate the permutation to the users of the current nodes
|
| 105 |
-
"""
|
| 106 |
-
usr_meta.tensor.permute(self.inverse_indices)
|
| 107 |
-
if hasattr(usr_meta, "store_tensor"):
|
| 108 |
-
if usr_meta.store_tensor is not None:
|
| 109 |
-
usr_meta.store_tensor.permute(self.inverse_indices)
|
| 110 |
-
|
| 111 |
-
def apply_to_input(self, input_meta: NodeBase):
|
| 112 |
-
"""
|
| 113 |
-
Propagate the permutation to inputs of the current nodes
|
| 114 |
-
"""
|
| 115 |
-
input_meta.tensor.permute(self.indices)
|
| 116 |
-
if hasattr(input_meta, "store_tensor"):
|
| 117 |
-
if input_meta.store_tensor is not None:
|
| 118 |
-
input_meta.store_tensor.permute(self.indices)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
class ReshapeImpl:
|
| 122 |
-
"""
|
| 123 |
-
Detailed implementation and helper functions for reshape
|
| 124 |
-
"""
|
| 125 |
-
def __init__(self, node) -> None:
|
| 126 |
-
self.node = node
|
| 127 |
-
assert "new_shape" in node.kwargs.keys()
|
| 128 |
-
self.output_shape = _list_to_tuple(node.kwargs["new_shape"])
|
| 129 |
-
|
| 130 |
-
def get_inverse_impl(self):
|
| 131 |
-
inverse_impl = deepcopy(self)
|
| 132 |
-
inverse_impl.output_shape = self.input_shape
|
| 133 |
-
inverse_impl.input_shape = self.output_shape
|
| 134 |
-
return inverse_impl
|
| 135 |
-
|
| 136 |
-
def shape_propagation(self, input_node_meta):
|
| 137 |
-
self.input_shape = input_node_meta.tensor.shape
|
| 138 |
-
return _list_to_tuple(self.output_shape)
|
| 139 |
-
|
| 140 |
-
def broadcast(self, shape, node_meta: NodeBase):
|
| 141 |
-
"""
|
| 142 |
-
Broadcast the inputs based on current shape.
|
| 143 |
-
"""
|
| 144 |
-
# Step 1: infer split
|
| 145 |
-
flatten_split_shape = self.infer_split(flatten(self.input_shape), flatten(self.output_shape))
|
| 146 |
-
split_input_shape = self.infer_merge(flatten_split_shape, self.input_shape)
|
| 147 |
-
split_output_shape = self.infer_merge(flatten_split_shape, self.output_shape)
|
| 148 |
-
|
| 149 |
-
# broadcast shape -> split_output_shape -> flatten_split_shape
|
| 150 |
-
if len(shape) - len(split_output_shape) > 0:
|
| 151 |
-
for _ in range(len(shape) - len(split_output_shape)):
|
| 152 |
-
split_output_shape = [1,] + split_output_shape
|
| 153 |
-
flatten_split_shape = [1,] + flatten_split_shape
|
| 154 |
-
split_input_shape = [1,] + split_input_shape
|
| 155 |
-
broadcast_factor = []
|
| 156 |
-
for dim, old_dim in zip(shape, split_output_shape):
|
| 157 |
-
if not isinstance(dim, list):
|
| 158 |
-
dim = [dim,]
|
| 159 |
-
if not isinstance(old_dim, list):
|
| 160 |
-
old_dim = [old_dim,]
|
| 161 |
-
if product(tuple(dim)) == product(tuple(old_dim)):
|
| 162 |
-
broadcast_factor += [1] * len(old_dim)
|
| 163 |
-
elif product(tuple(old_dim)) == 1:
|
| 164 |
-
assert len(dim) == 1
|
| 165 |
-
broadcast_factor.append(dim[0])
|
| 166 |
-
else:
|
| 167 |
-
raise NotImplementedError(f"Invalid Broadcast: {old_dim} -> {dim}")
|
| 168 |
-
|
| 169 |
-
# flatten_split_shape -> split_input_shape
|
| 170 |
-
factor_idx = 0
|
| 171 |
-
broadcast_split_input_shape = []
|
| 172 |
-
for dim in split_input_shape:
|
| 173 |
-
if isinstance(dim, list):
|
| 174 |
-
new_dim = []
|
| 175 |
-
for d in dim:
|
| 176 |
-
new_dim.append(d * broadcast_factor[factor_idx])
|
| 177 |
-
factor_idx += 1
|
| 178 |
-
broadcast_split_input_shape.append(new_dim)
|
| 179 |
-
else:
|
| 180 |
-
broadcast_split_input_shape.append(dim * broadcast_factor[factor_idx])
|
| 181 |
-
factor_idx += 1
|
| 182 |
-
broadcast_split_input_shape = _list_to_tuple(broadcast_split_input_shape)
|
| 183 |
-
node_meta.tensor.reshape(_list_to_tuple(split_input_shape))
|
| 184 |
-
node_meta.tensor.broadcast(broadcast_split_input_shape)
|
| 185 |
-
# Last reshape op to clean up
|
| 186 |
-
broadcast_input_shape = tuple([product(dim) for dim in broadcast_split_input_shape])
|
| 187 |
-
node_meta.tensor.reshape(broadcast_input_shape)
|
| 188 |
-
# Update the input shape and output shape
|
| 189 |
-
self.input_shape = _list_to_tuple(node_meta.tensor.shape)
|
| 190 |
-
self.output_shape = _list_to_tuple(shape)
|
| 191 |
-
|
| 192 |
-
def apply_to_user(self, user_meta: NodeBase):
|
| 193 |
-
"""
|
| 194 |
-
Propagate the reshape to user nodes
|
| 195 |
-
"""
|
| 196 |
-
user_meta.tensor.reshape(tuple(self.input_shape))
|
| 197 |
-
if hasattr(user_meta, "store_tensor"):
|
| 198 |
-
if user_meta.store_tensor is not None:
|
| 199 |
-
user_meta.store_tensor.reshape(tuple(self.input_shape))
|
| 200 |
-
|
| 201 |
-
def apply_to_input(self, input_meta: NodeBase):
|
| 202 |
-
"""
|
| 203 |
-
Propagate the reshape to input nodes
|
| 204 |
-
"""
|
| 205 |
-
input_meta.tensor.reshape(tuple(self.output_shape))
|
| 206 |
-
if hasattr(input_meta, "store_tensor"):
|
| 207 |
-
if input_meta.store_tensor is not None:
|
| 208 |
-
input_meta.store_tensor.reshape(tuple(self.output_shape))
|
| 209 |
-
|
| 210 |
-
#
|
| 211 |
-
# Helper functions
|
| 212 |
-
#
|
| 213 |
-
|
| 214 |
-
def infer_split(self, input_shape, output_shape):
|
| 215 |
-
"""
|
| 216 |
-
Infer the flatten splitted shape that can be merged to both input_shape and output_shape
|
| 217 |
-
"""
|
| 218 |
-
input_shape = _tuple_to_list(input_shape)
|
| 219 |
-
output_shape = _tuple_to_list(output_shape)
|
| 220 |
-
if len(input_shape) == 0 and len(output_shape) == 0:
|
| 221 |
-
return []
|
| 222 |
-
if len(input_shape) == 0:
|
| 223 |
-
if product(tuple(output_shape)) != 1:
|
| 224 |
-
raise ValueError("Invalid reshape size")
|
| 225 |
-
else:
|
| 226 |
-
return output_shape
|
| 227 |
-
if len(output_shape) == 0:
|
| 228 |
-
if product(tuple(input_shape)) != 1:
|
| 229 |
-
raise ValueError("Invalid reshape size")
|
| 230 |
-
else:
|
| 231 |
-
return input_shape
|
| 232 |
-
# This is done recursively by only process the last dimension at each time
|
| 233 |
-
old_dim = input_shape[-1]
|
| 234 |
-
new_dim = output_shape[-1]
|
| 235 |
-
# Exact match
|
| 236 |
-
if old_dim == new_dim:
|
| 237 |
-
return self.infer_split(input_shape[:-1], output_shape[:-1]) + [new_dim,]
|
| 238 |
-
# Needs split
|
| 239 |
-
if old_dim > new_dim and old_dim % new_dim == 0:
|
| 240 |
-
residual = old_dim // new_dim
|
| 241 |
-
return self.infer_split(input_shape[:-1] + [residual,], output_shape[:-1]) + [new_dim,]
|
| 242 |
-
# Needs merge
|
| 243 |
-
if old_dim < new_dim and new_dim % old_dim == 0:
|
| 244 |
-
residual = new_dim // old_dim
|
| 245 |
-
return self.infer_split(input_shape[:-1], output_shape[:-1] + [residual,]) + [old_dim,]
|
| 246 |
-
|
| 247 |
-
raise NotImplementedError(f"Unsupported split: {input_shape} -> {output_shape}")
|
| 248 |
-
|
| 249 |
-
def infer_merge(self, flatten_shape, shape):
|
| 250 |
-
flatten_shape = _tuple_to_list(flatten_shape)
|
| 251 |
-
shape = _tuple_to_list(shape)
|
| 252 |
-
idx_flat = len(flatten_shape) - 1
|
| 253 |
-
merged_shape = []
|
| 254 |
-
for dim in reversed(shape):
|
| 255 |
-
# Exact match
|
| 256 |
-
if dim == flatten_shape[idx_flat]:
|
| 257 |
-
merged_shape.append(dim)
|
| 258 |
-
idx_flat -= 1
|
| 259 |
-
# need group
|
| 260 |
-
elif dim > flatten_shape[idx_flat] and dim % flatten_shape[idx_flat] == 0:
|
| 261 |
-
residual = dim
|
| 262 |
-
group = []
|
| 263 |
-
while(residual > 1):
|
| 264 |
-
group.append(flatten_shape[idx_flat])
|
| 265 |
-
residual = residual // flatten_shape[idx_flat]
|
| 266 |
-
idx_flat -= 1
|
| 267 |
-
merged_shape.append(group[::-1])
|
| 268 |
-
else:
|
| 269 |
-
raise NotImplementedError(f"Unsupported merge: {flatten_shape} -> {shape}")
|
| 270 |
-
|
| 271 |
-
return merged_shape[::-1]
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
class LayoutNode(NodeBase):
|
| 275 |
-
"""
|
| 276 |
-
Layout manipulation nodes
|
| 277 |
-
"""
|
| 278 |
-
fn_to_impl = {
|
| 279 |
-
"permute": PermutationImpl,
|
| 280 |
-
"reshape": ReshapeImpl
|
| 281 |
-
}
|
| 282 |
-
def __init__(self, name: str, fn, kwargs: dict) -> None:
|
| 283 |
-
super().__init__(name)
|
| 284 |
-
self.op = "layout"
|
| 285 |
-
self.fn = fn
|
| 286 |
-
self.kwargs = kwargs
|
| 287 |
-
self.underlying_impl = self.fn_to_impl[self.fn.__name__](self)
|
| 288 |
-
|
| 289 |
-
def get_inverse_node(self):
|
| 290 |
-
inverse_node = deepcopy(self)
|
| 291 |
-
inverse_node.underlying_impl = self.underlying_impl.get_inverse_impl()
|
| 292 |
-
return inverse_node
|
| 293 |
-
|
| 294 |
-
def shape_propagation(self, input_node_metas):
|
| 295 |
-
if self._tensor is not None:
|
| 296 |
-
return
|
| 297 |
-
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
| 298 |
-
|
| 299 |
-
output_shape = self.underlying_impl.shape_propagation(input_node_metas[0])
|
| 300 |
-
|
| 301 |
-
self._tensor = Tensor(
|
| 302 |
-
element=self.element_output,
|
| 303 |
-
shape=output_shape, layout_tag=LayoutType.RowMajor
|
| 304 |
-
)
|
| 305 |
-
|
| 306 |
-
return super().shape_propagation(input_node_metas)
|
| 307 |
-
|
| 308 |
-
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
|
| 309 |
-
"""
|
| 310 |
-
The store nodes has element_output = element_input
|
| 311 |
-
"""
|
| 312 |
-
assert len(input_node_metas) == 1, "Layout node can only have one input node"
|
| 313 |
-
self.element_output = input_node_metas[0].element_output
|
| 314 |
-
|
| 315 |
-
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
| 316 |
-
"""
|
| 317 |
-
Propagate the broadcast in the reversed topological order
|
| 318 |
-
"""
|
| 319 |
-
if self.tensor is None:
|
| 320 |
-
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
| 321 |
-
shape = self.tensor.shape
|
| 322 |
-
|
| 323 |
-
for child in input_node_metas:
|
| 324 |
-
self.underlying_impl.broadcast(shape, child)
|
| 325 |
-
|
| 326 |
-
def apply_to_user(self, usr_meta: NodeBase):
|
| 327 |
-
"""
|
| 328 |
-
Propagate the permutation to user nodes
|
| 329 |
-
"""
|
| 330 |
-
self.underlying_impl.apply_to_user(usr_meta)
|
| 331 |
-
|
| 332 |
-
def apply_to_input(self, input_meta: NodeBase):
|
| 333 |
-
"""
|
| 334 |
-
Propagate the permutation to input nodes
|
| 335 |
-
"""
|
| 336 |
-
self.underlying_impl.apply_to_input(input_meta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/load_nodes.py
DELETED
|
@@ -1,294 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Load nodes and implementations
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import ctypes
|
| 38 |
-
|
| 39 |
-
from cutlass_cppgen.backend.c_types import tuple_factory
|
| 40 |
-
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
|
| 41 |
-
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class LoadImplBase(ImplBase):
|
| 45 |
-
"""
|
| 46 |
-
Base class for load node implementations
|
| 47 |
-
"""
|
| 48 |
-
reserved_names = ["accum", "C"]
|
| 49 |
-
def __init__(self, node) -> None:
|
| 50 |
-
super().__init__(node)
|
| 51 |
-
self.element = node.element
|
| 52 |
-
self.element_output = node.element_output
|
| 53 |
-
self.stride = node.tensor.stride
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
class AccumulatorImpl(LoadImplBase):
|
| 57 |
-
"""
|
| 58 |
-
Accumulator node implementation
|
| 59 |
-
"""
|
| 60 |
-
|
| 61 |
-
@staticmethod
|
| 62 |
-
def match(node, problem_size: tuple):
|
| 63 |
-
return node.name == "accum" and node.tensor.shape == problem_size
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
class LoadSrcImpl(LoadImplBase):
|
| 67 |
-
"""
|
| 68 |
-
Load C implementation
|
| 69 |
-
"""
|
| 70 |
-
@property
|
| 71 |
-
def name_camel(self) -> str:
|
| 72 |
-
return "TensorC"
|
| 73 |
-
|
| 74 |
-
@property
|
| 75 |
-
def argument_type_c(self):
|
| 76 |
-
stride_mnl = self.get_stride_mnl()
|
| 77 |
-
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 78 |
-
class _Argument(ctypes.Structure):
|
| 79 |
-
_fields_ = [
|
| 80 |
-
("ptr_C", ctypes.c_void_p),
|
| 81 |
-
("stride_C", tuple_type)
|
| 82 |
-
]
|
| 83 |
-
def __init__(self, ptr) -> None:
|
| 84 |
-
self.ptr_C = ptr
|
| 85 |
-
self.stride_C = tuple_type(stride_mnl)
|
| 86 |
-
|
| 87 |
-
return _Argument
|
| 88 |
-
|
| 89 |
-
@staticmethod
|
| 90 |
-
def match(node, problem_size: tuple):
|
| 91 |
-
return node.name == "C" and node.tensor.shape == problem_size
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
class AuxLoadImpl(LoadImplBase):
|
| 95 |
-
"""
|
| 96 |
-
Load arbitrary tensor
|
| 97 |
-
"""
|
| 98 |
-
@property
|
| 99 |
-
def argument_type(self):
|
| 100 |
-
stride_mnl = self.get_stride_mnl()
|
| 101 |
-
name = self.name
|
| 102 |
-
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 103 |
-
element_type = self.element
|
| 104 |
-
class _Argument(ctypes.Structure):
|
| 105 |
-
_fields_ = [
|
| 106 |
-
("ptr_aux", ctypes.c_void_p),
|
| 107 |
-
("null_default", dtype2ctype[element_type]),
|
| 108 |
-
("dAux", tuple_type)
|
| 109 |
-
]
|
| 110 |
-
def __init__(self, kwargs) -> None:
|
| 111 |
-
ptr = kwargs[name]
|
| 112 |
-
self.ptr_aux = ptr
|
| 113 |
-
self.null_default = to_ctype_value(0, element_type)
|
| 114 |
-
self.dAux = tuple_type(stride_mnl)
|
| 115 |
-
|
| 116 |
-
return _Argument
|
| 117 |
-
|
| 118 |
-
@staticmethod
|
| 119 |
-
def match(node, problem_size: tuple):
|
| 120 |
-
if node.name in LoadImplBase.reserved_names:
|
| 121 |
-
return False
|
| 122 |
-
strideMN = node.tensor.stride[-2:]
|
| 123 |
-
if (strideMN[0] == 1 and strideMN[1] != 0 or
|
| 124 |
-
strideMN[0] != 0 and strideMN[1] == 1 ):
|
| 125 |
-
return True
|
| 126 |
-
else:
|
| 127 |
-
return False
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
class RowBroadcastImpl(LoadImplBase):
|
| 131 |
-
"""
|
| 132 |
-
Broadcast a row vector
|
| 133 |
-
"""
|
| 134 |
-
def __init__(self, node) -> None:
|
| 135 |
-
super().__init__(node)
|
| 136 |
-
self.stride_dtype = "int"
|
| 137 |
-
|
| 138 |
-
@property
|
| 139 |
-
def argument_type(self):
|
| 140 |
-
stride_mnl = self.get_stride_mnl()
|
| 141 |
-
name = self.name
|
| 142 |
-
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 143 |
-
element_type = self.element
|
| 144 |
-
class _Argument(ctypes.Structure):
|
| 145 |
-
_fields_ = [
|
| 146 |
-
("ptr_row", ctypes.c_void_p),
|
| 147 |
-
("null_default", dtype2ctype[element_type]),
|
| 148 |
-
("dRow", tuple_type)
|
| 149 |
-
]
|
| 150 |
-
def __init__(self, kwargs) -> None:
|
| 151 |
-
ptr = kwargs[name]
|
| 152 |
-
self.ptr_row = ptr
|
| 153 |
-
self.null_default = to_ctype_value(0, element_type)
|
| 154 |
-
self.dRow = tuple_type(stride_mnl)
|
| 155 |
-
|
| 156 |
-
return _Argument
|
| 157 |
-
|
| 158 |
-
@staticmethod
|
| 159 |
-
def match(node, problem_size: tuple):
|
| 160 |
-
if node.name in LoadImplBase.reserved_names:
|
| 161 |
-
return False
|
| 162 |
-
|
| 163 |
-
strideMN = node.tensor.stride[-2:]
|
| 164 |
-
if strideMN == (0, 1):
|
| 165 |
-
return True
|
| 166 |
-
else:
|
| 167 |
-
return False
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
class ColumnBroadcastImpl(LoadImplBase):
|
| 171 |
-
"""
|
| 172 |
-
Broadcast a column vector
|
| 173 |
-
"""
|
| 174 |
-
def __init__(self, node) -> None:
|
| 175 |
-
super().__init__(node)
|
| 176 |
-
self.stride_dtype = "int"
|
| 177 |
-
|
| 178 |
-
@property
|
| 179 |
-
def argument_type(self):
|
| 180 |
-
stride_mnl = self.get_stride_mnl()
|
| 181 |
-
name = self.name
|
| 182 |
-
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 183 |
-
element_type = self.element
|
| 184 |
-
class _Argument(ctypes.Structure):
|
| 185 |
-
_fields_ = [
|
| 186 |
-
("ptr_col", ctypes.c_void_p),
|
| 187 |
-
("null_default", dtype2ctype[element_type]),
|
| 188 |
-
("dCol", tuple_type)
|
| 189 |
-
]
|
| 190 |
-
def __init__(self, kwargs) -> None:
|
| 191 |
-
ptr = kwargs[name]
|
| 192 |
-
self.ptr_col = int(ptr)
|
| 193 |
-
self.null_default = to_ctype_value(0, element_type)
|
| 194 |
-
self.dCol = tuple_type(stride_mnl)
|
| 195 |
-
|
| 196 |
-
return _Argument
|
| 197 |
-
|
| 198 |
-
@staticmethod
|
| 199 |
-
def match(node, problem_size: tuple):
|
| 200 |
-
if node.name in LoadImplBase.reserved_names:
|
| 201 |
-
return False
|
| 202 |
-
|
| 203 |
-
strideMN = node.tensor.stride[-2:]
|
| 204 |
-
if strideMN == (1, 0):
|
| 205 |
-
return True
|
| 206 |
-
else:
|
| 207 |
-
return False
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
class ScalarBroadcastImpl(LoadImplBase):
|
| 211 |
-
"""
|
| 212 |
-
Broadcast a scalar
|
| 213 |
-
"""
|
| 214 |
-
def __init__(self, node) -> None:
|
| 215 |
-
super().__init__(node)
|
| 216 |
-
self.stride_dtype = "int"
|
| 217 |
-
|
| 218 |
-
@property
|
| 219 |
-
def argument_type(self):
|
| 220 |
-
stride_mnl = self.get_stride_mnl()
|
| 221 |
-
name = self.name
|
| 222 |
-
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 223 |
-
element_type = self.element
|
| 224 |
-
|
| 225 |
-
if self.tensor.is_constant:
|
| 226 |
-
value = self.tensor.value
|
| 227 |
-
class _Argument(ctypes.Structure):
|
| 228 |
-
_fields_ = [
|
| 229 |
-
("scalars", dtype2ctype[element_type]),
|
| 230 |
-
("scalar_ptrs", ctypes.c_void_p),
|
| 231 |
-
("dScalar", tuple_type)
|
| 232 |
-
]
|
| 233 |
-
def __init__(self, kwargs) -> None:
|
| 234 |
-
self.scalars = to_ctype_value(value, element_type)
|
| 235 |
-
self.scalar_ptrs = 0
|
| 236 |
-
self.dScalar = tuple_type(stride_mnl)
|
| 237 |
-
|
| 238 |
-
else:
|
| 239 |
-
class _Argument(ctypes.Structure):
|
| 240 |
-
_fields_ = [
|
| 241 |
-
("scalars", dtype2ctype[element_type]),
|
| 242 |
-
("scalar_ptrs", ctypes.c_void_p),
|
| 243 |
-
("dScalar", tuple_type)
|
| 244 |
-
]
|
| 245 |
-
def __init__(self, kwargs) -> None:
|
| 246 |
-
scalar_or_ptr = kwargs[name]
|
| 247 |
-
if isinstance(scalar_or_ptr, float):
|
| 248 |
-
self.scalars = to_ctype_value(scalar_or_ptr, element_type)
|
| 249 |
-
self.scalar_ptrs = 0
|
| 250 |
-
else:
|
| 251 |
-
self.scalar_ptrs = int(scalar_or_ptr)
|
| 252 |
-
|
| 253 |
-
self.dScalar = tuple_type(stride_mnl)
|
| 254 |
-
|
| 255 |
-
return _Argument
|
| 256 |
-
|
| 257 |
-
@staticmethod
|
| 258 |
-
def match(node, problem_size: tuple):
|
| 259 |
-
if node.name in LoadImplBase.reserved_names:
|
| 260 |
-
return False
|
| 261 |
-
|
| 262 |
-
strideMN = node.tensor.stride[-2:]
|
| 263 |
-
if strideMN == (0, 0):
|
| 264 |
-
return True
|
| 265 |
-
else:
|
| 266 |
-
return False
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
class LoadNode(NodeBase):
|
| 270 |
-
"""
|
| 271 |
-
Load Node
|
| 272 |
-
"""
|
| 273 |
-
cnt = 0
|
| 274 |
-
possible_impls = [
|
| 275 |
-
AccumulatorImpl, LoadSrcImpl, AuxLoadImpl,
|
| 276 |
-
RowBroadcastImpl, ColumnBroadcastImpl,
|
| 277 |
-
ScalarBroadcastImpl
|
| 278 |
-
]
|
| 279 |
-
def __init__(self, name: str) -> None:
|
| 280 |
-
if name is None:
|
| 281 |
-
name = f"load{LoadNode.cnt}"
|
| 282 |
-
LoadNode.cnt += 1
|
| 283 |
-
super().__init__(name)
|
| 284 |
-
self.op = "load"
|
| 285 |
-
|
| 286 |
-
def type_propagation(self, *args, **kwargs):
|
| 287 |
-
"""
|
| 288 |
-
Load node loads tensor under type `tensor.element` and returns an array of type `tensor.element`.
|
| 289 |
-
"""
|
| 290 |
-
if self.tensor is None:
|
| 291 |
-
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
| 292 |
-
|
| 293 |
-
self.element = self.tensor.element
|
| 294 |
-
self.element_output = self.tensor.element
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/node.py
DELETED
|
@@ -1,306 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Base & visitor classes of DAGIR Nodes
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import ctypes
|
| 38 |
-
from re import sub
|
| 39 |
-
|
| 40 |
-
from cutlass_library import LayoutType
|
| 41 |
-
|
| 42 |
-
from cutlass_cppgen.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple
|
| 43 |
-
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class TupleEmitter:
|
| 47 |
-
"""
|
| 48 |
-
Emit the cute tuple to C++ code
|
| 49 |
-
"""
|
| 50 |
-
def __init__(self, stride_dtype):
|
| 51 |
-
self.stride_dtype = stride_dtype
|
| 52 |
-
|
| 53 |
-
def emit(self, py_tuple):
|
| 54 |
-
if isinstance(py_tuple, int):
|
| 55 |
-
if py_tuple in [0, 1]:
|
| 56 |
-
return f"cute::Int<{py_tuple}>"
|
| 57 |
-
else:
|
| 58 |
-
return f"{self.stride_dtype}"
|
| 59 |
-
elif isinstance(py_tuple, tuple):
|
| 60 |
-
decl = "cute::Stride<"
|
| 61 |
-
for item in py_tuple:
|
| 62 |
-
decl += self.emit(item) + ", "
|
| 63 |
-
return decl[:-2] + ">"
|
| 64 |
-
else:
|
| 65 |
-
raise ValueError(f"TupleEmitter.emit only accepts tuple or int, got {type(py_tuple).__name__}")
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
class ImplBase:
|
| 69 |
-
"""
|
| 70 |
-
Base class for Node Implementation
|
| 71 |
-
"""
|
| 72 |
-
def __init__(self, node) -> None:
|
| 73 |
-
self.node = node
|
| 74 |
-
self.name = node.name
|
| 75 |
-
self.tensor = node.tensor
|
| 76 |
-
self._type_decl = None
|
| 77 |
-
self.tuple_emitter = TupleEmitter("int64_t")
|
| 78 |
-
|
| 79 |
-
@property
|
| 80 |
-
def stride_dtype(self):
|
| 81 |
-
return self.tuple_emitter.stride_dtype
|
| 82 |
-
|
| 83 |
-
@stride_dtype.setter
|
| 84 |
-
def stride_dtype(self, stride_dtype):
|
| 85 |
-
self.tuple_emitter.stride_dtype = stride_dtype
|
| 86 |
-
|
| 87 |
-
@staticmethod
|
| 88 |
-
def match(node, problem_size: tuple):
|
| 89 |
-
"""
|
| 90 |
-
Match function used in get_underlying_impl
|
| 91 |
-
"""
|
| 92 |
-
raise NotImplementedError(f"The `match` function is not defined.")
|
| 93 |
-
|
| 94 |
-
@property
|
| 95 |
-
def argument_type(self):
|
| 96 |
-
"""
|
| 97 |
-
Default class for Argument Type
|
| 98 |
-
"""
|
| 99 |
-
class _Argument(ctypes.Structure):
|
| 100 |
-
_fields_ = []
|
| 101 |
-
|
| 102 |
-
def __init__(self, *args, **kwargs) -> None:
|
| 103 |
-
pass
|
| 104 |
-
|
| 105 |
-
return _Argument
|
| 106 |
-
|
| 107 |
-
@property
|
| 108 |
-
def name_camel(self) -> str:
|
| 109 |
-
"""
|
| 110 |
-
Return the CamelCase name.
|
| 111 |
-
"""
|
| 112 |
-
return sub(r"(_|-)+", " ", self.name).title().replace(" ", "")
|
| 113 |
-
|
| 114 |
-
@property
|
| 115 |
-
def stride_mnl(self):
|
| 116 |
-
"""
|
| 117 |
-
Typename StrideMNL
|
| 118 |
-
"""
|
| 119 |
-
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
| 120 |
-
return self.tuple_emitter.emit(stride)
|
| 121 |
-
|
| 122 |
-
def get_non_constant_stride(self, py_tuple):
|
| 123 |
-
if isinstance(py_tuple, int):
|
| 124 |
-
if py_tuple not in [0, 1]:
|
| 125 |
-
return py_tuple
|
| 126 |
-
else:
|
| 127 |
-
return None
|
| 128 |
-
non_constant_stride = []
|
| 129 |
-
for item in py_tuple:
|
| 130 |
-
item_out = self.get_non_constant_stride(item)
|
| 131 |
-
if item_out:
|
| 132 |
-
non_constant_stride.append(item_out)
|
| 133 |
-
return tuple(non_constant_stride)
|
| 134 |
-
|
| 135 |
-
def get_stride_mnl(self):
|
| 136 |
-
"""
|
| 137 |
-
Get the non-zero stride mnl. This is used in argument construction
|
| 138 |
-
"""
|
| 139 |
-
stride = _list_to_tuple([self.stride[-2], self.stride[-1]] + list(_reverse_tuple(tuple(self.stride[:-2]))))
|
| 140 |
-
return stride
|
| 141 |
-
|
| 142 |
-
def get_smem_size(self, *args, **kwargs):
|
| 143 |
-
"""
|
| 144 |
-
Get the shared memory size and alignment of current node
|
| 145 |
-
"""
|
| 146 |
-
return (0, 1)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
class NoOpImpl(ImplBase):
|
| 150 |
-
"""
|
| 151 |
-
The NoOpImpl does nothing but forward its input to users
|
| 152 |
-
"""
|
| 153 |
-
def __init__(self, node) -> None:
|
| 154 |
-
super().__init__(node)
|
| 155 |
-
|
| 156 |
-
@staticmethod
|
| 157 |
-
def match(node, problem_size: tuple):
|
| 158 |
-
if node.op == "store":
|
| 159 |
-
# Store that is not output is a No OP
|
| 160 |
-
return not node.is_output
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
class NodeBase:
|
| 164 |
-
"""
|
| 165 |
-
Base class of DAG Node
|
| 166 |
-
"""
|
| 167 |
-
def __init__(self, name: str) -> None:
|
| 168 |
-
self.name = name
|
| 169 |
-
self.underlying_impl = None
|
| 170 |
-
|
| 171 |
-
self._tensor = None
|
| 172 |
-
|
| 173 |
-
# Whether the node is disabled for emit
|
| 174 |
-
self.disabled = False
|
| 175 |
-
|
| 176 |
-
@property
|
| 177 |
-
def name_camel(self) -> str:
|
| 178 |
-
"""
|
| 179 |
-
Return the CamelCase name.
|
| 180 |
-
"""
|
| 181 |
-
return self.underlying_impl.name_camel
|
| 182 |
-
|
| 183 |
-
@property
|
| 184 |
-
def tensor(self) -> Tensor:
|
| 185 |
-
"""
|
| 186 |
-
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
| 187 |
-
"""
|
| 188 |
-
return self._tensor
|
| 189 |
-
|
| 190 |
-
@tensor.setter
|
| 191 |
-
def tensor(self, kwargs):
|
| 192 |
-
"""
|
| 193 |
-
Setting the tensor
|
| 194 |
-
"""
|
| 195 |
-
self._tensor = Tensor(**kwargs)
|
| 196 |
-
|
| 197 |
-
#
|
| 198 |
-
# Helper functions for type/shape propagation
|
| 199 |
-
#
|
| 200 |
-
|
| 201 |
-
def shape_propagation(self, input_node_metas):
|
| 202 |
-
"""
|
| 203 |
-
Infer shape from input nodes
|
| 204 |
-
General Broadcasting Rules from NumPy
|
| 205 |
-
When operating on two arrays, we compare their shapes element-wise.
|
| 206 |
-
It starts with the trailing (i.e. rightmost) dimension and works its
|
| 207 |
-
way left. Two dimensions are compatible when
|
| 208 |
-
1. they are equal
|
| 209 |
-
2. one of them is 1
|
| 210 |
-
"""
|
| 211 |
-
if self._tensor is not None:
|
| 212 |
-
return
|
| 213 |
-
|
| 214 |
-
shape = None
|
| 215 |
-
for src in input_node_metas:
|
| 216 |
-
src_shape = src.tensor.shape
|
| 217 |
-
if shape is None:
|
| 218 |
-
shape = src_shape
|
| 219 |
-
else:
|
| 220 |
-
len_difference = len(shape) - len(src_shape)
|
| 221 |
-
if len_difference > 0:
|
| 222 |
-
for _ in range(len_difference):
|
| 223 |
-
src_shape = [1, ] + list(src_shape)
|
| 224 |
-
elif len_difference < 0:
|
| 225 |
-
for _ in range(-len_difference):
|
| 226 |
-
shape = [1, ] + list(shape)
|
| 227 |
-
broadcasted_shape = []
|
| 228 |
-
# Infer broadcast shape
|
| 229 |
-
for shape_dim, src_dim in zip(reversed(shape), reversed(src_shape)):
|
| 230 |
-
if shape_dim == 1:
|
| 231 |
-
broadcasted_shape = [src_dim, ] + list(broadcasted_shape)
|
| 232 |
-
elif src_dim == 1:
|
| 233 |
-
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
| 234 |
-
elif shape_dim == src_dim:
|
| 235 |
-
broadcasted_shape = [shape_dim, ] + list(broadcasted_shape)
|
| 236 |
-
else:
|
| 237 |
-
error_msg = "Dimension mismatch between "
|
| 238 |
-
for src_ in input_node_metas:
|
| 239 |
-
error_msg += f"{src_.name}{src_.tensor.shape}, "
|
| 240 |
-
error_msg = error_msg[:-2] + "."
|
| 241 |
-
raise RuntimeError(error_msg)
|
| 242 |
-
shape = tuple(broadcasted_shape)
|
| 243 |
-
|
| 244 |
-
self._tensor = Tensor(element=self.element_output, shape=shape, layout_tag=LayoutType.RowMajor)
|
| 245 |
-
|
| 246 |
-
def type_propagation(self, *args, **kwargs):
|
| 247 |
-
"""
|
| 248 |
-
Each node is associated with two data types: `element` and `element_output`.
|
| 249 |
-
The `element_output` is the type of return array of the node. The `element`
|
| 250 |
-
has specific meaning for different node types.
|
| 251 |
-
* Load Node: data type of tensor in gmem
|
| 252 |
-
* Compute Node: element compute
|
| 253 |
-
* Store Node: data type of tensor in gmem
|
| 254 |
-
This function must be overloaded in the derived classes
|
| 255 |
-
"""
|
| 256 |
-
raise NotImplementedError(f"Function `type_propagation` is not overloaded in {self.__class__.__name__}")
|
| 257 |
-
|
| 258 |
-
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
| 259 |
-
"""
|
| 260 |
-
Propagate the broadcast in the reversed topological order.
|
| 261 |
-
For example:
|
| 262 |
-
C[l, m, n] = A[m, 1] + B[l, m, n]
|
| 263 |
-
After the broadcast propagation, it will be come
|
| 264 |
-
C[l, m, n] = A[l, m, n] + B[l, m, n]
|
| 265 |
-
and each tensor will have a proper stride accessing the underlying tensor
|
| 266 |
-
"""
|
| 267 |
-
if self.tensor is None:
|
| 268 |
-
raise RuntimeError(f"The tensor of node {self.name} is unknown.")
|
| 269 |
-
for child in input_node_metas:
|
| 270 |
-
child.tensor.broadcast(self.tensor.shape)
|
| 271 |
-
|
| 272 |
-
def get_underlying_impl(self, problem_size: tuple):
|
| 273 |
-
"""
|
| 274 |
-
Get the underlying implementation of the current node.
|
| 275 |
-
"""
|
| 276 |
-
if self.tensor is None:
|
| 277 |
-
raise RuntimeError(f"The Layout of node {self.name} is unknown. Please call PassShapeTypePropagation first.")
|
| 278 |
-
|
| 279 |
-
for impl in self.possible_impls:
|
| 280 |
-
if impl.match(self, problem_size):
|
| 281 |
-
self.underlying_impl = impl(self)
|
| 282 |
-
break
|
| 283 |
-
|
| 284 |
-
if self.underlying_impl is None:
|
| 285 |
-
raise NotImplementedError(f"No matching op for node {self.name} with stride {self.tensor.stride}.")
|
| 286 |
-
|
| 287 |
-
#
|
| 288 |
-
# Visitor Nodes & Impls
|
| 289 |
-
#
|
| 290 |
-
|
| 291 |
-
class TopoVisitorImpl(ImplBase):
|
| 292 |
-
"""
|
| 293 |
-
Impl for topological visitor
|
| 294 |
-
"""
|
| 295 |
-
def __init__(self, node) -> None:
|
| 296 |
-
super().__init__(node.output_node)
|
| 297 |
-
self.name = node.name
|
| 298 |
-
self.element_output = node.output_node.element_output
|
| 299 |
-
|
| 300 |
-
class TopoVisitorNode(NodeBase):
|
| 301 |
-
def __init__(self, name: str, subgraph, output_node) -> None:
|
| 302 |
-
super().__init__(name)
|
| 303 |
-
self.subgraph = subgraph
|
| 304 |
-
self.output_node = output_node
|
| 305 |
-
self.op = "dag"
|
| 306 |
-
self.underlying_impl = TopoVisitorImpl(self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/store_nodes.py
DELETED
|
@@ -1,277 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Store node and implementations
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import ctypes
|
| 38 |
-
|
| 39 |
-
from cutlass_library import DataType
|
| 40 |
-
|
| 41 |
-
from cutlass_cppgen.backend.c_types import tuple_factory
|
| 42 |
-
from cutlass_cppgen.backend.epilogue import dtype2ctype, to_ctype_value
|
| 43 |
-
from cutlass_cppgen.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl
|
| 44 |
-
from cutlass_cppgen.backend.evt.ir.tensor import Tensor
|
| 45 |
-
from cutlass_cppgen.backend.library import FloatRoundStyle, FunctionalOp
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
class StoreImplBase(ImplBase):
|
| 49 |
-
"""
|
| 50 |
-
Base class for store node implementation
|
| 51 |
-
"""
|
| 52 |
-
reserved_names = ["D"]
|
| 53 |
-
def __init__(self, node) -> None:
|
| 54 |
-
super().__init__(node)
|
| 55 |
-
self.element = node.element
|
| 56 |
-
self.element_output = node.element_output
|
| 57 |
-
self.stride = node.store_tensor.stride
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
class StoreDImpl(StoreImplBase):
|
| 61 |
-
"""
|
| 62 |
-
Store D implementation
|
| 63 |
-
"""
|
| 64 |
-
|
| 65 |
-
@property
|
| 66 |
-
def argument_type_d(self):
|
| 67 |
-
stride_mnl = self.get_stride_mnl()
|
| 68 |
-
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 69 |
-
class _Argument(ctypes.Structure):
|
| 70 |
-
_fields_ = [
|
| 71 |
-
("ptr_D", ctypes.c_void_p),
|
| 72 |
-
("stride_D", tuple_type)
|
| 73 |
-
]
|
| 74 |
-
def __init__(self, ptr: int) -> None:
|
| 75 |
-
self.ptr_D = ptr
|
| 76 |
-
self.stride_D = tuple_type(stride_mnl)
|
| 77 |
-
|
| 78 |
-
return _Argument
|
| 79 |
-
|
| 80 |
-
@staticmethod
|
| 81 |
-
def match(node, problem_size: tuple):
|
| 82 |
-
if node.name == "D" and node.store_tensor.shape == problem_size:
|
| 83 |
-
return True
|
| 84 |
-
return False
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
class AuxStoreImpl(StoreImplBase):
|
| 88 |
-
def __init__(self, node) -> None:
|
| 89 |
-
super().__init__(node)
|
| 90 |
-
self.round_style = FloatRoundStyle.ToNearest
|
| 91 |
-
|
| 92 |
-
@property
|
| 93 |
-
def argument_type(self):
|
| 94 |
-
stride_mnl = self.get_stride_mnl()
|
| 95 |
-
name = self.name
|
| 96 |
-
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 97 |
-
class _Argument(ctypes.Structure):
|
| 98 |
-
_fields_ = [
|
| 99 |
-
("ptr_aux", ctypes.c_void_p),
|
| 100 |
-
("dAux", tuple_type)
|
| 101 |
-
]
|
| 102 |
-
def __init__(self, kwargs) -> None:
|
| 103 |
-
ptr = kwargs[name]
|
| 104 |
-
self.ptr_aux = ptr
|
| 105 |
-
self.dAux = tuple_type(stride_mnl)
|
| 106 |
-
|
| 107 |
-
return _Argument
|
| 108 |
-
|
| 109 |
-
@staticmethod
|
| 110 |
-
def match(node, problem_size: tuple):
|
| 111 |
-
if not node.is_output:
|
| 112 |
-
return False
|
| 113 |
-
if node.name in StoreImplBase.reserved_names:
|
| 114 |
-
return False
|
| 115 |
-
|
| 116 |
-
strideMN = node.store_tensor.stride[-2:]
|
| 117 |
-
if (strideMN[0] == 1 and strideMN[1] != 0 or
|
| 118 |
-
strideMN[0] != 0 and strideMN[1] == 1 ):
|
| 119 |
-
return True
|
| 120 |
-
else:
|
| 121 |
-
return False
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
class ReductionImplBase(StoreImplBase):
|
| 125 |
-
def __init__(self, node) -> None:
|
| 126 |
-
super().__init__(node)
|
| 127 |
-
self.element = node.store_tensor.element
|
| 128 |
-
self.element_compute = node.element_compute
|
| 129 |
-
self.reg_reduce_fn = self.node.reg_reduce_fn
|
| 130 |
-
self.gmem_reduce_fn = self.node.gmem_reduce_fn
|
| 131 |
-
self.round_style = node.round_style
|
| 132 |
-
self.stride_dtype = "int"
|
| 133 |
-
|
| 134 |
-
def get_reduce_identity(self):
|
| 135 |
-
"""
|
| 136 |
-
Return the reduction identity of the current reduce_fn
|
| 137 |
-
"""
|
| 138 |
-
maxes = {
|
| 139 |
-
DataType.f32: (2 ** 31) - 1,
|
| 140 |
-
DataType.f16: (2 ** 15),
|
| 141 |
-
DataType.s32: (2 ** 31) - 1,
|
| 142 |
-
DataType.s8: (2 ** 7) - 1
|
| 143 |
-
}
|
| 144 |
-
mins = {
|
| 145 |
-
DataType.f32: -maxes[DataType.f32],
|
| 146 |
-
DataType.f16: -maxes[DataType.f16],
|
| 147 |
-
DataType.s32: -maxes[DataType.s32],
|
| 148 |
-
DataType.s8: -maxes[DataType.s8]
|
| 149 |
-
}
|
| 150 |
-
if self.reg_reduce_fn == FunctionalOp.Maximum:
|
| 151 |
-
if self.element_compute not in mins:
|
| 152 |
-
raise Exception(f"No min entry for data type {self.element_compute}")
|
| 153 |
-
return to_ctype_value(mins[self.element_compute], self.element_compute)
|
| 154 |
-
elif self.reg_reduce_fn == FunctionalOp.Multiplies:
|
| 155 |
-
return to_ctype_value(1., self.element_compute)
|
| 156 |
-
elif self.reg_reduce_fn == FunctionalOp.Minimum:
|
| 157 |
-
if self.element_compute not in maxes:
|
| 158 |
-
raise Exception(f"No max entry for data type {self.element_compute}")
|
| 159 |
-
return to_ctype_value(maxes[self.element_compute], self.element_compute)
|
| 160 |
-
else:
|
| 161 |
-
return to_ctype_value(0., self.element_compute)
|
| 162 |
-
|
| 163 |
-
@property
|
| 164 |
-
def argument_type(self):
|
| 165 |
-
self.get_reduce_identity()
|
| 166 |
-
stride_mnl = self.get_stride_mnl()
|
| 167 |
-
name = self.name
|
| 168 |
-
tuple_type = tuple_factory(stride_mnl, self.stride_dtype)
|
| 169 |
-
element_compute = self.element_compute
|
| 170 |
-
reduce_identity = self.get_reduce_identity()
|
| 171 |
-
class _Argument(ctypes.Structure):
|
| 172 |
-
_fields_ = [
|
| 173 |
-
("ptr", ctypes.c_void_p),
|
| 174 |
-
("reduce_identity", dtype2ctype[element_compute]),
|
| 175 |
-
("dMNL", tuple_type)
|
| 176 |
-
]
|
| 177 |
-
def __init__(self, kwargs) -> None:
|
| 178 |
-
ptr = kwargs[name]
|
| 179 |
-
self.ptr = ptr
|
| 180 |
-
self.reduce_identity = reduce_identity
|
| 181 |
-
self.dMNL = tuple_type(stride_mnl)
|
| 182 |
-
|
| 183 |
-
return _Argument
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
class ColumnReductionImpl(ReductionImplBase):
|
| 187 |
-
|
| 188 |
-
@staticmethod
|
| 189 |
-
def match(node, problem_size: tuple):
|
| 190 |
-
if not node.is_output:
|
| 191 |
-
return False
|
| 192 |
-
if node.name in StoreImplBase.reserved_names:
|
| 193 |
-
return False
|
| 194 |
-
|
| 195 |
-
strideMN = node.store_tensor.stride[-2:]
|
| 196 |
-
if strideMN == (1, 0):
|
| 197 |
-
return True
|
| 198 |
-
else:
|
| 199 |
-
return False
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
class RowReductionImpl(ReductionImplBase):
|
| 203 |
-
|
| 204 |
-
@staticmethod
|
| 205 |
-
def match(node, problem_size: tuple):
|
| 206 |
-
if not node.is_output:
|
| 207 |
-
return False
|
| 208 |
-
if node.name in StoreImplBase.reserved_names:
|
| 209 |
-
return False
|
| 210 |
-
|
| 211 |
-
strideMN = node.store_tensor.stride[-2:]
|
| 212 |
-
if strideMN == (0, 1):
|
| 213 |
-
return True
|
| 214 |
-
else:
|
| 215 |
-
return False
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
class ScalarReductionImpl(ReductionImplBase):
|
| 219 |
-
|
| 220 |
-
@staticmethod
|
| 221 |
-
def match(node, problem_size: tuple):
|
| 222 |
-
if not node.is_output:
|
| 223 |
-
return False
|
| 224 |
-
if node.name in StoreImplBase.reserved_names:
|
| 225 |
-
return False
|
| 226 |
-
|
| 227 |
-
strideMN = node.store_tensor.stride[-2:]
|
| 228 |
-
if strideMN == (0, 0):
|
| 229 |
-
return True
|
| 230 |
-
else:
|
| 231 |
-
return False
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
class StoreNode(NodeBase):
|
| 235 |
-
"""
|
| 236 |
-
Store node
|
| 237 |
-
"""
|
| 238 |
-
possible_impls = [
|
| 239 |
-
AuxStoreImpl, RowReductionImpl,
|
| 240 |
-
ColumnReductionImpl, ScalarReductionImpl,
|
| 241 |
-
NoOpImpl, StoreDImpl
|
| 242 |
-
]
|
| 243 |
-
def __init__(self, name: str) -> None:
|
| 244 |
-
super().__init__(name)
|
| 245 |
-
self.op = "store"
|
| 246 |
-
self.is_output = False
|
| 247 |
-
self._store_tensor = None
|
| 248 |
-
|
| 249 |
-
@property
|
| 250 |
-
def store_tensor(self) -> Tensor:
|
| 251 |
-
"""
|
| 252 |
-
Return the output tensor (concept: cutlass_cppgen.backend.evt.ir.tensor)
|
| 253 |
-
"""
|
| 254 |
-
return self._store_tensor
|
| 255 |
-
|
| 256 |
-
@store_tensor.setter
|
| 257 |
-
def store_tensor(self, kwargs):
|
| 258 |
-
"""
|
| 259 |
-
Setting the tensor
|
| 260 |
-
"""
|
| 261 |
-
self._store_tensor = Tensor(**kwargs)
|
| 262 |
-
|
| 263 |
-
def type_propagation(self, input_node_metas: 'list[NodeBase]'):
|
| 264 |
-
"""
|
| 265 |
-
The store nodes has element_output = element_input
|
| 266 |
-
"""
|
| 267 |
-
if self.is_output:
|
| 268 |
-
if self.store_tensor is None:
|
| 269 |
-
raise RuntimeError(f"The store tensor of node {self.name} is unknown.")
|
| 270 |
-
self.element = self.store_tensor.element
|
| 271 |
-
assert len(input_node_metas) == 1, "Store node can only have one input node"
|
| 272 |
-
self.element_output = input_node_metas[0].element_output
|
| 273 |
-
|
| 274 |
-
def broadcast_propagation(self, input_node_metas: 'list[NodeBase]'):
|
| 275 |
-
super().broadcast_propagation(input_node_metas)
|
| 276 |
-
if self.is_output:
|
| 277 |
-
self._store_tensor.broadcast(self.tensor.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/ir/tensor.py
DELETED
|
@@ -1,137 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
High-level class for tensor
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from cutlass_library import LayoutType
|
| 38 |
-
|
| 39 |
-
from cutlass_cppgen.backend.evt.ir.layout_algorithm import (
|
| 40 |
-
Layout,
|
| 41 |
-
broadcast,
|
| 42 |
-
canonicalization,
|
| 43 |
-
permutation,
|
| 44 |
-
reshape,
|
| 45 |
-
_reverse_tuple
|
| 46 |
-
)
|
| 47 |
-
from cutlass_cppgen.utils.datatypes import get_datatype_and_layout, get_tensor_shape, library_type
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
class Tensor:
|
| 51 |
-
"""
|
| 52 |
-
The tensor abstracts the data type
|
| 53 |
-
"""
|
| 54 |
-
def __init__(self, tensor=None, element=None, shape=None, stride=None,layout_tag=None, is_constant=False) -> None:
|
| 55 |
-
if element is not None and tensor is not None:
|
| 56 |
-
raise Exception(f"Must not specify both element and tensor")
|
| 57 |
-
elif shape is not None and tensor is not None:
|
| 58 |
-
raise Exception(f"Must not specify both shape and tensor")
|
| 59 |
-
elif layout_tag is not None and tensor is not None:
|
| 60 |
-
raise Exception(f"Must not specify both layout_tag and tensor")
|
| 61 |
-
elif (element is None or (layout_tag is None and stride is None) or shape is None) and (tensor is None) :
|
| 62 |
-
raise Exception(f"Must specify one of (element, shape, layout/stride) or (tensor)")
|
| 63 |
-
elif stride is not None and tensor is not None:
|
| 64 |
-
raise Exception(f"Must not specify both stride and tensor")
|
| 65 |
-
elif stride is not None and layout_tag is not None:
|
| 66 |
-
raise Exception(f"Must not specify layout_tag when stride is provided")
|
| 67 |
-
|
| 68 |
-
if isinstance(tensor, Tensor):
|
| 69 |
-
# Directly copy all the attributes
|
| 70 |
-
self.__dict__.update(vars(tensor))
|
| 71 |
-
else:
|
| 72 |
-
if tensor is None:
|
| 73 |
-
self.element = library_type(element)
|
| 74 |
-
else:
|
| 75 |
-
self.element, layout_tag = get_datatype_and_layout(tensor)
|
| 76 |
-
shape = get_tensor_shape(tensor)
|
| 77 |
-
if stride is not None:
|
| 78 |
-
self.layout = Layout(shape[::-1], stride[::-1])
|
| 79 |
-
else:
|
| 80 |
-
if layout_tag == LayoutType.RowMajor:
|
| 81 |
-
self.layout = Layout(shape[::-1])
|
| 82 |
-
elif layout_tag == LayoutType.ColumnMajor:
|
| 83 |
-
self.layout = permutation(Layout(shape), [idx for idx in reversed(range(len(shape)))])
|
| 84 |
-
self.layout = canonicalization(self.layout)
|
| 85 |
-
|
| 86 |
-
self.is_constant = is_constant
|
| 87 |
-
# Save the tensor value if it is constant
|
| 88 |
-
if is_constant and tensor is not None:
|
| 89 |
-
self.value = tensor
|
| 90 |
-
|
| 91 |
-
@property
|
| 92 |
-
def shape(self):
|
| 93 |
-
"""
|
| 94 |
-
Returns the RowMajor layout shape
|
| 95 |
-
"""
|
| 96 |
-
return _reverse_tuple(self.layout.shape)
|
| 97 |
-
|
| 98 |
-
@property
|
| 99 |
-
def stride(self):
|
| 100 |
-
"""
|
| 101 |
-
Returns the RowMajor layout stride
|
| 102 |
-
"""
|
| 103 |
-
return _reverse_tuple(self.layout.stride)
|
| 104 |
-
|
| 105 |
-
@property
|
| 106 |
-
def rank(self):
|
| 107 |
-
"""
|
| 108 |
-
Returns the rank of the tensor
|
| 109 |
-
"""
|
| 110 |
-
return len(self.shape)
|
| 111 |
-
|
| 112 |
-
#
|
| 113 |
-
# Layout Algorithms
|
| 114 |
-
#
|
| 115 |
-
|
| 116 |
-
def broadcast(self, shape):
|
| 117 |
-
"""
|
| 118 |
-
Broadcast self.layout to shape
|
| 119 |
-
"""
|
| 120 |
-
assert isinstance(shape, tuple)
|
| 121 |
-
self.layout = broadcast(self.layout, _reverse_tuple(shape))
|
| 122 |
-
|
| 123 |
-
def reshape(self, shape):
|
| 124 |
-
"""
|
| 125 |
-
Reshape self.layout to shape
|
| 126 |
-
"""
|
| 127 |
-
assert isinstance(shape, tuple)
|
| 128 |
-
reverse_shape = _reverse_tuple(shape)
|
| 129 |
-
self.layout = reshape(self.layout, reverse_shape)
|
| 130 |
-
|
| 131 |
-
def permute(self, indices):
|
| 132 |
-
"""
|
| 133 |
-
Permute self.layout according to indices
|
| 134 |
-
"""
|
| 135 |
-
length = len(indices)
|
| 136 |
-
indices = [length - idx - 1 for idx in indices]
|
| 137 |
-
self.layout = permutation(self.layout, indices[::-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/__init__.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
from cutlass_cppgen.backend.evt.passes.graph_drawer import EVTGraphDrawer
|
| 34 |
-
from cutlass_cppgen.backend.evt.passes.pass_argument_type import PassGetArgumentType
|
| 35 |
-
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
| 36 |
-
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
| 37 |
-
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
| 38 |
-
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
| 39 |
-
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassManager
|
| 40 |
-
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
| 41 |
-
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
| 42 |
-
from cutlass_cppgen.backend.evt.passes.smem_size_calculator import GetSmemSize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/graph_drawer.py
DELETED
|
@@ -1,143 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
from __future__ import annotations
|
| 33 |
-
|
| 34 |
-
import subprocess
|
| 35 |
-
|
| 36 |
-
from cutlass_library import DataTypeTag
|
| 37 |
-
|
| 38 |
-
from cutlass_cppgen.backend.evt.ir.dag_ir import DAGIR
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
_COLOR_MAP = {
|
| 42 |
-
"load": '"AliceBlue"',
|
| 43 |
-
"compute": "LemonChiffon1",
|
| 44 |
-
"accumulator": "LightGrey",
|
| 45 |
-
"store": "PowderBlue",
|
| 46 |
-
"layout": "lightseagreen",
|
| 47 |
-
"dag": "darkorange"
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class EVTGraphDrawer:
|
| 52 |
-
"""
|
| 53 |
-
Visualize a EVT DAGIR with graphviz
|
| 54 |
-
"""
|
| 55 |
-
def __init__(
|
| 56 |
-
self,
|
| 57 |
-
graph: DAGIR,
|
| 58 |
-
name: str
|
| 59 |
-
):
|
| 60 |
-
self._name = name
|
| 61 |
-
self._dot_graphs = {}
|
| 62 |
-
|
| 63 |
-
self._dot_graphs[name] = self._to_dot(graph, name)
|
| 64 |
-
|
| 65 |
-
def _get_node_style(self, node):
|
| 66 |
-
template = {
|
| 67 |
-
"shape": "record",
|
| 68 |
-
"fillcolor": "#CAFFE3",
|
| 69 |
-
"style": '"filled,rounded"',
|
| 70 |
-
"fontcolor": "#000000",
|
| 71 |
-
}
|
| 72 |
-
if node.op in _COLOR_MAP:
|
| 73 |
-
template["fillcolor"] = _COLOR_MAP[node.op]
|
| 74 |
-
else:
|
| 75 |
-
raise NotImplementedError("unknown node op")
|
| 76 |
-
if node.disabled:
|
| 77 |
-
template["fontcolor"] = "grey"
|
| 78 |
-
template["fillcolor"] = "white"
|
| 79 |
-
return template
|
| 80 |
-
|
| 81 |
-
def _get_node_label(self, node):
|
| 82 |
-
label = "{" + f"name={node.name}|op={node.op}"
|
| 83 |
-
if node.op == "layout":
|
| 84 |
-
label += f"|fn={node.fn.__name__}"
|
| 85 |
-
for key in node.kwargs:
|
| 86 |
-
label += f"|{key}={node.kwargs[key]}"
|
| 87 |
-
if node.underlying_impl is not None:
|
| 88 |
-
label += f"|impl={type(node.underlying_impl).__name__}"
|
| 89 |
-
if node.op == "load":
|
| 90 |
-
label += f"|element_output={DataTypeTag[node.underlying_impl.element]}"
|
| 91 |
-
elif node.op == "compute":
|
| 92 |
-
label += f"|element_compute={DataTypeTag[node.underlying_impl.element_compute]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
| 93 |
-
elif node.op == "store":
|
| 94 |
-
label += f"|element_store={DataTypeTag[node.underlying_impl.element]}|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
| 95 |
-
elif node.op == "dag":
|
| 96 |
-
label += f"|element_output={DataTypeTag[node.underlying_impl.element_output]}"
|
| 97 |
-
if node.tensor is not None:
|
| 98 |
-
shape = node.tensor.shape
|
| 99 |
-
stride = node.tensor.stride
|
| 100 |
-
label += f"|shape={shape}|stride={stride}"
|
| 101 |
-
|
| 102 |
-
if hasattr(node, "store_tensor"):
|
| 103 |
-
if node.store_tensor is not None:
|
| 104 |
-
store_shape = node.store_tensor.shape
|
| 105 |
-
store_stride = node.store_tensor.stride
|
| 106 |
-
label += f"|store_shape={store_shape}|stride_stride={store_stride}"
|
| 107 |
-
|
| 108 |
-
label += "}"
|
| 109 |
-
return label
|
| 110 |
-
|
| 111 |
-
def _to_dot(
|
| 112 |
-
self,
|
| 113 |
-
graph: DAGIR,
|
| 114 |
-
name: str
|
| 115 |
-
):
|
| 116 |
-
import pydot
|
| 117 |
-
dot_graph = pydot.Dot(name, randir="TB")
|
| 118 |
-
for node in graph.nodes_meta:
|
| 119 |
-
style = self._get_node_style(node)
|
| 120 |
-
label = self._get_node_label(node)
|
| 121 |
-
dot_node = pydot.Node(
|
| 122 |
-
node.name, label=label, **style
|
| 123 |
-
)
|
| 124 |
-
dot_graph.add_node(dot_node)
|
| 125 |
-
if node.op == "dag":
|
| 126 |
-
dot_subgraph = self._to_dot(node.subgraph, name=node.name)
|
| 127 |
-
self._dot_graphs[node.name] = dot_subgraph
|
| 128 |
-
|
| 129 |
-
# Add edges
|
| 130 |
-
for src, dst in graph.edges:
|
| 131 |
-
weight = graph.get_edge_weight(src, dst)
|
| 132 |
-
dot_graph.add_edge(pydot.Edge(src, dst, label=weight))
|
| 133 |
-
|
| 134 |
-
return dot_graph
|
| 135 |
-
|
| 136 |
-
def get_dot_graph(self) -> pydot.Dot:
|
| 137 |
-
return [(key, self.get_dot_graph_by_name(key)) for key in self._dot_graphs.keys()]
|
| 138 |
-
|
| 139 |
-
def get_dot_graph_by_name(self, name) -> pydot.Dot:
|
| 140 |
-
return self._dot_graphs[name]
|
| 141 |
-
|
| 142 |
-
def get_main_dot_graph(self) -> pydot.Dot:
|
| 143 |
-
return self._dot_graphs[self._name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_argument_type.py
DELETED
|
@@ -1,120 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Construct the epilogue visitor argument type
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from cutlass_cppgen.backend.c_types import visitor_factory
|
| 38 |
-
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode
|
| 39 |
-
from cutlass_cppgen.backend.evt.passes.pass_dag_2_tree import PassDAG2Tree
|
| 40 |
-
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
| 41 |
-
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 42 |
-
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
| 43 |
-
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class PassGetArgumentType(EVTPassBase):
|
| 47 |
-
"""
|
| 48 |
-
Construct the epilogue visitor argument type
|
| 49 |
-
"""
|
| 50 |
-
dependencies = [
|
| 51 |
-
PassShapeTypePropagation, # The Layout of all nodes must be set
|
| 52 |
-
PassDAG2Tree, # The type of each node must be set
|
| 53 |
-
PassGetImpl # The DAG subgraphs must be set
|
| 54 |
-
]
|
| 55 |
-
|
| 56 |
-
def requires(self) -> None:
|
| 57 |
-
# Check "D" is in the node list
|
| 58 |
-
if cc_map[self.cc] in [90, 100] and (not self.dag_ir.has_node("D")):
|
| 59 |
-
raise SyntaxError(
|
| 60 |
-
"Sm90+ EVT requires the epilogue to have a returned tensor D, "
|
| 61 |
-
"but the variable 'D' is not found in the return values.")
|
| 62 |
-
|
| 63 |
-
def call(self):
|
| 64 |
-
nodes = self.dag_ir.nodes_topological_order()
|
| 65 |
-
self.argument_types = {}
|
| 66 |
-
for node in nodes:
|
| 67 |
-
meta = self.dag_ir.get_node_meta(node)
|
| 68 |
-
if not meta.disabled:
|
| 69 |
-
self.argument_types[node] = meta.underlying_impl.argument_type
|
| 70 |
-
if node == "D" and cc_map[self.cc] in [90, 100]:
|
| 71 |
-
continue
|
| 72 |
-
if isinstance(meta, TopoVisitorNode):
|
| 73 |
-
self.get_dag_argument_type(node)
|
| 74 |
-
else:
|
| 75 |
-
self.get_evt_argument_type(node)
|
| 76 |
-
|
| 77 |
-
self.cc_specific_method(self.set_argument_type)()
|
| 78 |
-
|
| 79 |
-
def get_evt_argument_type(self, node):
|
| 80 |
-
# Sort the input nodes by edge weight
|
| 81 |
-
input_types = [self.argument_types[child] for child in self.dag_ir.get_all_inputs(node)]
|
| 82 |
-
if len(input_types) > 0:
|
| 83 |
-
self.argument_types[node] = visitor_factory(
|
| 84 |
-
input_types + [self.argument_types[node],], self.dag_ir.get_all_inputs(node) + [node,])
|
| 85 |
-
|
| 86 |
-
def get_dag_argument_type(self, node):
|
| 87 |
-
meta = self.dag_ir.get_node_meta(node)
|
| 88 |
-
subgraph = meta.subgraph
|
| 89 |
-
subgraph_nodes = subgraph.nodes_topological_order()
|
| 90 |
-
# Visit the unvisited nodes in subgraph
|
| 91 |
-
for n in subgraph_nodes:
|
| 92 |
-
m = subgraph.get_node_meta(n)
|
| 93 |
-
if m.disabled:
|
| 94 |
-
continue
|
| 95 |
-
else:
|
| 96 |
-
self.argument_types[n] = m.underlying_impl.argument_type
|
| 97 |
-
input_types = [self.argument_types[child] for child in subgraph_nodes[:-1]]
|
| 98 |
-
if len(input_types) > 0:
|
| 99 |
-
self.argument_types[node] = visitor_factory(input_types, subgraph_nodes[:-1])
|
| 100 |
-
|
| 101 |
-
def set_argument_type(self):
|
| 102 |
-
pass
|
| 103 |
-
|
| 104 |
-
def sm90_set_argument_type(self):
|
| 105 |
-
self.dag_ir.epilogue_thread_type = self.argument_types[self.dag_ir.get_all_inputs("D")[0]]
|
| 106 |
-
# Get the tensorD argument type
|
| 107 |
-
self.dag_ir.arg_d_type = self.dag_ir.get_node_meta("D").underlying_impl.argument_type_d
|
| 108 |
-
|
| 109 |
-
# Get the tensorC argument type
|
| 110 |
-
if self.dag_ir.has_node("C"):
|
| 111 |
-
self.dag_ir.arg_c_type = self.dag_ir.get_node_meta("C").underlying_impl.argument_type_c
|
| 112 |
-
else:
|
| 113 |
-
self.dag_ir.arg_c_type = self.dag_ir.arg_d_type
|
| 114 |
-
|
| 115 |
-
def sm100_set_argument_type(self):
|
| 116 |
-
self.sm90_set_argument_type()
|
| 117 |
-
|
| 118 |
-
def sm80_set_argument_type(self):
|
| 119 |
-
nodes = self.dag_ir.nodes_topological_order()
|
| 120 |
-
self.dag_ir.epilogue_thread_type = self.argument_types[nodes[-1]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_dag_2_tree.py
DELETED
|
@@ -1,169 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Merge non-tree sub-graphs of the DAG IR into a single DAG. The fused DAG will be implemented
|
| 35 |
-
by the topological visitor, while the rest of the graph will be implemented with the tree visitor.
|
| 36 |
-
"""
|
| 37 |
-
|
| 38 |
-
from copy import deepcopy
|
| 39 |
-
|
| 40 |
-
from cutlass_cppgen.backend.evt.ir import DAGIR, TopoVisitorNode
|
| 41 |
-
from cutlass_cppgen.backend.evt.passes.pass_get_impl import PassGetImpl
|
| 42 |
-
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 43 |
-
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class PassDAG2Tree(EVTPassBase):
|
| 47 |
-
"""
|
| 48 |
-
Convert the DAG IR to Tree by fusing subgraphs
|
| 49 |
-
"""
|
| 50 |
-
dependencies = [
|
| 51 |
-
PassShapeTypePropagation,
|
| 52 |
-
PassGetImpl
|
| 53 |
-
]
|
| 54 |
-
|
| 55 |
-
def call(self):
|
| 56 |
-
# Step 1: find the nodes that have multiple parents
|
| 57 |
-
multi_parent_nodes = []
|
| 58 |
-
|
| 59 |
-
for node in self.dag_ir.nodes_topological_order():
|
| 60 |
-
if self.dag_ir.out_degree(node) > 1:
|
| 61 |
-
multi_parent_nodes.append(node)
|
| 62 |
-
# Step 2: find the lowest common ancestor (LCA) of all its parents
|
| 63 |
-
for node in multi_parent_nodes:
|
| 64 |
-
# A multi-parent node could be already fused by the previous node
|
| 65 |
-
if not self.dag_ir.has_node(node):
|
| 66 |
-
continue
|
| 67 |
-
# A node uncovered by the previous fusions can have out degree change
|
| 68 |
-
# Case 1: it has <= 1 edges to the previously fused subgraph, no degree change
|
| 69 |
-
# Case 2: it has more than one edges to the previously fused subgraph, degree drops
|
| 70 |
-
if self.dag_ir.out_degree(node) <= 1:
|
| 71 |
-
continue
|
| 72 |
-
|
| 73 |
-
# Otherwise, the node still
|
| 74 |
-
reachable_nodes = []
|
| 75 |
-
# Complexity: O(Dout*N)
|
| 76 |
-
for parent in self.dag_ir.get_users(node):
|
| 77 |
-
reachable_nodes.append(set(self.dag_ir.all_reachable_nodes(parent)))
|
| 78 |
-
# get the common reachable objects
|
| 79 |
-
common_items = set.intersection(*reachable_nodes)
|
| 80 |
-
node_to_fuse = set.union(*reachable_nodes).difference(common_items)
|
| 81 |
-
|
| 82 |
-
lca = None
|
| 83 |
-
# If common ancestor exists, find the lowest one
|
| 84 |
-
if len(common_items) > 0:
|
| 85 |
-
topo_order = self.dag_ir.nodes_topological_order()
|
| 86 |
-
topo_idx = -1
|
| 87 |
-
for item in common_items:
|
| 88 |
-
if lca is None:
|
| 89 |
-
lca = item
|
| 90 |
-
topo_idx = topo_order.index(item)
|
| 91 |
-
else:
|
| 92 |
-
if topo_idx > topo_order.index(item):
|
| 93 |
-
lca = item
|
| 94 |
-
topo_idx = topo_order.index(item)
|
| 95 |
-
else:
|
| 96 |
-
# there is no common ancestor for all the parents, we pack all the reachable
|
| 97 |
-
# nodes into a single DAG node as a fallback. The lca should be the input node of
|
| 98 |
-
# one of the output nodes with out_degree = 0
|
| 99 |
-
potential_output_nodes = []
|
| 100 |
-
for node in node_to_fuse:
|
| 101 |
-
if self.dag_ir.out_degree(node) == 0:
|
| 102 |
-
potential_output_nodes.append(node)
|
| 103 |
-
if len(potential_output_nodes) == 0:
|
| 104 |
-
raise RuntimeError(f"No output node with out degree = 0 found.")
|
| 105 |
-
|
| 106 |
-
output_node = None
|
| 107 |
-
if (self.dag_ir.cc >= 90):
|
| 108 |
-
# For SM90+, the lca should be the input node of D
|
| 109 |
-
if (not self.dag_ir.has_node("D")):
|
| 110 |
-
raise RuntimeError(f"D is not a node in the DAG IR.")
|
| 111 |
-
output_node = "D"
|
| 112 |
-
else:
|
| 113 |
-
output_node = potential_output_nodes[0]
|
| 114 |
-
|
| 115 |
-
if (output_node is None):
|
| 116 |
-
raise RuntimeError(f"No output node found.")
|
| 117 |
-
lca = self.dag_ir.get_all_inputs(output_node)[0]
|
| 118 |
-
node_to_fuse.remove(output_node)
|
| 119 |
-
|
| 120 |
-
# The lca is the output node of the DAG node
|
| 121 |
-
# Get the nodes to be fused
|
| 122 |
-
node_to_fuse.add(lca)
|
| 123 |
-
# Get all the input nodes
|
| 124 |
-
all_input_nodes = []
|
| 125 |
-
all_output_nodes = []
|
| 126 |
-
for node in node_to_fuse:
|
| 127 |
-
all_input_nodes.append(set(self.dag_ir.get_all_inputs(node)))
|
| 128 |
-
all_output_nodes.append(set(self.dag_ir.get_users(node)))
|
| 129 |
-
all_input_nodes = set.union(*all_input_nodes)
|
| 130 |
-
all_output_nodes = set.union(*all_output_nodes)
|
| 131 |
-
|
| 132 |
-
new_subgraph_nodes = set.union(node_to_fuse, all_input_nodes, all_output_nodes)
|
| 133 |
-
|
| 134 |
-
# Create the subgraph
|
| 135 |
-
subgraph_ = self.dag_ir._graph.subgraph(new_subgraph_nodes)
|
| 136 |
-
subgraph = DAGIR(self.dag_ir.cc)
|
| 137 |
-
for node in subgraph_.nodes:
|
| 138 |
-
meta = deepcopy(self.dag_ir.get_node_meta(node))
|
| 139 |
-
if node not in node_to_fuse:
|
| 140 |
-
meta.disabled = True
|
| 141 |
-
subgraph.add_node(meta)
|
| 142 |
-
for edge in subgraph_.edges:
|
| 143 |
-
subgraph.add_edge(edge[0], edge[1], self.dag_ir.get_edge_weight(edge[0], edge[1]))
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
# Create the fused node
|
| 147 |
-
dag_node = TopoVisitorNode(
|
| 148 |
-
name=f"dag_{lca}", subgraph=subgraph,
|
| 149 |
-
output_node=self.dag_ir.get_node_meta(lca))
|
| 150 |
-
self.dag_ir.add_node(dag_node)
|
| 151 |
-
|
| 152 |
-
# Add input edges
|
| 153 |
-
for idx, node in enumerate(all_input_nodes):
|
| 154 |
-
self.dag_ir.add_edge(node, dag_node.name, weight=idx)
|
| 155 |
-
|
| 156 |
-
# Replace all uses with DAG node (only 1 output node)
|
| 157 |
-
self.dag_ir.replace_all_uses_with(lca, dag_node.name)
|
| 158 |
-
|
| 159 |
-
# Remove all fused nodes
|
| 160 |
-
node_to_fuse.remove(lca)
|
| 161 |
-
for node in node_to_fuse:
|
| 162 |
-
self.dag_ir.remove_node(node)
|
| 163 |
-
|
| 164 |
-
def ensures(self) -> None:
|
| 165 |
-
# Ensure that after the pass, the resulting DAG becomes a tree
|
| 166 |
-
for node in self.dag_ir.nodes:
|
| 167 |
-
out_degree = self.dag_ir.out_degree(node)
|
| 168 |
-
if out_degree > 1:
|
| 169 |
-
raise RuntimeError(f"PassDAG2Tree failed. Node {node} still have outdegree = {out_degree}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_fix_element_d.py
DELETED
|
@@ -1,64 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Fix the element_output of producer of D.
|
| 35 |
-
|
| 36 |
-
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
| 37 |
-
element converter, so the compute node producing D must have element_output = type(D).
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
from cutlass_cppgen.backend.evt.passes.pass_layout_elimination import PassLayoutManipulateElimination
|
| 41 |
-
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class PassFixElementD(EVTPassBase):
|
| 45 |
-
"""
|
| 46 |
-
In Sm90 epilogue visitor, the node writing D to gmem does not have internal
|
| 47 |
-
element converter, so the compute node producing D must have
|
| 48 |
-
element_output = type(D)
|
| 49 |
-
"""
|
| 50 |
-
dependencies = [
|
| 51 |
-
PassLayoutManipulateElimination
|
| 52 |
-
]
|
| 53 |
-
def get_producer(self, node, element_D):
|
| 54 |
-
node_meta = self.dag_ir.get_node_meta(node)
|
| 55 |
-
if node_meta.op == "compute":
|
| 56 |
-
node_meta.element_output = element_D
|
| 57 |
-
elif node_meta.op == "store":
|
| 58 |
-
self.get_producer(self.dag_ir.get_all_inputs(node)[0], element_D)
|
| 59 |
-
|
| 60 |
-
def call(self):
|
| 61 |
-
if self.dag_ir.has_node("D"):
|
| 62 |
-
node_d_meta = self.dag_ir.get_node_meta("D")
|
| 63 |
-
element_D = node_d_meta.store_tensor.element
|
| 64 |
-
self.get_producer("D", element_D)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_get_impl.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Infer the underlying implement of each node.
|
| 35 |
-
|
| 36 |
-
While the frontend only distinguish between Load/Store/Compute Node,
|
| 37 |
-
each of these nodes can have different underlying implementation based
|
| 38 |
-
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
| 39 |
-
This pass infers the underlying impl of each node
|
| 40 |
-
"""
|
| 41 |
-
|
| 42 |
-
import cutlass_cppgen.backend.evt.backend as evt_backend
|
| 43 |
-
from cutlass_cppgen.backend.evt.ir import DAGIR, LoadNode
|
| 44 |
-
from cutlass_cppgen.backend.evt.passes.pass_fix_element_d import PassFixElementD
|
| 45 |
-
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 46 |
-
from cutlass_cppgen.backend.evt.passes.pass_no_op_elimination import PassNoOpElimination
|
| 47 |
-
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
| 48 |
-
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class PassGetImpl(EVTPassBase):
|
| 52 |
-
"""
|
| 53 |
-
While the frontend only distinguish between Load/Store/Compute Node,
|
| 54 |
-
each of these nodes can have different underlying implementation based
|
| 55 |
-
on their layout. For instance, a LoadNode can be AuxLoad, Row/Col/Scalar broadcast, etc.
|
| 56 |
-
This pass infers the underlying impl of each node
|
| 57 |
-
"""
|
| 58 |
-
dependencies = [
|
| 59 |
-
PassShapeTypePropagation, # The shape and type info are required for inference
|
| 60 |
-
PassFixElementD
|
| 61 |
-
]
|
| 62 |
-
|
| 63 |
-
def __init__(self, dag_ir: DAGIR) -> None:
|
| 64 |
-
super().__init__(dag_ir)
|
| 65 |
-
self.no_op_elimination = PassNoOpElimination(dag_ir)
|
| 66 |
-
|
| 67 |
-
def requires(self) -> None:
|
| 68 |
-
# Verify "accum" is in the arg list
|
| 69 |
-
if not self.dag_ir.has_node("accum"):
|
| 70 |
-
raise SyntaxError("Cannot find 'accum' in the argument list.")
|
| 71 |
-
|
| 72 |
-
def call(self):
|
| 73 |
-
# The loop structure of the epilogue is determined by the
|
| 74 |
-
# accumulator shape
|
| 75 |
-
accumulator: LoadNode = self.dag_ir.get_node_meta("accum")
|
| 76 |
-
problem_size = accumulator.tensor.shape
|
| 77 |
-
|
| 78 |
-
for node_meta in self.dag_ir.node_metas_topological_order():
|
| 79 |
-
node_meta.get_underlying_impl(problem_size)
|
| 80 |
-
|
| 81 |
-
def ensures(self) -> None:
|
| 82 |
-
# Some nodes will be lowered to NoOp, eliminate them
|
| 83 |
-
self.no_op_elimination()
|
| 84 |
-
# Lower to cc-specific impl
|
| 85 |
-
for node_meta in self.dag_ir.nodes_meta:
|
| 86 |
-
node_impl_ccs = getattr(evt_backend, f"sm{cc_map[self.cc]}_nodes")
|
| 87 |
-
node_meta.underlying_impl = getattr(
|
| 88 |
-
node_impl_ccs,
|
| 89 |
-
f"Sm{cc_map[self.cc]}" + node_meta.underlying_impl.__class__.__name__
|
| 90 |
-
)(node_meta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_layout_elimination.py
DELETED
|
@@ -1,217 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Eliminate layout manipulation nodes
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from copy import deepcopy
|
| 38 |
-
|
| 39 |
-
from cutlass_cppgen.backend.evt.ir import DAGIR, LayoutNode
|
| 40 |
-
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 41 |
-
from cutlass_cppgen.backend.evt.passes.pass_shape_type_propagation import PassShapeTypePropagation
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
class PassLayoutManipulateElimination(EVTPassBase):
|
| 45 |
-
"""
|
| 46 |
-
Eliminate layout manipulation nodes
|
| 47 |
-
"""
|
| 48 |
-
dependencies = [PassShapeTypePropagation]
|
| 49 |
-
|
| 50 |
-
def __init__(self, dag_ir: DAGIR) -> None:
|
| 51 |
-
super().__init__(dag_ir)
|
| 52 |
-
self.copy_cnt = 0
|
| 53 |
-
|
| 54 |
-
def call(self):
|
| 55 |
-
self.layout_nodes_worklist = self.get_all_layout_nodes()
|
| 56 |
-
# Run while loop utill all layout nodes are eliminated
|
| 57 |
-
while(len(self.layout_nodes_worklist) > 0):
|
| 58 |
-
node = self.layout_nodes_worklist.pop(0)
|
| 59 |
-
# for node in layout_nodes:
|
| 60 |
-
# Step 1: get the propagation direction
|
| 61 |
-
direction = self.get_propagation_direction(node)
|
| 62 |
-
self.visited = []
|
| 63 |
-
getattr(self, f"propagate_to_{direction}")(self.dag_ir.get_node_meta(node), node)
|
| 64 |
-
# Eliminate the current node
|
| 65 |
-
input_node = self.dag_ir.get_all_inputs(node)[0]
|
| 66 |
-
self.dag_ir.replace_all_uses_with(node, input_node)
|
| 67 |
-
# layout_nodes = self.get_all_layout_nodes()
|
| 68 |
-
|
| 69 |
-
def get_all_layout_nodes(self):
|
| 70 |
-
layout_nodes = []
|
| 71 |
-
for node_meta in reversed(self.dag_ir.node_metas_topological_order()):
|
| 72 |
-
if isinstance(node_meta, LayoutNode):
|
| 73 |
-
layout_nodes.append(node_meta.name)
|
| 74 |
-
return layout_nodes
|
| 75 |
-
|
| 76 |
-
def get_propagation_direction(self, node: str):
|
| 77 |
-
"""
|
| 78 |
-
The logic is propagating all layout nodes away from the accumulator node.
|
| 79 |
-
"""
|
| 80 |
-
self.visited = []
|
| 81 |
-
self.get_influenced_users(node)
|
| 82 |
-
nodes_influenced_dir_users = self.visited
|
| 83 |
-
self.visited = []
|
| 84 |
-
self.get_influenced_inputs(node)
|
| 85 |
-
nodes_influenced_dir_inputs = self.visited
|
| 86 |
-
|
| 87 |
-
if "accum" in nodes_influenced_dir_users and "accum" not in nodes_influenced_dir_inputs:
|
| 88 |
-
return "inputs"
|
| 89 |
-
elif "accum" not in nodes_influenced_dir_users and "accum" in nodes_influenced_dir_inputs:
|
| 90 |
-
return "users"
|
| 91 |
-
else:
|
| 92 |
-
raise RuntimeError("Unsolved propagation direction")
|
| 93 |
-
|
| 94 |
-
# Get all influenced nodes if we propagate along the user direction
|
| 95 |
-
def get_influenced_users(self, node: str):
|
| 96 |
-
if node in self.visited:
|
| 97 |
-
return
|
| 98 |
-
self.visited.append(node)
|
| 99 |
-
|
| 100 |
-
users = self.dag_ir.get_users(node)
|
| 101 |
-
for user in users:
|
| 102 |
-
self.get_influenced_users(user)
|
| 103 |
-
user_inputs = []
|
| 104 |
-
for user in users:
|
| 105 |
-
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
| 106 |
-
if len(user_inputs) > 0:
|
| 107 |
-
user_inputs = set.union(*user_inputs)
|
| 108 |
-
user_inputs.remove(node)
|
| 109 |
-
for input in user_inputs:
|
| 110 |
-
self.get_influenced_inputs(input)
|
| 111 |
-
|
| 112 |
-
# Get all influenced nodes if we propagate along the input direction
|
| 113 |
-
def get_influenced_inputs(self, node: str):
|
| 114 |
-
if node in self.visited:
|
| 115 |
-
return
|
| 116 |
-
self.visited.append(node)
|
| 117 |
-
|
| 118 |
-
inputs = self.dag_ir.get_all_inputs(node)
|
| 119 |
-
for input in inputs:
|
| 120 |
-
self.get_influenced_inputs(input)
|
| 121 |
-
input_users = []
|
| 122 |
-
for input in inputs:
|
| 123 |
-
input_users.append(set(self.dag_ir.get_users(input)))
|
| 124 |
-
if len(input_users) > 0:
|
| 125 |
-
input_users = set.union(*input_users)
|
| 126 |
-
input_users.remove(node)
|
| 127 |
-
for user in input_users:
|
| 128 |
-
self.get_influenced_users(user)
|
| 129 |
-
|
| 130 |
-
def add_copy_before(self, layout_node_meta: LayoutNode, target: str):
|
| 131 |
-
copied_node_meta = deepcopy(layout_node_meta)
|
| 132 |
-
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
| 133 |
-
self.copy_cnt += 1
|
| 134 |
-
copied_node_meta.name = copied_node
|
| 135 |
-
self.dag_ir.add_node(copied_node_meta)
|
| 136 |
-
# Add edges
|
| 137 |
-
target_inputs = self.dag_ir.get_all_inputs(target)
|
| 138 |
-
for src in target_inputs:
|
| 139 |
-
self.dag_ir.remove_edge(src, target)
|
| 140 |
-
self.dag_ir.add_edge(src, copied_node)
|
| 141 |
-
self.dag_ir.add_edge(copied_node, target)
|
| 142 |
-
self.layout_nodes_worklist.append(copied_node)
|
| 143 |
-
|
| 144 |
-
def add_copy_after(self, layout_node_meta: LayoutNode, target: str):
|
| 145 |
-
copied_node_meta = deepcopy(layout_node_meta)
|
| 146 |
-
copied_node = f"{copied_node_meta.name}_copy{self.copy_cnt}"
|
| 147 |
-
self.copy_cnt += 1
|
| 148 |
-
copied_node_meta.name = copied_node
|
| 149 |
-
self.dag_ir.add_node(copied_node_meta)
|
| 150 |
-
# Add edges
|
| 151 |
-
users = self.dag_ir.get_users(target)
|
| 152 |
-
for user in users:
|
| 153 |
-
self.dag_ir.remove_edge(target, user)
|
| 154 |
-
self.dag_ir.add_edge(copied_node, user)
|
| 155 |
-
self.dag_ir.add_edge(target, copied_node)
|
| 156 |
-
self.layout_nodes_worklist.append(copied_node)
|
| 157 |
-
|
| 158 |
-
# Propagate the layout `node` along the user direction
|
| 159 |
-
def propagate_to_users(self, layout_node_meta: LayoutNode, node: str):
|
| 160 |
-
"""
|
| 161 |
-
Propagate layout node to users
|
| 162 |
-
"""
|
| 163 |
-
if node in self.visited:
|
| 164 |
-
# Avoid applying twice
|
| 165 |
-
return
|
| 166 |
-
self.visited.append(node)
|
| 167 |
-
|
| 168 |
-
node_meta = self.dag_ir.get_node_meta(node)
|
| 169 |
-
if layout_node_meta.name != node:
|
| 170 |
-
if isinstance(node_meta, LayoutNode):
|
| 171 |
-
# Layout node is not transparent with layout node
|
| 172 |
-
self.add_copy_before(layout_node_meta, node)
|
| 173 |
-
return
|
| 174 |
-
else:
|
| 175 |
-
layout_node_meta.apply_to_user(node_meta)
|
| 176 |
-
|
| 177 |
-
users = self.dag_ir.get_users(node)
|
| 178 |
-
user_inputs = []
|
| 179 |
-
for user in users:
|
| 180 |
-
user_inputs.append(set(self.dag_ir.get_all_inputs(user)))
|
| 181 |
-
for user in users:
|
| 182 |
-
self.propagate_to_users(layout_node_meta, user)
|
| 183 |
-
if len(user_inputs) > 0:
|
| 184 |
-
user_inputs = set.union(*user_inputs)
|
| 185 |
-
user_inputs.remove(node)
|
| 186 |
-
for input in user_inputs:
|
| 187 |
-
self.propagate_to_inputs(layout_node_meta.get_inverse_node(), input)
|
| 188 |
-
|
| 189 |
-
# Propagate the layout `node` along the input direction
|
| 190 |
-
def propagate_to_inputs(self, layout_node_meta: LayoutNode, node: str):
|
| 191 |
-
"""
|
| 192 |
-
Propagate layout node to inputs
|
| 193 |
-
"""
|
| 194 |
-
if node in self.visited:
|
| 195 |
-
# Avoid applying twice
|
| 196 |
-
return
|
| 197 |
-
self.visited.append(node)
|
| 198 |
-
|
| 199 |
-
node_meta = self.dag_ir.get_node_meta(node)
|
| 200 |
-
if layout_node_meta.name != node:
|
| 201 |
-
if isinstance(node_meta, LayoutNode):
|
| 202 |
-
# Layout node is not transparent with layout node
|
| 203 |
-
self.add_copy_after(layout_node_meta, node)
|
| 204 |
-
return
|
| 205 |
-
else:
|
| 206 |
-
layout_node_meta.apply_to_input(node_meta)
|
| 207 |
-
inputs = self.dag_ir.get_all_inputs(node)
|
| 208 |
-
input_users = []
|
| 209 |
-
for input in inputs:
|
| 210 |
-
input_users.append(set(self.dag_ir.get_users(input)))
|
| 211 |
-
for input in inputs:
|
| 212 |
-
self.propagate_to_inputs(layout_node_meta, input)
|
| 213 |
-
if len(input_users) > 0:
|
| 214 |
-
input_users = set.union(*input_users)
|
| 215 |
-
input_users.remove(node)
|
| 216 |
-
for user in input_users:
|
| 217 |
-
self.propagate_to_users(layout_node_meta.get_inverse_node(), user)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_manager.py
DELETED
|
@@ -1,164 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Pass manager for DAG IR.
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from typing import Any
|
| 38 |
-
|
| 39 |
-
import networkx as nx
|
| 40 |
-
|
| 41 |
-
from cutlass_cppgen.backend.evt.ir import DAGIR
|
| 42 |
-
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class EVTPassBase:
|
| 46 |
-
"""
|
| 47 |
-
Base class for EVT Passes
|
| 48 |
-
"""
|
| 49 |
-
dependencies = []
|
| 50 |
-
def __init__(self, dag_ir: DAGIR) -> None:
|
| 51 |
-
self.dag_ir = dag_ir
|
| 52 |
-
self.cc = self.dag_ir.cc
|
| 53 |
-
|
| 54 |
-
def requires(self) -> None:
|
| 55 |
-
"""
|
| 56 |
-
This function will be called before the pass is run.
|
| 57 |
-
"""
|
| 58 |
-
pass
|
| 59 |
-
|
| 60 |
-
def call(self) -> None:
|
| 61 |
-
"""
|
| 62 |
-
The pass that is run through the self.dag_ir
|
| 63 |
-
"""
|
| 64 |
-
raise NotImplementedError(
|
| 65 |
-
f"__call__ is not overwritten in Pass {self.__class__.__name__}")
|
| 66 |
-
|
| 67 |
-
def ensures(self) -> None:
|
| 68 |
-
"""
|
| 69 |
-
This function will be called after the pass is run.
|
| 70 |
-
"""
|
| 71 |
-
pass
|
| 72 |
-
|
| 73 |
-
def __call__(self) -> Any:
|
| 74 |
-
self.requires()
|
| 75 |
-
self.call()
|
| 76 |
-
self.ensures()
|
| 77 |
-
|
| 78 |
-
def cc_specific_method(self, func):
|
| 79 |
-
"""
|
| 80 |
-
This enables defining function that behaves differently under different cc
|
| 81 |
-
The simplest example of using this function is the following
|
| 82 |
-
|
| 83 |
-
.. highlight:: python
|
| 84 |
-
.. code-block:: python
|
| 85 |
-
|
| 86 |
-
class ExamplePass(EVTPassBase):
|
| 87 |
-
|
| 88 |
-
def call(sekf):
|
| 89 |
-
# This automatically select the smXX_func based on current cc
|
| 90 |
-
self.cc_specific_method(self.func)()
|
| 91 |
-
|
| 92 |
-
# Interface func, can be empty
|
| 93 |
-
def func(self):
|
| 94 |
-
pass
|
| 95 |
-
|
| 96 |
-
# Sm90 specific func
|
| 97 |
-
def sm90_func(self):
|
| 98 |
-
// sm90 specific method
|
| 99 |
-
return
|
| 100 |
-
|
| 101 |
-
# Sm80 specific func
|
| 102 |
-
def sm80_func(self):
|
| 103 |
-
// sm80 specific method
|
| 104 |
-
return
|
| 105 |
-
"""
|
| 106 |
-
func_name = f"sm{cc_map[self.cc]}_{func.__name__}"
|
| 107 |
-
if hasattr(self, func_name):
|
| 108 |
-
return getattr(self, func_name)
|
| 109 |
-
else:
|
| 110 |
-
raise NotImplementedError(f"func {func.__name__} is not overwritten for Sm{self.cc}")
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
class EVTPassManager(nx.DiGraph):
|
| 114 |
-
"""
|
| 115 |
-
Topological-based Pass Manager.
|
| 116 |
-
Each registered pass has a list of dependencies. The pass manager organizes
|
| 117 |
-
the passes as a DAG and launch the compiler passes under topological order.
|
| 118 |
-
"""
|
| 119 |
-
def __init__(self, dag_ir: DAGIR, pass_list):
|
| 120 |
-
super().__init__()
|
| 121 |
-
self.dag_ir = dag_ir
|
| 122 |
-
for pass_cls in pass_list:
|
| 123 |
-
self.add_pass(pass_cls)
|
| 124 |
-
|
| 125 |
-
self.sorted_passes = self.schedule()
|
| 126 |
-
|
| 127 |
-
def get_callable(self, pass_name):
|
| 128 |
-
"""
|
| 129 |
-
Return the callable of the pass
|
| 130 |
-
"""
|
| 131 |
-
return self.nodes[pass_name]["callable"]
|
| 132 |
-
|
| 133 |
-
def add_pass(self, pass_cls):
|
| 134 |
-
"""
|
| 135 |
-
Add a pass to the pass manager
|
| 136 |
-
:param pass_cls: the class of pass
|
| 137 |
-
:type pass_cls: derived class of EVTPassBase
|
| 138 |
-
"""
|
| 139 |
-
name = pass_cls.__name__
|
| 140 |
-
pass_callable = pass_cls(self.dag_ir)
|
| 141 |
-
self.add_node(name, callable=pass_callable)
|
| 142 |
-
|
| 143 |
-
def schedule(self):
|
| 144 |
-
"""
|
| 145 |
-
Schedule the added passes under topological order
|
| 146 |
-
"""
|
| 147 |
-
# Add edges
|
| 148 |
-
for pass_name in self.nodes:
|
| 149 |
-
callable = self.get_callable(pass_name)
|
| 150 |
-
for dependency_cls in callable.dependencies:
|
| 151 |
-
self.add_edge(
|
| 152 |
-
dependency_cls.__name__,
|
| 153 |
-
type(callable).__name__)
|
| 154 |
-
|
| 155 |
-
# Topological sort
|
| 156 |
-
return list(nx.topological_sort(self))
|
| 157 |
-
|
| 158 |
-
def __call__(self) -> Any:
|
| 159 |
-
"""
|
| 160 |
-
Launch the registered passes
|
| 161 |
-
"""
|
| 162 |
-
for pass_name in self.sorted_passes:
|
| 163 |
-
callable = self.get_callable(pass_name)
|
| 164 |
-
callable()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_no_op_elimination.py
DELETED
|
@@ -1,53 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
No op elimination node
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from typing import Any
|
| 38 |
-
|
| 39 |
-
from cutlass_cppgen.backend.evt.ir import NoOpImpl
|
| 40 |
-
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class PassNoOpElimination(EVTPassBase):
|
| 44 |
-
"""
|
| 45 |
-
The dead node elimination pass removes nodes with NoOpImpl in DAG IR
|
| 46 |
-
"""
|
| 47 |
-
dependencies = []
|
| 48 |
-
|
| 49 |
-
def call(self) -> Any:
|
| 50 |
-
for node in self.dag_ir.nodes_topological_order():
|
| 51 |
-
node_meta = self.dag_ir.get_node_meta(node)
|
| 52 |
-
if isinstance(node_meta.underlying_impl, NoOpImpl):
|
| 53 |
-
self.dag_ir.replace_all_uses_with(node, self.dag_ir.get_all_inputs(node)[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_preprocess_red.py
DELETED
|
@@ -1,97 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Preprocess the reduction nodes.
|
| 35 |
-
|
| 36 |
-
The parser treats reduction as Compute(op=(reg_reduce_fn, gmem_reduce_fn)) - Store()
|
| 37 |
-
This pass fuses these into a single store node, and then replaces all uses of the
|
| 38 |
-
current node with the new store node.
|
| 39 |
-
"""
|
| 40 |
-
|
| 41 |
-
from cutlass_cppgen.backend.evt.ir import ComputeNode, StoreNode
|
| 42 |
-
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class PassPreprocessRed(EVTPassBase):
|
| 46 |
-
"""
|
| 47 |
-
Preprocess red nodes
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
def call(self):
|
| 51 |
-
# Step 1: find the compute nodes with op=red
|
| 52 |
-
red_compute_nodes = []
|
| 53 |
-
for node_meta in self.dag_ir.nodes_meta:
|
| 54 |
-
if isinstance(node_meta, ComputeNode):
|
| 55 |
-
if type(node_meta.fn) == tuple:
|
| 56 |
-
# To keep the frontend simple, the reduction nodes
|
| 57 |
-
# are parsed into compute nodes by default
|
| 58 |
-
# The simple heuristic to distinguish between compute
|
| 59 |
-
# and reduction node is that compute node is a single function,
|
| 60 |
-
# while the reduction node is a tuple of functions for
|
| 61 |
-
# in-register reduction and atomic global memory reduction
|
| 62 |
-
red_compute_nodes.append(node_meta.name)
|
| 63 |
-
|
| 64 |
-
# Step 2: for each compute, merge it with the succeeding store
|
| 65 |
-
for node in red_compute_nodes:
|
| 66 |
-
# Verify
|
| 67 |
-
users = self.dag_ir.get_users(node)
|
| 68 |
-
inputs = self.dag_ir.get_all_inputs(node)
|
| 69 |
-
# Has a single user
|
| 70 |
-
assert len(users) == 1
|
| 71 |
-
assert len(inputs) == 1
|
| 72 |
-
user = users[0]
|
| 73 |
-
input = inputs[0]
|
| 74 |
-
|
| 75 |
-
user_meta = self.dag_ir.get_node_meta(user)
|
| 76 |
-
# Must be a store node
|
| 77 |
-
assert isinstance(user_meta, StoreNode)
|
| 78 |
-
# With output degree == 0
|
| 79 |
-
assert self.dag_ir.out_degree(user) == 0
|
| 80 |
-
# Register the reduce op
|
| 81 |
-
node_meta = self.dag_ir.get_node_meta(node)
|
| 82 |
-
user_meta.reg_reduce_fn, user_meta.gmem_reduce_fn = node_meta.fn
|
| 83 |
-
user_meta.element_compute = node_meta.element_compute
|
| 84 |
-
user_meta.round_style = node_meta.round_style
|
| 85 |
-
|
| 86 |
-
# Replace all uses
|
| 87 |
-
self.dag_ir.remove_edge(input, node)
|
| 88 |
-
input_users = self.dag_ir.get_users(input)
|
| 89 |
-
for iu in input_users:
|
| 90 |
-
weight = self.dag_ir.get_edge_weight(input, iu)
|
| 91 |
-
self.dag_ir.add_edge(user, iu, weight)
|
| 92 |
-
self.dag_ir.remove_edge(input, iu)
|
| 93 |
-
self.dag_ir.add_edge(input, user)
|
| 94 |
-
self.dag_ir.remove_node(node)
|
| 95 |
-
|
| 96 |
-
# Register the reduction name
|
| 97 |
-
self.dag_ir.reduction_names.append(user)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/pass_shape_type_propagation.py
DELETED
|
@@ -1,59 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Shape and type propagation pass
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from cutlass_cppgen.backend.evt.ir.node import NodeBase
|
| 38 |
-
from cutlass_cppgen.backend.evt.passes.pass_manager import EVTPassBase
|
| 39 |
-
from cutlass_cppgen.backend.evt.passes.pass_preprocess_red import PassPreprocessRed
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class PassShapeTypePropagation(EVTPassBase):
|
| 43 |
-
"""
|
| 44 |
-
Propagate the shape and type of all nodes
|
| 45 |
-
"""
|
| 46 |
-
dependencies = [PassPreprocessRed]
|
| 47 |
-
|
| 48 |
-
def call(self):
|
| 49 |
-
# Propagate the node shape and type
|
| 50 |
-
for node in self.dag_ir.nodes_topological_order():
|
| 51 |
-
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
| 52 |
-
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
| 53 |
-
node_meta.type_propagation(input_node_metas)
|
| 54 |
-
node_meta.shape_propagation(input_node_metas)
|
| 55 |
-
|
| 56 |
-
for node in reversed(self.dag_ir.nodes_topological_order()):
|
| 57 |
-
node_meta: NodeBase = self.dag_ir.get_node_meta(node)
|
| 58 |
-
input_node_metas = self.dag_ir.get_all_inputs_meta(node)
|
| 59 |
-
node_meta.broadcast_propagation(input_node_metas)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/smem_size_calculator.py
DELETED
|
@@ -1,319 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Compute the shared memory size in bytes
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from math import gcd
|
| 38 |
-
|
| 39 |
-
import cutlass_library
|
| 40 |
-
from pycute import flatten, shape_div, product
|
| 41 |
-
|
| 42 |
-
import cutlass_cppgen
|
| 43 |
-
from cutlass_cppgen.backend.evt.ir import TopoVisitorNode, DAGIR
|
| 44 |
-
from cutlass_cppgen.backend.library import DataType, DataTypeSize
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
class GetSmemSize:
|
| 48 |
-
"""
|
| 49 |
-
Get the size in byte of shared memory used by the kernel
|
| 50 |
-
"""
|
| 51 |
-
def __init__(self, dag_ir: DAGIR) -> None:
|
| 52 |
-
self.dag_ir = dag_ir
|
| 53 |
-
self.cc = self.dag_ir.cc
|
| 54 |
-
|
| 55 |
-
#
|
| 56 |
-
# Sm90 epilogue specific
|
| 57 |
-
#
|
| 58 |
-
|
| 59 |
-
def sm90_epilogue_tile(self, tile_description):
|
| 60 |
-
# Get the epilogue tile size
|
| 61 |
-
schedule = tile_description.epilogue_schedule
|
| 62 |
-
if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized:
|
| 63 |
-
element_d = self.dag_ir.get_node_meta("D").element
|
| 64 |
-
nperf = 64 if (DataTypeSize[element_d] == 8 and tile_description.threadblock_shape[1] % 64 == 0) else 32
|
| 65 |
-
epi_tile_m = min(64, tile_description.threadblock_shape[0])
|
| 66 |
-
epi_tile_n = gcd(min(nperf, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
|
| 67 |
-
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
|
| 68 |
-
elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative:
|
| 69 |
-
epi_tile_m = min(128, tile_description.threadblock_shape[0])
|
| 70 |
-
epi_tile_n = gcd(min(32, tile_description.threadblock_shape[1]), tile_description.threadblock_shape[1])
|
| 71 |
-
epilogue_tile_mn = (epi_tile_m, epi_tile_n)
|
| 72 |
-
else:
|
| 73 |
-
raise NotImplementedError(f"Unsupported schedule: {schedule}")
|
| 74 |
-
|
| 75 |
-
# Get the pipeline stages
|
| 76 |
-
stages_d = 2
|
| 77 |
-
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
|
| 78 |
-
if self.dag_ir.has_node("C"):
|
| 79 |
-
element_c = self.dag_ir.get_node_meta("C").element
|
| 80 |
-
else:
|
| 81 |
-
element_c = None
|
| 82 |
-
|
| 83 |
-
element_d = self.dag_ir.get_node_meta("D").element
|
| 84 |
-
if element_c == element_d:
|
| 85 |
-
reuse_smem_c = True
|
| 86 |
-
else:
|
| 87 |
-
reuse_smem_c = False
|
| 88 |
-
stages_c = max(epi_tiles, stages_d + 1) if reuse_smem_c else epi_tiles
|
| 89 |
-
|
| 90 |
-
# Record the epilogue tile
|
| 91 |
-
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
| 92 |
-
self.epilogue_tile_mn = epilogue_tile_mn
|
| 93 |
-
self.epi_tiles = epi_tiles
|
| 94 |
-
self.stages_c = stages_c
|
| 95 |
-
self.stages_d = stages_d
|
| 96 |
-
self.reuse_smem_c = reuse_smem_c
|
| 97 |
-
self.element_c = element_c
|
| 98 |
-
self.element_d = element_d
|
| 99 |
-
self.is_source_supported = element_c is not None
|
| 100 |
-
|
| 101 |
-
def sm90_or_sm100_epilogue_smem_size(self, tile_description):
|
| 102 |
-
# Get the Fusion Storage
|
| 103 |
-
nodes = self.dag_ir.nodes_topological_order()
|
| 104 |
-
self.smem_types = {}
|
| 105 |
-
for node in nodes:
|
| 106 |
-
meta = self.dag_ir.get_node_meta(node)
|
| 107 |
-
if not meta.disabled:
|
| 108 |
-
self.smem_types[node] = meta.underlying_impl.get_smem_size(
|
| 109 |
-
self.cta_tile_mnk, self.epilogue_tile_mn,
|
| 110 |
-
self.stages_c, self.stages_d, self.epi_tiles)
|
| 111 |
-
if node == "D":
|
| 112 |
-
continue
|
| 113 |
-
if isinstance(meta, TopoVisitorNode):
|
| 114 |
-
self.get_dag_smem_type(node)
|
| 115 |
-
else:
|
| 116 |
-
self.get_evt_smem_type(node)
|
| 117 |
-
|
| 118 |
-
thread_smem_size = self.smem_types[self.dag_ir.get_all_inputs("D")[0]][0]
|
| 119 |
-
# Get the Tensor Storage
|
| 120 |
-
tensors = []
|
| 121 |
-
if self.is_source_supported:
|
| 122 |
-
smem_C = DataTypeSize[self.element_c] * product(self.epilogue_tile_mn) * self.stages_c // 8
|
| 123 |
-
tensors.append((smem_C, 128))
|
| 124 |
-
else:
|
| 125 |
-
tensors.append((0, 1))
|
| 126 |
-
if self.reuse_smem_c:
|
| 127 |
-
tensors.append((0, 128))
|
| 128 |
-
else:
|
| 129 |
-
smem_D = DataTypeSize[self.element_d] * product(self.epilogue_tile_mn) * self.stages_d // 8
|
| 130 |
-
tensors.append((smem_D, 128))
|
| 131 |
-
tensors.append((thread_smem_size, 128))
|
| 132 |
-
|
| 133 |
-
tensor_smem_size = self.get_struct_size(tensors)
|
| 134 |
-
# Get pipeline storage size
|
| 135 |
-
# sizeof(uint64_t * stages_c * 2), alignment of uint64_t
|
| 136 |
-
# 2 is for FullBarrier and EmptyBarrier
|
| 137 |
-
pipeline_smem_size = (8 * self.stages_c * 2, 8)
|
| 138 |
-
|
| 139 |
-
# get SharedStorage size
|
| 140 |
-
smem_size = self.get_struct_size([tensor_smem_size, pipeline_smem_size])
|
| 141 |
-
return smem_size[0]
|
| 142 |
-
|
| 143 |
-
def sm90_epilogue_smem_size(self, tile_description):
|
| 144 |
-
"""
|
| 145 |
-
Compute the shared memory size of sm90 collective epilogue
|
| 146 |
-
"""
|
| 147 |
-
self.sm90_epilogue_tile(tile_description)
|
| 148 |
-
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
|
| 149 |
-
|
| 150 |
-
#
|
| 151 |
-
# Sm100 epilogue specific
|
| 152 |
-
#
|
| 153 |
-
|
| 154 |
-
def sm100_epilogue_tile(self, tile_description):
|
| 155 |
-
cta_tile = (tile_description.blackwell_threadblock_shape[0], tile_description.blackwell_threadblock_shape[1])
|
| 156 |
-
mma_tile = cta_tile
|
| 157 |
-
|
| 158 |
-
if tile_description.is_2sm:
|
| 159 |
-
cta_tile = (cta_tile[0] // 2, cta_tile[1])
|
| 160 |
-
|
| 161 |
-
if tile_description.is_2sm and mma_tile[0] == 128:
|
| 162 |
-
tmem_warps = (2, 2)
|
| 163 |
-
else:
|
| 164 |
-
tmem_warps = (4, 1)
|
| 165 |
-
|
| 166 |
-
if self.dag_ir.has_node("C"):
|
| 167 |
-
element_c = self.dag_ir.get_node_meta("C").element
|
| 168 |
-
element_c_size = DataTypeSize[element_c]
|
| 169 |
-
else:
|
| 170 |
-
element_c = None
|
| 171 |
-
element_c_size = 0
|
| 172 |
-
|
| 173 |
-
element_d = self.dag_ir.get_node_meta("D").element
|
| 174 |
-
|
| 175 |
-
DisableSource = element_c is None or not self.dag_ir.has_node("C") or self.dag_ir.get_node_meta("C").element == DataType.void
|
| 176 |
-
|
| 177 |
-
CtaM = cta_tile[0]
|
| 178 |
-
CtaN = cta_tile[1]
|
| 179 |
-
WarpM = tmem_warps[0]
|
| 180 |
-
WarpN = tmem_warps[1]
|
| 181 |
-
MaxBits = max(element_c_size, DataTypeSize[element_d])
|
| 182 |
-
DpFull = 32
|
| 183 |
-
M = min(CtaM, DpFull * WarpM)
|
| 184 |
-
|
| 185 |
-
if DisableSource:
|
| 186 |
-
# Epilogues w/o residual load are less sensitive to smem allocation
|
| 187 |
-
# Target a fixed amount of compute per epilogue iteration
|
| 188 |
-
if MaxBits == 4:
|
| 189 |
-
# Make epilogue tile larger to reduce the epilogue iterations.
|
| 190 |
-
# 64 is the experimental value. It will minimize epilogue iterations but keep the number of A/B buffers the same.
|
| 191 |
-
ComputeElts = 8192
|
| 192 |
-
Nperf = ComputeElts // M
|
| 193 |
-
else:
|
| 194 |
-
ComputeElts = 4096
|
| 195 |
-
Nperf = ComputeElts // M
|
| 196 |
-
else:
|
| 197 |
-
# Epilogues w/ residual load are more sensitive to smem allocation
|
| 198 |
-
# Target optimal smem distribution between epilogue+mainloop based on datatype+tilesize
|
| 199 |
-
if MaxBits == 32:
|
| 200 |
-
Nperf = 16 if CtaM > 64 and CtaN <= 128 else 32
|
| 201 |
-
elif MaxBits == 16:
|
| 202 |
-
Nperf = 32 if CtaN <= 128 else 64
|
| 203 |
-
else:
|
| 204 |
-
Nperf = 64
|
| 205 |
-
|
| 206 |
-
def is_m_major(layout):
|
| 207 |
-
return flatten(layout.stride[0]) == 1
|
| 208 |
-
|
| 209 |
-
if DisableSource or is_m_major(self.dag_ir.get_node_meta("C").tensor.layout):
|
| 210 |
-
N_min_C = 8 * WarpN
|
| 211 |
-
elif element_c_size == 6:
|
| 212 |
-
N_min_C = 128 * WarpN
|
| 213 |
-
else:
|
| 214 |
-
N_min_C = (128 // element_c_size) * WarpN
|
| 215 |
-
|
| 216 |
-
if is_m_major(self.dag_ir.get_node_meta("D").tensor.layout):
|
| 217 |
-
N_min_D = 8 * WarpN
|
| 218 |
-
elif DataTypeSize[element_d] == 6:
|
| 219 |
-
N_min_D = 128 * WarpN
|
| 220 |
-
else:
|
| 221 |
-
N_min_D = (128 // DataTypeSize[element_d]) * WarpN
|
| 222 |
-
|
| 223 |
-
N = min(CtaN, max(Nperf, N_min_C, N_min_D))
|
| 224 |
-
|
| 225 |
-
tile_m = M
|
| 226 |
-
tile_n_size = N // WarpN * WarpN
|
| 227 |
-
|
| 228 |
-
epilogue_tile_mn = (tile_m, tile_n_size)
|
| 229 |
-
epi_tiles = product(shape_div(tuple(tile_description.threadblock_shape)[:2], epilogue_tile_mn))
|
| 230 |
-
|
| 231 |
-
stages_d = min(epi_tiles, 2)
|
| 232 |
-
reuse_smem_c = (element_c_size > 8)
|
| 233 |
-
|
| 234 |
-
if reuse_smem_c:
|
| 235 |
-
stages_c = max(min(epi_tiles, 4), stages_d + 1)
|
| 236 |
-
else:
|
| 237 |
-
stages_c = min(epi_tiles, 4)
|
| 238 |
-
|
| 239 |
-
# Record the epilogue tile
|
| 240 |
-
self.cta_tile_mnk = tuple(tile_description.threadblock_shape)
|
| 241 |
-
self.epilogue_tile_mn = epilogue_tile_mn
|
| 242 |
-
self.epi_tiles = epi_tiles
|
| 243 |
-
self.stages_c = stages_c
|
| 244 |
-
self.stages_d = stages_d
|
| 245 |
-
self.reuse_smem_c = reuse_smem_c
|
| 246 |
-
self.element_c = element_c
|
| 247 |
-
self.element_d = element_d
|
| 248 |
-
self.is_source_supported = not DisableSource
|
| 249 |
-
|
| 250 |
-
def sm100_epilogue_smem_size(self, tile_description):
|
| 251 |
-
"""
|
| 252 |
-
Compute the shared memory size of sm100 collective epilogue
|
| 253 |
-
"""
|
| 254 |
-
self.sm100_epilogue_tile(tile_description)
|
| 255 |
-
return self.sm90_or_sm100_epilogue_smem_size(tile_description)
|
| 256 |
-
|
| 257 |
-
def __call__(self, tile_description):
|
| 258 |
-
return getattr(self, f"sm{self.cc}_epilogue_smem_size")(tile_description)
|
| 259 |
-
|
| 260 |
-
#
|
| 261 |
-
# Helper functions
|
| 262 |
-
#
|
| 263 |
-
|
| 264 |
-
@staticmethod
|
| 265 |
-
def get_visitor_size(members: list, ebo: bool):
|
| 266 |
-
"""
|
| 267 |
-
Get the size of struct in bytes
|
| 268 |
-
"""
|
| 269 |
-
offset = 0
|
| 270 |
-
max_alignment = 1
|
| 271 |
-
if len(members) > 0:
|
| 272 |
-
# Get alignment
|
| 273 |
-
for _, alignment in members:
|
| 274 |
-
max_alignment = max(max_alignment, alignment)
|
| 275 |
-
|
| 276 |
-
for type_size, _ in members:
|
| 277 |
-
if type_size != 0:
|
| 278 |
-
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
| 279 |
-
if type_size == 0 and not ebo:
|
| 280 |
-
offset += 1
|
| 281 |
-
else:
|
| 282 |
-
offset += type_size
|
| 283 |
-
offset = ((offset + max_alignment - 1) // max_alignment) * max_alignment
|
| 284 |
-
return (offset, max_alignment)
|
| 285 |
-
else:
|
| 286 |
-
# Struct size is at least 1
|
| 287 |
-
return (1, 1)
|
| 288 |
-
|
| 289 |
-
def get_struct_size(self, members: list):
|
| 290 |
-
"""
|
| 291 |
-
Get the size of struct in bytes
|
| 292 |
-
"""
|
| 293 |
-
return self.get_visitor_size(members, False)
|
| 294 |
-
|
| 295 |
-
def get_evt_smem_type(self, node):
|
| 296 |
-
# Sort the input nodes by edge weight
|
| 297 |
-
input_types = [self.smem_types[child] for child in self.dag_ir.get_all_inputs(node)]
|
| 298 |
-
input_types.append(self.smem_types[node])
|
| 299 |
-
if len(input_types) > 1:
|
| 300 |
-
ebo = len(input_types) > 4
|
| 301 |
-
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
| 302 |
-
|
| 303 |
-
def get_dag_smem_type(self, node):
|
| 304 |
-
meta = self.dag_ir.get_node_meta(node)
|
| 305 |
-
subgraph = meta.subgraph
|
| 306 |
-
subgraph_nodes = subgraph.nodes_topological_order()
|
| 307 |
-
# Visit the unvisited nodes in subgraph
|
| 308 |
-
for n in subgraph_nodes:
|
| 309 |
-
m = subgraph.get_node_meta(n)
|
| 310 |
-
if m.disabled:
|
| 311 |
-
continue
|
| 312 |
-
else:
|
| 313 |
-
self.smem_types[n] = m.underlying_impl.get_smem_size(
|
| 314 |
-
self.cta_tile_mnk, self.epilogue_tile_mn,
|
| 315 |
-
self.stages_c, self.stages_d, self.epi_tiles)
|
| 316 |
-
input_types = [self.smem_types[child] for child in subgraph_nodes[:-1]]
|
| 317 |
-
if len(input_types) > 0:
|
| 318 |
-
ebo = len(input_types) > 4
|
| 319 |
-
self.smem_types[node] = self.get_visitor_size(input_types, ebo)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/evt/passes/util.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Utilities for passes
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
# Map from the CC of the kernel to the EVT implementation that the CC targets
|
| 38 |
-
cc_map = {
|
| 39 |
-
80: 80,
|
| 40 |
-
86: 80,
|
| 41 |
-
89: 80,
|
| 42 |
-
90: 90,
|
| 43 |
-
100: 100,
|
| 44 |
-
101: 100,
|
| 45 |
-
103: 100,
|
| 46 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/frontend.py
DELETED
|
@@ -1,109 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
from __future__ import annotations
|
| 33 |
-
|
| 34 |
-
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 35 |
-
cuda = lazy_import("cuda.cuda")
|
| 36 |
-
import numpy as np
|
| 37 |
-
|
| 38 |
-
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
|
| 39 |
-
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class NumpyFrontend:
|
| 43 |
-
"""
|
| 44 |
-
Frontend node for numpy
|
| 45 |
-
"""
|
| 46 |
-
|
| 47 |
-
@staticmethod
|
| 48 |
-
def argument(np_tensor: "np.ndarray", is_output: "bool") -> cuda.CUdeviceptr:
|
| 49 |
-
"""Convert the input numpy tensor to CUDA device pointer
|
| 50 |
-
|
| 51 |
-
:param np_tensor: input numpy nd array
|
| 52 |
-
:param is_output: whether the tensor is output
|
| 53 |
-
|
| 54 |
-
:return: CUDA device pointer
|
| 55 |
-
"""
|
| 56 |
-
# copy the data to device
|
| 57 |
-
if is_output:
|
| 58 |
-
return device_mem_alloc(np_tensor.size * np_tensor.itemsize)
|
| 59 |
-
else:
|
| 60 |
-
return todevice(np_tensor)
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
class TorchFrontend:
|
| 64 |
-
"""
|
| 65 |
-
Frontend node for torch
|
| 66 |
-
"""
|
| 67 |
-
|
| 68 |
-
@staticmethod
|
| 69 |
-
def argument(torch_tensor: "torch.Tensor") -> cuda.CUdeviceptr:
|
| 70 |
-
"""Convert the input torch tensor to CUDA device pointer
|
| 71 |
-
|
| 72 |
-
:param torch_tensor: input torch tensor
|
| 73 |
-
:param is_output: whether the tensor is output
|
| 74 |
-
|
| 75 |
-
:return: CUDA device pointer
|
| 76 |
-
"""
|
| 77 |
-
|
| 78 |
-
# check the device of torch_tensor
|
| 79 |
-
if not torch_tensor.is_cuda:
|
| 80 |
-
torch_tensor = torch_tensor.to("cuda")
|
| 81 |
-
|
| 82 |
-
return cuda.CUdeviceptr(torch_tensor.data_ptr())
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
class CupyFrontend:
|
| 86 |
-
"""
|
| 87 |
-
Frontend node for cupy
|
| 88 |
-
"""
|
| 89 |
-
|
| 90 |
-
@staticmethod
|
| 91 |
-
def argument(cupy_ndarray: "cp.ndarray"):
|
| 92 |
-
return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr))
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
class TensorFrontend:
|
| 96 |
-
"""
|
| 97 |
-
Universal Frontend for client-provide tensors
|
| 98 |
-
"""
|
| 99 |
-
|
| 100 |
-
@staticmethod
|
| 101 |
-
def argument(tensor, is_output=False):
|
| 102 |
-
if is_numpy_tensor(tensor):
|
| 103 |
-
return NumpyFrontend.argument(tensor, is_output)
|
| 104 |
-
elif is_torch_tensor(tensor):
|
| 105 |
-
return TorchFrontend.argument(tensor)
|
| 106 |
-
elif is_cupy_tensor(tensor):
|
| 107 |
-
return CupyFrontend.argument(tensor)
|
| 108 |
-
else:
|
| 109 |
-
raise NotImplementedError("Unknown Tensor Type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/gemm_operation.py
DELETED
|
@@ -1,2145 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
from __future__ import annotations
|
| 33 |
-
|
| 34 |
-
import copy
|
| 35 |
-
import ctypes
|
| 36 |
-
import enum
|
| 37 |
-
|
| 38 |
-
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 39 |
-
cuda = lazy_import("cuda.cuda")
|
| 40 |
-
cudart = lazy_import("cuda.cudart")
|
| 41 |
-
from cutlass_library import SubstituteTemplate
|
| 42 |
-
import numpy as np
|
| 43 |
-
|
| 44 |
-
from cutlass_library import (
|
| 45 |
-
ComplexTransformTag,
|
| 46 |
-
DataType,
|
| 47 |
-
DataTypeNames,
|
| 48 |
-
DataTypeSize,
|
| 49 |
-
DataTypeTag,
|
| 50 |
-
EpilogueScheduleSuffixes,
|
| 51 |
-
EpilogueScheduleTag,
|
| 52 |
-
EpilogueScheduleType,
|
| 53 |
-
GemmKind,
|
| 54 |
-
GemmKindNames,
|
| 55 |
-
GemmUniversalMode,
|
| 56 |
-
KernelScheduleSuffixes,
|
| 57 |
-
KernelScheduleTag,
|
| 58 |
-
KernelScheduleType,
|
| 59 |
-
LayoutTag,
|
| 60 |
-
LayoutType,
|
| 61 |
-
MathOperation,
|
| 62 |
-
MathOperationTag,
|
| 63 |
-
OpcodeClass,
|
| 64 |
-
OpcodeClassNames,
|
| 65 |
-
OpcodeClassTag,
|
| 66 |
-
OperationKind,
|
| 67 |
-
ShortComplexLayoutNames,
|
| 68 |
-
ShortDataTypeNames,
|
| 69 |
-
ShortLayoutTypeNames,
|
| 70 |
-
SwizzlingFunctor,
|
| 71 |
-
SwizzlingFunctorTag,
|
| 72 |
-
TileSchedulerSuffixes,
|
| 73 |
-
TileSchedulerTag,
|
| 74 |
-
TileSchedulerType,
|
| 75 |
-
get_complex_from_real
|
| 76 |
-
)
|
| 77 |
-
from cutlass_cppgen.backend.arguments import ArgumentBase
|
| 78 |
-
from cutlass_cppgen.backend.c_types import (
|
| 79 |
-
GemmCoord_,
|
| 80 |
-
GemmCoordBatched_,
|
| 81 |
-
GenericMainloopArguments3x_,
|
| 82 |
-
StrideBatched_,
|
| 83 |
-
dim3_,
|
| 84 |
-
get_gemm_arguments,
|
| 85 |
-
get_gemm_arguments_3x,
|
| 86 |
-
get_gemm_arguments_streamk,
|
| 87 |
-
get_gemm_grouped_arguments,
|
| 88 |
-
get_mainloop_arguments_3x,
|
| 89 |
-
get_tile_scheduler_arguments_3x,
|
| 90 |
-
)
|
| 91 |
-
from cutlass_cppgen.backend.library import (
|
| 92 |
-
ApiVersion,
|
| 93 |
-
EmissionType,
|
| 94 |
-
SchedulerMode,
|
| 95 |
-
SchedulerModeTag,
|
| 96 |
-
TensorDescription,
|
| 97 |
-
TileDescription,
|
| 98 |
-
api_version,
|
| 99 |
-
)
|
| 100 |
-
from cutlass_cppgen.backend.memory_manager import device_mem_alloc, todevice
|
| 101 |
-
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
|
| 102 |
-
from cutlass_cppgen.backend.type_hint import GemmOperation, Tensor
|
| 103 |
-
from cutlass_cppgen.backend.utils.device import device_sm_count
|
| 104 |
-
from cutlass_cppgen.shape import GemmCoord, MatrixCoord
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
################################################################################
|
| 108 |
-
#
|
| 109 |
-
# Data structure modeling a GEMM operation
|
| 110 |
-
#
|
| 111 |
-
################################################################################
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def leading_dimension(layout: LayoutType, shape: MatrixCoord) -> int:
|
| 115 |
-
"""
|
| 116 |
-
Returns the leading dimenson of a tensor with layout ``layout`` and shape ``shape``.
|
| 117 |
-
|
| 118 |
-
:param layout: layout of the tensor
|
| 119 |
-
:type layout: cutlass_cppgen.shape.LayoutType
|
| 120 |
-
:param shape: shape of the tensor
|
| 121 |
-
:type shape: cutlass_cppgen.shape.MatrixCoord
|
| 122 |
-
|
| 123 |
-
:return: leading dimension of the tensor
|
| 124 |
-
:rtype: int
|
| 125 |
-
"""
|
| 126 |
-
if layout == LayoutType.RowMajor:
|
| 127 |
-
return shape.column
|
| 128 |
-
elif layout == LayoutType.ColumnMajor:
|
| 129 |
-
return shape.row
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def transpose_layout(layout: LayoutType) -> LayoutType:
|
| 133 |
-
if layout == LayoutType.ColumnMajor:
|
| 134 |
-
return LayoutType.RowMajor
|
| 135 |
-
elif layout == LayoutType.RowMajor:
|
| 136 |
-
return LayoutType.ColumnMajor
|
| 137 |
-
else:
|
| 138 |
-
raise ValueError(f"Unsupported Layout {layout}")
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
class GemmArguments2x(ArgumentBase):
|
| 142 |
-
"""
|
| 143 |
-
Argument wrapper for GEMM in CUTLASS 2. It encodes problem information and
|
| 144 |
-
user-provide tensors into the kernel's argument
|
| 145 |
-
|
| 146 |
-
:param operation: the GEMM operation to take the argument
|
| 147 |
-
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
| 148 |
-
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
| 149 |
-
|
| 150 |
-
:param problem_size: GEMM problem size gemm(M, N, K)
|
| 151 |
-
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
| 152 |
-
|
| 153 |
-
:param A: tensor A
|
| 154 |
-
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 155 |
-
|
| 156 |
-
:param B: tensor B
|
| 157 |
-
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 158 |
-
|
| 159 |
-
:param C: tensor C
|
| 160 |
-
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 161 |
-
|
| 162 |
-
:param D: tensor D
|
| 163 |
-
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 164 |
-
|
| 165 |
-
:param gemm_mode: GEMM mode
|
| 166 |
-
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
|
| 167 |
-
|
| 168 |
-
:param output_op: output operator, optional
|
| 169 |
-
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
| 170 |
-
|
| 171 |
-
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
| 172 |
-
:type stream: :class:`cuda.cuda.CUstream`
|
| 173 |
-
"""
|
| 174 |
-
|
| 175 |
-
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
| 176 |
-
self.operation = operation
|
| 177 |
-
|
| 178 |
-
self.layout_A = operation.A.layout
|
| 179 |
-
self.layout_B = operation.B.layout
|
| 180 |
-
self.layout_C = operation.C.layout
|
| 181 |
-
|
| 182 |
-
self.element_A = operation.A.element
|
| 183 |
-
self.element_B = operation.B.element
|
| 184 |
-
self.element_C = operation.C.element
|
| 185 |
-
|
| 186 |
-
if operation.C.layout in [LayoutType.RowMajorInterleaved32, LayoutType.ColumnMajorInterleaved32]:
|
| 187 |
-
raise Exception("Interleaved layout not currently supported")
|
| 188 |
-
|
| 189 |
-
if hasattr(self.operation.epilogue_functor, "visitor") and operation.arch not in [90, 100, 101, 103]:
|
| 190 |
-
super().__init__(A, B, None, None, **kwargs)
|
| 191 |
-
else:
|
| 192 |
-
super().__init__(A, B, C, D, **kwargs)
|
| 193 |
-
|
| 194 |
-
if operation.switched:
|
| 195 |
-
self.problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
|
| 196 |
-
self.ptr_A, self.ptr_B = self.ptr_B, self.ptr_A
|
| 197 |
-
else:
|
| 198 |
-
self.problem_size = problem_size
|
| 199 |
-
# If the number of elements in C = problem_size.n, C is treated as the bias
|
| 200 |
-
if hasattr(self, "tensor_c_numel"):
|
| 201 |
-
if self.tensor_c_numel == self.problem_size.n and self.problem_size.m != 1:
|
| 202 |
-
self.bias = True
|
| 203 |
-
|
| 204 |
-
self.lda = leading_dimension(self.layout_A, self.problem_size.mk)
|
| 205 |
-
self.ldb = leading_dimension(self.layout_B, self.problem_size.kn)
|
| 206 |
-
self.ldc = leading_dimension(self.layout_C, self.problem_size.mn)
|
| 207 |
-
self.ldd = self.ldc
|
| 208 |
-
|
| 209 |
-
if self.bias:
|
| 210 |
-
self.ldc = 0
|
| 211 |
-
|
| 212 |
-
if "output_op" in kwargs.keys() and gemm_mode != GemmUniversalMode.GemmSplitKParallel:
|
| 213 |
-
self.output_op = kwargs["output_op"]
|
| 214 |
-
else:
|
| 215 |
-
if self.operation.epilogue_functor.element_epilogue in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]:
|
| 216 |
-
dtype = int
|
| 217 |
-
else:
|
| 218 |
-
dtype = float
|
| 219 |
-
self.output_op = self.operation.epilogue_type(dtype(1.0), dtype(0.0))
|
| 220 |
-
|
| 221 |
-
self.gemm_mode = gemm_mode
|
| 222 |
-
if gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]:
|
| 223 |
-
if "split_k_slices" in kwargs.keys():
|
| 224 |
-
self.batch_count = kwargs["split_k_slices"]
|
| 225 |
-
else:
|
| 226 |
-
self.batch_count = 1
|
| 227 |
-
self.split_k_slices = self.batch_count
|
| 228 |
-
|
| 229 |
-
if gemm_mode in [GemmUniversalMode.Batched, GemmUniversalMode.Array]:
|
| 230 |
-
if "batch" in kwargs.keys():
|
| 231 |
-
self.batch_count = kwargs["batch"]
|
| 232 |
-
else:
|
| 233 |
-
self.batch_count = 1
|
| 234 |
-
|
| 235 |
-
if "batch_strides" in kwargs:
|
| 236 |
-
self.batched_stride_A = kwargs["batch_strides"]["A"]
|
| 237 |
-
self.batched_stride_B = kwargs["batch_strides"]["B"]
|
| 238 |
-
self.batched_stride_C = kwargs["batch_strides"]["C"]
|
| 239 |
-
self.batched_stride_D = kwargs["batch_strides"]["D"]
|
| 240 |
-
else:
|
| 241 |
-
self.batched_stride_A = self.problem_size.m * self.problem_size.k
|
| 242 |
-
self.batched_stride_B = self.problem_size.n * self.problem_size.k
|
| 243 |
-
self.batched_stride_C = self.problem_size.m * self.problem_size.n
|
| 244 |
-
self.batched_stride_D = self.problem_size.m * self.problem_size.n
|
| 245 |
-
|
| 246 |
-
if self.bias:
|
| 247 |
-
self.batched_stride_C = self.problem_size.n
|
| 248 |
-
|
| 249 |
-
if gemm_mode == GemmUniversalMode.Array:
|
| 250 |
-
self.ptr_A_array = []
|
| 251 |
-
self.ptr_B_array = []
|
| 252 |
-
self.ptr_C_array = []
|
| 253 |
-
self.ptr_D_array = []
|
| 254 |
-
|
| 255 |
-
ptr_A_addr = int(self.ptr_A)
|
| 256 |
-
ptr_B_addr = int(self.ptr_B)
|
| 257 |
-
ptr_C_addr = int(self.ptr_C)
|
| 258 |
-
ptr_D_addr = int(self.ptr_D)
|
| 259 |
-
|
| 260 |
-
stride_A = self.batched_stride_A * DataTypeSize[self.element_A] // 8
|
| 261 |
-
stride_B = self.batched_stride_B * DataTypeSize[self.element_B] // 8
|
| 262 |
-
stride_C = self.batched_stride_C * DataTypeSize[self.element_C] // 8
|
| 263 |
-
stride_D = self.batched_stride_D * DataTypeSize[self.element_C] // 8
|
| 264 |
-
for _ in range(self.batch_count):
|
| 265 |
-
self.ptr_A_array.append(ptr_A_addr)
|
| 266 |
-
self.ptr_B_array.append(ptr_B_addr)
|
| 267 |
-
self.ptr_C_array.append(ptr_C_addr)
|
| 268 |
-
self.ptr_D_array.append(ptr_D_addr)
|
| 269 |
-
|
| 270 |
-
ptr_A_addr += stride_A
|
| 271 |
-
ptr_B_addr += stride_B
|
| 272 |
-
ptr_C_addr += stride_C
|
| 273 |
-
ptr_D_addr += stride_D
|
| 274 |
-
|
| 275 |
-
self.ptr_A_array_buffer = todevice(self.ptr_A_array, dtype=np.int64)
|
| 276 |
-
self.ptr_B_array_buffer = todevice(self.ptr_B_array, dtype=np.int64)
|
| 277 |
-
self.ptr_C_array_buffer = todevice(self.ptr_C_array, dtype=np.int64)
|
| 278 |
-
self.ptr_D_array_buffer = todevice(self.ptr_D_array, dtype=np.int64)
|
| 279 |
-
|
| 280 |
-
if isinstance(self.operation, GemmOperationUniversal):
|
| 281 |
-
self.initialize()
|
| 282 |
-
|
| 283 |
-
def get_arguments(self):
|
| 284 |
-
problem_size_ = self.problem_size.ctype
|
| 285 |
-
grid_tiled_shape_ = GemmCoord(
|
| 286 |
-
self.grid_tiled_shape.x,
|
| 287 |
-
self.grid_tiled_shape.y,
|
| 288 |
-
self.grid_tiled_shape.z ).ctype
|
| 289 |
-
|
| 290 |
-
if self.gemm_mode == GemmUniversalMode.Array:
|
| 291 |
-
arguments = self.operation.argument_type(
|
| 292 |
-
# Arguments from UniversalArgumentsBase
|
| 293 |
-
self.gemm_mode,
|
| 294 |
-
problem_size_,
|
| 295 |
-
self.batch_count,
|
| 296 |
-
0,
|
| 297 |
-
# Remaining arguments
|
| 298 |
-
self.output_op,
|
| 299 |
-
int(self.ptr_A_array_buffer.ptr),
|
| 300 |
-
int(self.ptr_B_array_buffer.ptr),
|
| 301 |
-
int(self.ptr_C_array_buffer.ptr),
|
| 302 |
-
int(self.ptr_D_array_buffer.ptr),
|
| 303 |
-
0, 0, 0,
|
| 304 |
-
self.lda, self.ldb, self.ldc, self.ldd,
|
| 305 |
-
self.lda, self.ldb, self.ldc, self.ldd,
|
| 306 |
-
0, 0, 0
|
| 307 |
-
)
|
| 308 |
-
else:
|
| 309 |
-
arguments = self.operation.argument_type(
|
| 310 |
-
# Arguments from UniversalArgumentsBase
|
| 311 |
-
self.gemm_mode, problem_size_, self.batch_count, self.batched_stride_D,
|
| 312 |
-
# Remaining arguments
|
| 313 |
-
self.output_op,
|
| 314 |
-
int(self.ptr_A),
|
| 315 |
-
int(self.ptr_B),
|
| 316 |
-
int(self.ptr_C),
|
| 317 |
-
int(self.ptr_D),
|
| 318 |
-
self.batched_stride_A,
|
| 319 |
-
self.batched_stride_B,
|
| 320 |
-
self.batched_stride_C,
|
| 321 |
-
self.lda, self.ldb, self.ldc, self.ldd,
|
| 322 |
-
self.lda, self.ldb, self.ldc, self.ldd,
|
| 323 |
-
0, 0, 0
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
self.arguments = arguments, grid_tiled_shape_, self.gemm_k_size
|
| 327 |
-
|
| 328 |
-
def initialize(self):
|
| 329 |
-
launch_config = self.operation.rt_module.plan(self)
|
| 330 |
-
|
| 331 |
-
# Get the host and device workspace
|
| 332 |
-
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
|
| 333 |
-
|
| 334 |
-
if device_workspace_size > 0:
|
| 335 |
-
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
| 336 |
-
workspace_ptr = self.workspace_buffer.ptr
|
| 337 |
-
err, = cuda.cuMemsetD32(
|
| 338 |
-
workspace_ptr, 0, device_workspace_size // 4)
|
| 339 |
-
else:
|
| 340 |
-
workspace_ptr = None
|
| 341 |
-
|
| 342 |
-
device_workspace = 0
|
| 343 |
-
if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
|
| 344 |
-
# In GEMM splik-K parallel, the D pointer is redirected to the workspace
|
| 345 |
-
self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
|
| 346 |
-
elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
|
| 347 |
-
device_workspace = workspace_ptr
|
| 348 |
-
|
| 349 |
-
self.get_arguments()
|
| 350 |
-
|
| 351 |
-
arguments, grid_tiled_shape, gemm_k_size = self.arguments
|
| 352 |
-
res_arg = self.operation.rt_module.get_args(
|
| 353 |
-
ctypes.byref(arguments), ctypes.c_void_p(int(device_workspace)))
|
| 354 |
-
host_workspace = bytearray(res_arg.contents)
|
| 355 |
-
|
| 356 |
-
device_workspace = None
|
| 357 |
-
|
| 358 |
-
self.host_workspace = host_workspace
|
| 359 |
-
self.device_workspace = device_workspace
|
| 360 |
-
self.launch_config = launch_config
|
| 361 |
-
|
| 362 |
-
def sync(self, stream_sync=True):
|
| 363 |
-
super().sync(stream_sync)
|
| 364 |
-
if hasattr(self.output_op, "sync"):
|
| 365 |
-
self.output_op.sync()
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
class GemmArguments2xStreamK(GemmArguments2x):
|
| 369 |
-
"""
|
| 370 |
-
Argument wrapper for stream-K GEMMs in CUTLASS 2. It encodes problem information and
|
| 371 |
-
user-provide tensors into the kernel's argument
|
| 372 |
-
|
| 373 |
-
:param operation: the GEMM operation to take the argument
|
| 374 |
-
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
| 375 |
-
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
| 376 |
-
|
| 377 |
-
:param problem_size: GEMM problem size gemm(M, N, K)
|
| 378 |
-
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
| 379 |
-
|
| 380 |
-
:param A: tensor A
|
| 381 |
-
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 382 |
-
|
| 383 |
-
:param B: tensor B
|
| 384 |
-
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 385 |
-
|
| 386 |
-
:param C: tensor C
|
| 387 |
-
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 388 |
-
|
| 389 |
-
:param D: tensor D
|
| 390 |
-
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 391 |
-
|
| 392 |
-
:param gemm_mode: GEMM mode
|
| 393 |
-
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
|
| 394 |
-
|
| 395 |
-
:param output_op: output operator, optional
|
| 396 |
-
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
| 397 |
-
"""
|
| 398 |
-
|
| 399 |
-
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
| 400 |
-
if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]:
|
| 401 |
-
raise Exception(f"Unsupported GEMM mode {gemm_mode}.")
|
| 402 |
-
|
| 403 |
-
super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
|
| 404 |
-
|
| 405 |
-
def get_arguments(self):
|
| 406 |
-
batch_stride_A = self.problem_size.m * self.problem_size.k
|
| 407 |
-
batch_stride_B = self.problem_size.k * self.problem_size.n
|
| 408 |
-
batch_stride_C = self.problem_size.m * self.problem_size.n
|
| 409 |
-
batch_stride_D = self.problem_size.m * self.problem_size.n
|
| 410 |
-
|
| 411 |
-
arguments = self.operation.argument_type(
|
| 412 |
-
self.gemm_mode,
|
| 413 |
-
GemmCoord_(self.problem_size.m, self.problem_size.n, self.problem_size.k),
|
| 414 |
-
self.batch_count,
|
| 415 |
-
self.output_op,
|
| 416 |
-
int(self.ptr_A),
|
| 417 |
-
int(self.ptr_B),
|
| 418 |
-
int(self.ptr_C),
|
| 419 |
-
int(self.ptr_D),
|
| 420 |
-
batch_stride_A,
|
| 421 |
-
batch_stride_B,
|
| 422 |
-
batch_stride_C,
|
| 423 |
-
batch_stride_D,
|
| 424 |
-
self.lda, self.ldb, self.ldc, self.ldd, # strides
|
| 425 |
-
self.lda, self.ldb, self.ldc, self.ldd,
|
| 426 |
-
-1, # avail_sms
|
| 427 |
-
)
|
| 428 |
-
return arguments
|
| 429 |
-
|
| 430 |
-
def initialize(self):
|
| 431 |
-
# Get the host and device workspace
|
| 432 |
-
device_workspace_size = self.operation.rt_module.get_device_workspace_size(
|
| 433 |
-
self,
|
| 434 |
-
device_sm_count(),
|
| 435 |
-
self.operation.rt_module.occupancy
|
| 436 |
-
)
|
| 437 |
-
|
| 438 |
-
if device_workspace_size > 0:
|
| 439 |
-
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
| 440 |
-
workspace_ptr = self.workspace_buffer.ptr
|
| 441 |
-
err, = cuda.cuMemsetD32(
|
| 442 |
-
workspace_ptr, 0, device_workspace_size // 4)
|
| 443 |
-
else:
|
| 444 |
-
workspace_ptr = None
|
| 445 |
-
|
| 446 |
-
device_workspace = 0
|
| 447 |
-
if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
|
| 448 |
-
# In GEMM splik-K parallel, the D pointer is redirected to the workspace
|
| 449 |
-
self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
|
| 450 |
-
elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
|
| 451 |
-
device_workspace = workspace_ptr
|
| 452 |
-
|
| 453 |
-
arguments = self.get_arguments()
|
| 454 |
-
|
| 455 |
-
res_arg = self.operation.rt_module.get_args(
|
| 456 |
-
ctypes.byref(arguments),
|
| 457 |
-
ctypes.c_void_p(int(device_workspace)),
|
| 458 |
-
device_sm_count(),
|
| 459 |
-
self.operation.rt_module.occupancy
|
| 460 |
-
)
|
| 461 |
-
host_workspace = bytearray(res_arg.contents)
|
| 462 |
-
|
| 463 |
-
grid = self.operation.rt_module.get_grid_shape(
|
| 464 |
-
ctypes.byref(arguments),
|
| 465 |
-
device_sm_count(),
|
| 466 |
-
self.operation.rt_module.occupancy
|
| 467 |
-
)
|
| 468 |
-
|
| 469 |
-
device_workspace = None
|
| 470 |
-
|
| 471 |
-
self.host_workspace = host_workspace
|
| 472 |
-
self.device_workspace = device_workspace
|
| 473 |
-
self.launch_config = LaunchConfiguration(
|
| 474 |
-
[grid.m, grid.n, grid.k],
|
| 475 |
-
[self.operation.rt_module.threads, 1, 1],
|
| 476 |
-
self.operation.rt_module.shared_memory_capacity
|
| 477 |
-
)
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
class GemmArguments3x(GemmArguments2x):
|
| 481 |
-
"""
|
| 482 |
-
Argument wrapper for GEMM in CUTLASS 3. It encodes problem information and
|
| 483 |
-
user-provide tensors into the kernel's argument
|
| 484 |
-
|
| 485 |
-
:param operation: the GEMM operation to take the argument
|
| 486 |
-
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
| 487 |
-
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
| 488 |
-
|
| 489 |
-
:param problem_size: GEMM problem size gemm(M, N, K)
|
| 490 |
-
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
| 491 |
-
|
| 492 |
-
:param A: tensor A
|
| 493 |
-
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 494 |
-
|
| 495 |
-
:param B: tensor B
|
| 496 |
-
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 497 |
-
|
| 498 |
-
:param C: tensor C
|
| 499 |
-
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 500 |
-
|
| 501 |
-
:param D: tensor D
|
| 502 |
-
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 503 |
-
|
| 504 |
-
:param gemm_mode: GEMM mode
|
| 505 |
-
:type gemm_mode: GemmUniversalMode
|
| 506 |
-
|
| 507 |
-
:param output_op: output operator, optional
|
| 508 |
-
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
| 509 |
-
"""
|
| 510 |
-
|
| 511 |
-
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
| 512 |
-
if gemm_mode not in [GemmUniversalMode.Gemm, GemmUniversalMode.Batched]:
|
| 513 |
-
raise Exception(f"Unsupported GEMM mode {gemm_mode}.")
|
| 514 |
-
|
| 515 |
-
super().__init__(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
|
| 516 |
-
|
| 517 |
-
def get_arguments(self):
|
| 518 |
-
mainloop_args = get_mainloop_arguments_3x(
|
| 519 |
-
self.operation.tile_description.kernel_schedule,
|
| 520 |
-
self.operation.A.element,
|
| 521 |
-
self.operation.B.element,
|
| 522 |
-
self.operation.A.alignment,
|
| 523 |
-
self.operation.B.alignment
|
| 524 |
-
)
|
| 525 |
-
scheduler_args = get_tile_scheduler_arguments_3x(self.operation.tile_description.tile_scheduler)
|
| 526 |
-
uses_default_epilogue = self.operation.rt_module.uses_default_epilogue()
|
| 527 |
-
argument_type, epilogue_args, epilogue_type, hw_info = get_gemm_arguments_3x(
|
| 528 |
-
mainloop_args, self.operation.epilogue_functor, scheduler_args, uses_default_epilogue)
|
| 529 |
-
|
| 530 |
-
problem_size_ = GemmCoordBatched_(self.problem_size, self.batch_count)
|
| 531 |
-
|
| 532 |
-
if self.batch_count > 1:
|
| 533 |
-
bsA = self.batched_stride_A
|
| 534 |
-
bsB = self.batched_stride_B
|
| 535 |
-
bsC = self.batched_stride_C
|
| 536 |
-
bsD = self.batched_stride_D
|
| 537 |
-
else:
|
| 538 |
-
bsA = 0
|
| 539 |
-
bsB = 0
|
| 540 |
-
bsC = 0
|
| 541 |
-
bsD = 0
|
| 542 |
-
stride_A = StrideBatched_(self.lda, bsA)
|
| 543 |
-
stride_B = StrideBatched_(self.ldb, bsB)
|
| 544 |
-
stride_C = StrideBatched_(self.ldc, bsC)
|
| 545 |
-
stride_D = StrideBatched_(self.ldd, bsD)
|
| 546 |
-
|
| 547 |
-
# Superset of potential mainloop arguments
|
| 548 |
-
generic_args = GenericMainloopArguments3x_(
|
| 549 |
-
int(self.ptr_A),
|
| 550 |
-
stride_A,
|
| 551 |
-
int(self.ptr_B),
|
| 552 |
-
stride_B,
|
| 553 |
-
4 # mma_promotion_interval
|
| 554 |
-
)
|
| 555 |
-
|
| 556 |
-
# Set of mainloop arguments needed for this kernel
|
| 557 |
-
mainloop = mainloop_args.from_generic_mainloop_args(generic_args)
|
| 558 |
-
|
| 559 |
-
if not uses_default_epilogue and hasattr(self.output_op, "to_evt_params"):
|
| 560 |
-
self.output_op = self.output_op.to_evt_params()
|
| 561 |
-
|
| 562 |
-
epilogue = epilogue_args(
|
| 563 |
-
self.output_op,
|
| 564 |
-
int(self.ptr_C),
|
| 565 |
-
stride_C,
|
| 566 |
-
int(self.ptr_D),
|
| 567 |
-
stride_D,
|
| 568 |
-
)
|
| 569 |
-
|
| 570 |
-
# Set hardware info
|
| 571 |
-
hw_info_ = hw_info(
|
| 572 |
-
0, device_sm_count(), 0,
|
| 573 |
-
dim3_(0,0,0),
|
| 574 |
-
dim3_(0,0,0),
|
| 575 |
-
)
|
| 576 |
-
|
| 577 |
-
self.arguments = argument_type(
|
| 578 |
-
int(self.gemm_mode),
|
| 579 |
-
problem_size_,
|
| 580 |
-
mainloop,
|
| 581 |
-
epilogue,
|
| 582 |
-
hw_info_,
|
| 583 |
-
scheduler_args
|
| 584 |
-
)
|
| 585 |
-
return self.arguments
|
| 586 |
-
|
| 587 |
-
def initialize(self):
|
| 588 |
-
# Get the host and evice workspace
|
| 589 |
-
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
|
| 590 |
-
|
| 591 |
-
if device_workspace_size > 0:
|
| 592 |
-
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
| 593 |
-
workspace_ptr = self.workspace_buffer.ptr
|
| 594 |
-
err, = cuda.cuMemsetD32(
|
| 595 |
-
workspace_ptr, 0, device_workspace_size // 4)
|
| 596 |
-
else:
|
| 597 |
-
workspace_ptr = None
|
| 598 |
-
|
| 599 |
-
device_workspace = 0
|
| 600 |
-
if workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
|
| 601 |
-
# In GEMM splik-K parallel, the D pointer is redirected to the workspace
|
| 602 |
-
self.ptr_D = cuda.CUdeviceptr(workspace_ptr)
|
| 603 |
-
elif workspace_ptr is not None and self.gemm_mode == GemmUniversalMode.Gemm:
|
| 604 |
-
device_workspace = workspace_ptr
|
| 605 |
-
|
| 606 |
-
self.get_arguments()
|
| 607 |
-
res_arg = self.operation.rt_module.get_args(
|
| 608 |
-
ctypes.byref(self.arguments),
|
| 609 |
-
ctypes.c_void_p(int(device_workspace)),
|
| 610 |
-
)
|
| 611 |
-
host_workspace = bytearray(res_arg.contents)
|
| 612 |
-
|
| 613 |
-
grid = self.operation.rt_module.get_grid_shape(
|
| 614 |
-
ctypes.byref(self.arguments),
|
| 615 |
-
ctypes.c_void_p(int(device_workspace)),
|
| 616 |
-
)
|
| 617 |
-
block = self.operation.rt_module.get_block_shape()
|
| 618 |
-
|
| 619 |
-
device_workspace = None
|
| 620 |
-
|
| 621 |
-
self.host_workspace = host_workspace
|
| 622 |
-
self.device_workspace = device_workspace
|
| 623 |
-
self.launch_config = LaunchConfiguration(
|
| 624 |
-
[grid.x, grid.y, grid.z],
|
| 625 |
-
[block.x, block.y, block.z],
|
| 626 |
-
self.operation.rt_module.shared_memory_capacity,
|
| 627 |
-
)
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
| 631 |
-
"""
|
| 632 |
-
Argument wrapper for GEMM in CUTLASS 2 or 3. It returns either 2x arguments
|
| 633 |
-
or 3x arguments depending on the `arch` field specified in `operation`.
|
| 634 |
-
|
| 635 |
-
:param operation: the GEMM operation to take the argument
|
| 636 |
-
:type operation: :class:`cutlass_cppgen.backend.GemmOperationUniversal` |
|
| 637 |
-
:class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
| 638 |
-
|
| 639 |
-
:param problem_size: GEMM problem size gemm(M, N, K)
|
| 640 |
-
:type operation: :class:`cutlass_cppgen.shape.GemmCoord`
|
| 641 |
-
|
| 642 |
-
:param A: tensor A
|
| 643 |
-
:type A: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 644 |
-
|
| 645 |
-
:param B: tensor B
|
| 646 |
-
:type B: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 647 |
-
|
| 648 |
-
:param C: tensor C
|
| 649 |
-
:type C: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 650 |
-
|
| 651 |
-
:param D: tensor D
|
| 652 |
-
:type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray
|
| 653 |
-
|
| 654 |
-
:param gemm_mode: GEMM mode
|
| 655 |
-
:type gemm_mode: :class:`cutlass_library.GemmUniversalMode`
|
| 656 |
-
|
| 657 |
-
:param output_op: output operator, optional
|
| 658 |
-
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
| 659 |
-
"""
|
| 660 |
-
if operation.swizzling_functor == SwizzlingFunctor.StreamK:
|
| 661 |
-
if operation.api == ApiVersion.v3x:
|
| 662 |
-
raise Exception("Stream K is currently only supported in CUTLASS 2.x")
|
| 663 |
-
ArgClass = GemmArguments2xStreamK
|
| 664 |
-
else:
|
| 665 |
-
ArgClass = GemmArguments3x if operation.api == ApiVersion.v3x else GemmArguments2x
|
| 666 |
-
return ArgClass(operation, problem_size, A, B, C, D, gemm_mode, **kwargs)
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
class GemmGroupedArguments:
|
| 670 |
-
"""
|
| 671 |
-
Argument wrapper for GEMM Grouped. It encodes problem information and
|
| 672 |
-
user-provide tensors into the kernel's argument
|
| 673 |
-
|
| 674 |
-
:param operation: the GEMM Grouped operation to take the argument
|
| 675 |
-
:type operation: :class:`cutlass_cppgen.backend.GemmOperationGrouped`
|
| 676 |
-
|
| 677 |
-
:param problem_size: list of GEMM problem size gemm(M, N, K)
|
| 678 |
-
:type operation: list[:class:`cutlass_cppgen.shape.GemmCoord`]
|
| 679 |
-
|
| 680 |
-
:param A: list of tensor A
|
| 681 |
-
:type A: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
|
| 682 |
-
|
| 683 |
-
:param B: list of tensor B
|
| 684 |
-
:type B: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
|
| 685 |
-
|
| 686 |
-
:param C: list of tensor C
|
| 687 |
-
:type C: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
|
| 688 |
-
|
| 689 |
-
:param D: list of tensor D
|
| 690 |
-
:type D: list[cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray]
|
| 691 |
-
|
| 692 |
-
:param output_op: output operator, optional
|
| 693 |
-
:type output_op: :class:`cutlass_cppgen.backend.LinearCombinationFunctorArguments`
|
| 694 |
-
|
| 695 |
-
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
| 696 |
-
:type stream: :class:`cuda.cuda.CUstream`
|
| 697 |
-
"""
|
| 698 |
-
|
| 699 |
-
def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs):
|
| 700 |
-
# Get number of problems in the group
|
| 701 |
-
self.problem_count = len(problem_sizes)
|
| 702 |
-
|
| 703 |
-
# Check the input arguments
|
| 704 |
-
assert len(A) == self.problem_count
|
| 705 |
-
assert len(B) == self.problem_count
|
| 706 |
-
assert len(C) == self.problem_count
|
| 707 |
-
assert len(D) == self.problem_count
|
| 708 |
-
|
| 709 |
-
problem_size_host = []
|
| 710 |
-
self.ptr_A_host = []
|
| 711 |
-
self.ptr_B_host = []
|
| 712 |
-
self.ptr_C_host = []
|
| 713 |
-
self.ptr_D_host = []
|
| 714 |
-
|
| 715 |
-
lda_host = []
|
| 716 |
-
ldb_host = []
|
| 717 |
-
ldc_host = []
|
| 718 |
-
ldd_host = []
|
| 719 |
-
|
| 720 |
-
self.partitions = 1
|
| 721 |
-
|
| 722 |
-
self.operation = operation
|
| 723 |
-
|
| 724 |
-
# Get the threadblock
|
| 725 |
-
threadblock_shape = operation.tile_description.threadblock_shape
|
| 726 |
-
self.threadblock_shape = GemmCoord(
|
| 727 |
-
threadblock_shape[0],
|
| 728 |
-
threadblock_shape[1],
|
| 729 |
-
threadblock_shape[2],
|
| 730 |
-
)
|
| 731 |
-
self.threadblock_swizzle = operation.swizzling_functor
|
| 732 |
-
|
| 733 |
-
self.total_tiles = 0
|
| 734 |
-
|
| 735 |
-
self.gemm_arguments = []
|
| 736 |
-
|
| 737 |
-
self.stream = kwargs.get("stream", cuda.CUstream(0))
|
| 738 |
-
|
| 739 |
-
# Process the input arguments
|
| 740 |
-
for idx, problem_size in enumerate(problem_sizes):
|
| 741 |
-
M, N, K = problem_size.m, problem_size.n, problem_size.k
|
| 742 |
-
temp_argument = GemmArguments2x(
|
| 743 |
-
operation=operation,
|
| 744 |
-
problem_size=GemmCoord(M, N, K),
|
| 745 |
-
A=A[idx], B=B[idx], C=C[idx], D=D[idx])
|
| 746 |
-
self.gemm_arguments.append(temp_argument)
|
| 747 |
-
|
| 748 |
-
problem_size_host.append(
|
| 749 |
-
[temp_argument.problem_size.m,
|
| 750 |
-
temp_argument.problem_size.n,
|
| 751 |
-
temp_argument.problem_size.k]
|
| 752 |
-
)
|
| 753 |
-
|
| 754 |
-
self.ptr_A_host.append(int(temp_argument.ptr_A))
|
| 755 |
-
lda_host.append(temp_argument.lda)
|
| 756 |
-
|
| 757 |
-
self.ptr_B_host.append(int(temp_argument.ptr_B))
|
| 758 |
-
ldb_host.append(temp_argument.ldb)
|
| 759 |
-
|
| 760 |
-
self.ptr_C_host.append(int(temp_argument.ptr_C))
|
| 761 |
-
ldc_host.append(temp_argument.ldc)
|
| 762 |
-
|
| 763 |
-
self.ptr_D_host.append(int(temp_argument.ptr_D))
|
| 764 |
-
ldd_host.append(temp_argument.ldd)
|
| 765 |
-
|
| 766 |
-
# Get number of tiles
|
| 767 |
-
grid = self.operation.rt_module.get_grid_shape(
|
| 768 |
-
self.operation.rt_module.get_tiled_shape(
|
| 769 |
-
temp_argument.problem_size.ctype,
|
| 770 |
-
self.threadblock_shape.ctype,
|
| 771 |
-
temp_argument.batch_count
|
| 772 |
-
)
|
| 773 |
-
)
|
| 774 |
-
self.total_tiles += grid.x * grid.y * grid.z
|
| 775 |
-
|
| 776 |
-
self.problem_size_buffer = todevice(problem_size_host, np.int32)
|
| 777 |
-
self.ptr_A_buffer = todevice(self.ptr_A_host, np.int64)
|
| 778 |
-
self.ptr_B_buffer = todevice(self.ptr_B_host, np.int64)
|
| 779 |
-
self.ptr_C_buffer = todevice(self.ptr_C_host, np.int64)
|
| 780 |
-
self.ptr_D_buffer = todevice(self.ptr_D_host, np.int64)
|
| 781 |
-
|
| 782 |
-
self.lda_buffer = todevice(lda_host, np.int64)
|
| 783 |
-
self.ldb_buffer = todevice(ldb_host, np.int64)
|
| 784 |
-
self.ldc_buffer = todevice(ldc_host, np.int64)
|
| 785 |
-
self.ldd_buffer = todevice(ldd_host, np.int64)
|
| 786 |
-
|
| 787 |
-
if "output_op" in kwargs.keys():
|
| 788 |
-
self.alpha = kwargs["output_op"].alpha
|
| 789 |
-
self.beta = kwargs["output_op"].beta
|
| 790 |
-
else:
|
| 791 |
-
self.alpha = 1.0
|
| 792 |
-
self.beta = 0.0
|
| 793 |
-
|
| 794 |
-
if "output_op" in kwargs.keys():
|
| 795 |
-
self.output_op = kwargs["output_op"]
|
| 796 |
-
else:
|
| 797 |
-
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
| 798 |
-
|
| 799 |
-
# Get host problem size
|
| 800 |
-
self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0]
|
| 801 |
-
|
| 802 |
-
self.arguments = self.get_arguments()
|
| 803 |
-
|
| 804 |
-
self.initialize()
|
| 805 |
-
|
| 806 |
-
def get_arguments(self):
|
| 807 |
-
return self.operation.argument_type(
|
| 808 |
-
self.problem_size_buffer.ptr,
|
| 809 |
-
self.problem_count,
|
| 810 |
-
self.total_tiles,
|
| 811 |
-
self.output_op,
|
| 812 |
-
self.ptr_A_buffer.ptr,
|
| 813 |
-
self.ptr_B_buffer.ptr,
|
| 814 |
-
self.ptr_C_buffer.ptr,
|
| 815 |
-
self.ptr_D_buffer.ptr,
|
| 816 |
-
self.lda_buffer.ptr,
|
| 817 |
-
self.ldb_buffer.ptr,
|
| 818 |
-
self.ldc_buffer.ptr,
|
| 819 |
-
self.ldd_buffer.ptr,
|
| 820 |
-
ctypes.c_void_p(int(self.host_problem_size_ptr)),
|
| 821 |
-
)
|
| 822 |
-
|
| 823 |
-
def initialize(self):
|
| 824 |
-
# Get launch configuration
|
| 825 |
-
launch_config = self.operation.rt_module.plan(self)
|
| 826 |
-
|
| 827 |
-
# Get the host and evice workspace
|
| 828 |
-
device_workspace_size = self.operation.rt_module.get_device_workspace_size(self)
|
| 829 |
-
|
| 830 |
-
if device_workspace_size > 0:
|
| 831 |
-
self.workspace_buffer = device_mem_alloc(device_workspace_size)
|
| 832 |
-
workspace_ptr = self.workspace_buffer.ptr
|
| 833 |
-
err, = cuda.cuMemsetD32(
|
| 834 |
-
workspace_ptr, 0, device_workspace_size // 4)
|
| 835 |
-
else:
|
| 836 |
-
workspace_ptr = None
|
| 837 |
-
|
| 838 |
-
if self.operation.precompute_mode == SchedulerMode.Host:
|
| 839 |
-
device_workspace_ptr = self.operation.rt_module.host_precompute(
|
| 840 |
-
self, self.operation.rt_module.get_workspace_size(self),)
|
| 841 |
-
else:
|
| 842 |
-
device_workspace_ptr = 0
|
| 843 |
-
|
| 844 |
-
result = self.operation.rt_module.get_args(
|
| 845 |
-
ctypes.byref(self.arguments),
|
| 846 |
-
self.total_tiles,
|
| 847 |
-
ctypes.c_void_p(int(device_workspace_ptr)),
|
| 848 |
-
)
|
| 849 |
-
host_workspace = bytearray(result.contents)
|
| 850 |
-
|
| 851 |
-
device_workspace = None
|
| 852 |
-
|
| 853 |
-
self.host_workspace = host_workspace
|
| 854 |
-
self.device_workspace = device_workspace
|
| 855 |
-
self.launch_config = launch_config
|
| 856 |
-
|
| 857 |
-
def sync(self):
|
| 858 |
-
err, = cudart.cudaDeviceSynchronize()
|
| 859 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 860 |
-
raise RuntimeError("CUDA Error %s" % str(err))
|
| 861 |
-
for arg in self.gemm_arguments:
|
| 862 |
-
arg.sync(stream_sync=False)
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
################################################################################
|
| 866 |
-
# Base class for GEMM runtime module
|
| 867 |
-
################################################################################
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
class GemmRTbase(ExecutableOperation):
|
| 871 |
-
"""
|
| 872 |
-
GemmRT manages the CUTLASS runtime components
|
| 873 |
-
"""
|
| 874 |
-
|
| 875 |
-
KernelTemplate = r"""
|
| 876 |
-
extern "C"
|
| 877 |
-
__global__ void
|
| 878 |
-
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
| 879 |
-
|
| 880 |
-
// Dynamic shared memory base pointer
|
| 881 |
-
extern __shared__ int SharedStorageBase[];
|
| 882 |
-
|
| 883 |
-
// Declare pointer to dynamic shared memory.
|
| 884 |
-
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
| 885 |
-
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
| 886 |
-
|
| 887 |
-
${operation_name}${operation_suffix}::invoke(params, *shared_storage);
|
| 888 |
-
}
|
| 889 |
-
"""
|
| 890 |
-
|
| 891 |
-
def __init__(self, operation: "GemmOperation"):
|
| 892 |
-
super().__init__(operation)
|
| 893 |
-
|
| 894 |
-
self.operation = operation
|
| 895 |
-
threadblock_shape = operation.tile_description.threadblock_shape
|
| 896 |
-
self.threadblock_shape = GemmCoord(
|
| 897 |
-
threadblock_shape[0], threadblock_shape[1], threadblock_shape[2])
|
| 898 |
-
self.threadblock_swizzle = operation.swizzling_functor
|
| 899 |
-
|
| 900 |
-
# Threads per threadblock
|
| 901 |
-
self.threads = operation.tile_description.num_threads
|
| 902 |
-
|
| 903 |
-
def emit(self):
|
| 904 |
-
return self.emitter.emit(self.operation)
|
| 905 |
-
|
| 906 |
-
def can_implement(self, configuration, arguments):
|
| 907 |
-
raise NotImplementedError()
|
| 908 |
-
|
| 909 |
-
def get_host_workspace_size(self, arguments):
|
| 910 |
-
raise NotImplementedError()
|
| 911 |
-
|
| 912 |
-
def get_device_workspace_size(self, arguments):
|
| 913 |
-
return 0
|
| 914 |
-
|
| 915 |
-
def initialize(self):
|
| 916 |
-
err, = cuda.cuFuncSetAttribute(
|
| 917 |
-
self.kernel,
|
| 918 |
-
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
| 919 |
-
value=self.shared_memory_capacity)
|
| 920 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 921 |
-
raise RuntimeError(
|
| 922 |
-
f"CUDA error on call to cuFuncSetAttribute: {cuda.cuGetErrorString(err)[1]}"
|
| 923 |
-
)
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
################################################################################
|
| 927 |
-
# Runtime module for GEMM Universal
|
| 928 |
-
################################################################################
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
class GemmRTUniversal(GemmRTbase):
|
| 932 |
-
"""
|
| 933 |
-
GemmRTUniversal manages the CUTLASS runtime components
|
| 934 |
-
"""
|
| 935 |
-
|
| 936 |
-
HostTemplate = r"""
|
| 937 |
-
extern "C" {
|
| 938 |
-
// Get the size of params in bytes
|
| 939 |
-
int ${operation_name}_get_param_size(){
|
| 940 |
-
return sizeof(${operation_name}${operation_suffix}::Params);
|
| 941 |
-
}
|
| 942 |
-
|
| 943 |
-
// Get the size of dynamic shared memory in bytes
|
| 944 |
-
int ${operation_name}_shared_memory_size() {
|
| 945 |
-
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
| 946 |
-
}
|
| 947 |
-
|
| 948 |
-
// Get the params as byte array
|
| 949 |
-
char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int* workspace){
|
| 950 |
-
${operation_name}_base::Params* params;
|
| 951 |
-
params = new ${operation_name}_base::Params(*argument,
|
| 952 |
-
-1, // SM count. Only used for stream-K
|
| 953 |
-
-1 // Occupancy. Only used for stream-K
|
| 954 |
-
);
|
| 955 |
-
|
| 956 |
-
// Semaphore holds the pointer to the workspace in the Params struct
|
| 957 |
-
params->semaphore = workspace;
|
| 958 |
-
|
| 959 |
-
char *bytes = ((char*)(params));
|
| 960 |
-
char *output = new char[sizeof(${operation_name}_base::Params)];
|
| 961 |
-
for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++)
|
| 962 |
-
output[i] = bytes[i];
|
| 963 |
-
|
| 964 |
-
return output;
|
| 965 |
-
}
|
| 966 |
-
|
| 967 |
-
cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape(
|
| 968 |
-
cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) {
|
| 969 |
-
return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape(
|
| 970 |
-
problem_size, tile_size, split_k_slices);
|
| 971 |
-
}
|
| 972 |
-
|
| 973 |
-
dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) {
|
| 974 |
-
return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape);
|
| 975 |
-
}
|
| 976 |
-
}
|
| 977 |
-
"""
|
| 978 |
-
|
| 979 |
-
def __init__(self, operation):
|
| 980 |
-
super(GemmRTUniversal, self).__init__(operation)
|
| 981 |
-
self.extra_funcs = {
|
| 982 |
-
"get_tiled_shape": GemmCoord_,
|
| 983 |
-
"get_grid_shape": dim3_,
|
| 984 |
-
}
|
| 985 |
-
self.emitter = EmitGemmUniversalInstance(
|
| 986 |
-
"_type", operation.direct_store)
|
| 987 |
-
|
| 988 |
-
self.argument_type, self.epilogue_type = get_gemm_arguments(operation.epilogue_functor)
|
| 989 |
-
self.argtype = [
|
| 990 |
-
ctypes.POINTER(self.argument_type),
|
| 991 |
-
ctypes.POINTER(GemmCoord_), ctypes.c_int, ctypes.c_void_p
|
| 992 |
-
]
|
| 993 |
-
|
| 994 |
-
def plan(self, arguments):
|
| 995 |
-
grid = self.get_tiled_shape(
|
| 996 |
-
arguments.problem_size.ctype,
|
| 997 |
-
self.threadblock_shape.ctype,
|
| 998 |
-
arguments.batch_count
|
| 999 |
-
)
|
| 1000 |
-
|
| 1001 |
-
gemm_k_size = arguments.problem_size.k
|
| 1002 |
-
if arguments.gemm_mode in [GemmUniversalMode.Gemm, GemmUniversalMode.GemmSplitKParallel]:
|
| 1003 |
-
alignk = max(max(128 // DataTypeSize[self.operation.A.element],
|
| 1004 |
-
128 // DataTypeSize[self.operation.B.element]), 1)
|
| 1005 |
-
|
| 1006 |
-
gemm_k_size = (((arguments.problem_size.k + arguments.batch_count - 1) //
|
| 1007 |
-
arguments.batch_count + alignk - 1) // alignk) * alignk
|
| 1008 |
-
|
| 1009 |
-
if gemm_k_size:
|
| 1010 |
-
grid_z = (arguments.problem_size.k + gemm_k_size - 1) // gemm_k_size
|
| 1011 |
-
grid = GemmCoord(grid.m, grid.n, grid_z).ctype
|
| 1012 |
-
|
| 1013 |
-
arguments.grid_tiled_shape = dim3_(grid.m, grid.n, grid.k)
|
| 1014 |
-
grid = self.get_grid_shape(grid)
|
| 1015 |
-
arguments.gemm_k_size = gemm_k_size
|
| 1016 |
-
return LaunchConfiguration(
|
| 1017 |
-
[grid.x, grid.y, grid.z],
|
| 1018 |
-
[self.threads, 1, 1],
|
| 1019 |
-
self.shared_memory_capacity)
|
| 1020 |
-
|
| 1021 |
-
def get_device_workspace_size(self, arguments: GemmArguments):
|
| 1022 |
-
workspace_bytes = 0
|
| 1023 |
-
if arguments.gemm_mode == GemmUniversalMode.GemmSplitKParallel:
|
| 1024 |
-
workspace_bytes = (DataTypeSize[arguments.operation.C.element]
|
| 1025 |
-
* arguments.batched_stride_D * arguments.grid_tiled_shape.z // 8)
|
| 1026 |
-
elif (arguments.gemm_mode == GemmUniversalMode.Gemm and
|
| 1027 |
-
arguments.split_k_slices > 1):
|
| 1028 |
-
workspace_bytes = 4 * arguments.grid_tiled_shape.x * arguments.grid_tiled_shape.y
|
| 1029 |
-
|
| 1030 |
-
return workspace_bytes
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
class GemmRTUniversalStreamK(GemmRTUniversal):
|
| 1034 |
-
"""
|
| 1035 |
-
Manages the CUTLASS runtime components for 2.x stream K kernels
|
| 1036 |
-
"""
|
| 1037 |
-
|
| 1038 |
-
HostTemplate = r"""
|
| 1039 |
-
extern "C" {
|
| 1040 |
-
// Get the size of params in bytes
|
| 1041 |
-
int ${operation_name}_get_param_size(){
|
| 1042 |
-
return sizeof(${operation_name}${operation_suffix}::Params);
|
| 1043 |
-
}
|
| 1044 |
-
|
| 1045 |
-
// Get the size of dynamic shared memory in bytes
|
| 1046 |
-
int ${operation_name}_shared_memory_size() {
|
| 1047 |
-
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
| 1048 |
-
}
|
| 1049 |
-
|
| 1050 |
-
using GemmType = ${operation_name}_base;
|
| 1051 |
-
|
| 1052 |
-
// Get the params as byte array
|
| 1053 |
-
char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace,
|
| 1054 |
-
int sm_count, int occupancy) {
|
| 1055 |
-
GemmType::Params* params;
|
| 1056 |
-
params = new GemmType::Params(*argument, sm_count, occupancy);
|
| 1057 |
-
|
| 1058 |
-
params->init_workspace(workspace);
|
| 1059 |
-
|
| 1060 |
-
char *bytes = ((char*)(params));
|
| 1061 |
-
char *output = new char[sizeof(GemmType::Params)];
|
| 1062 |
-
for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++)
|
| 1063 |
-
output[i] = bytes[i];
|
| 1064 |
-
|
| 1065 |
-
return output;
|
| 1066 |
-
}
|
| 1067 |
-
|
| 1068 |
-
dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int device_sms, int sm_occupancy) {
|
| 1069 |
-
typename GemmType::Params params(*args, device_sms, sm_occupancy);
|
| 1070 |
-
return params.get_grid_dims();
|
| 1071 |
-
}
|
| 1072 |
-
|
| 1073 |
-
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* args, int device_sms, int sm_occupancy) {
|
| 1074 |
-
typename GemmType::Params params(*args, device_sms, sm_occupancy);
|
| 1075 |
-
return params.get_workspace_size();
|
| 1076 |
-
}
|
| 1077 |
-
}
|
| 1078 |
-
"""
|
| 1079 |
-
|
| 1080 |
-
def __init__(self, operation: "GemmOperation"):
|
| 1081 |
-
super(GemmRTUniversalStreamK, self).__init__(operation)
|
| 1082 |
-
self.extra_funcs = {
|
| 1083 |
-
"get_grid_shape": GemmCoord_,
|
| 1084 |
-
"get_kernel_workspace_size": ctypes.c_uint64,
|
| 1085 |
-
}
|
| 1086 |
-
self._occupancy = None
|
| 1087 |
-
self.argument_type, self.epilogue_type = get_gemm_arguments_streamk(operation.epilogue_functor)
|
| 1088 |
-
|
| 1089 |
-
@property
|
| 1090 |
-
def occupancy(self):
|
| 1091 |
-
if self._occupancy is None:
|
| 1092 |
-
err, self._occupancy = cuda.cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(
|
| 1093 |
-
self.kernel, self.threads, self.shared_memory_capacity,
|
| 1094 |
-
cuda.CUoccupancy_flags.CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE)
|
| 1095 |
-
|
| 1096 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 1097 |
-
raise RuntimeError(
|
| 1098 |
-
"CUDA error on call to cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags: "
|
| 1099 |
-
f"{cuda.cuGetErrorString(err)[1]}")
|
| 1100 |
-
return self._occupancy
|
| 1101 |
-
|
| 1102 |
-
def get_device_workspace_size(self, arguments: GemmArguments2xStreamK, device_sms: int, sm_occupancy: int):
|
| 1103 |
-
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()), device_sms, sm_occupancy)
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
################################################################################
|
| 1107 |
-
# Runtime module for GEMM Universal within CUTLASS 3
|
| 1108 |
-
################################################################################
|
| 1109 |
-
|
| 1110 |
-
|
| 1111 |
-
class GemmRTUniversal3x(GemmRTUniversal):
|
| 1112 |
-
"""
|
| 1113 |
-
Manages the CUTLASS runtime components for 3.x kernels
|
| 1114 |
-
"""
|
| 1115 |
-
|
| 1116 |
-
KernelTemplate = r"""
|
| 1117 |
-
|
| 1118 |
-
using Operator = ${operation_name}${operation_suffix};
|
| 1119 |
-
extern "C"
|
| 1120 |
-
__global__ __launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor)
|
| 1121 |
-
void ${operation_name}(__grid_constant__ typename Operator::Params const params) {
|
| 1122 |
-
// Dynamic shared memory base pointer
|
| 1123 |
-
extern __shared__ char smem[];
|
| 1124 |
-
|
| 1125 |
-
// Declare pointer to dynamic shared memory.
|
| 1126 |
-
Operator op;
|
| 1127 |
-
op(params, smem);
|
| 1128 |
-
}
|
| 1129 |
-
"""
|
| 1130 |
-
HostTemplate = r"""
|
| 1131 |
-
extern "C" {
|
| 1132 |
-
// Get the size of params in bytes
|
| 1133 |
-
int ${operation_name}_get_param_size(){
|
| 1134 |
-
return sizeof(${operation_name}${operation_suffix}::Params);
|
| 1135 |
-
}
|
| 1136 |
-
|
| 1137 |
-
// Get the size of dynamic shared memory in bytes
|
| 1138 |
-
int ${operation_name}_shared_memory_size() {
|
| 1139 |
-
return ${operation_name}${operation_suffix}::SharedStorageSize;
|
| 1140 |
-
}
|
| 1141 |
-
|
| 1142 |
-
using GemmType = ${operation_name}_base;
|
| 1143 |
-
|
| 1144 |
-
bool ${operation_name}_uses_default_epilogue() {
|
| 1145 |
-
return std::is_same_v<GemmType::CollectiveEpilogue::DispatchPolicy, cutlass::gemm::EpilogueDefault>;
|
| 1146 |
-
}
|
| 1147 |
-
|
| 1148 |
-
// Get the workspace size
|
| 1149 |
-
uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) {
|
| 1150 |
-
return GemmType::get_workspace_size(*argument);
|
| 1151 |
-
}
|
| 1152 |
-
|
| 1153 |
-
// Get the params as byte array
|
| 1154 |
-
char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){
|
| 1155 |
-
GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace);
|
| 1156 |
-
char *bytes = ((char*)(¶ms));
|
| 1157 |
-
char *output = new char[sizeof(GemmType::Params)];
|
| 1158 |
-
for (unsigned int i = 0; i < sizeof(GemmType::Params); i ++)
|
| 1159 |
-
output[i] = bytes[i];
|
| 1160 |
-
|
| 1161 |
-
return output;
|
| 1162 |
-
}
|
| 1163 |
-
|
| 1164 |
-
// Get the total number of blocks for a persistent kernel
|
| 1165 |
-
uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) {
|
| 1166 |
-
auto problem_shape_MNKL = append<4>(problem, Int<1>{});
|
| 1167 |
-
auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] =
|
| 1168 |
-
cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl(
|
| 1169 |
-
problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{});
|
| 1170 |
-
return problem_blocks_m * problem_blocks_n * problem_blocks_l;
|
| 1171 |
-
}
|
| 1172 |
-
|
| 1173 |
-
// Get the grid shape
|
| 1174 |
-
dim3 ${operation_name}_get_grid_shape(GemmType::Arguments* args, int* workspace) {
|
| 1175 |
-
auto tmp_params = GemmType::to_underlying_arguments(*args, workspace);
|
| 1176 |
-
return GemmType::get_grid_shape(tmp_params);
|
| 1177 |
-
}
|
| 1178 |
-
|
| 1179 |
-
// Get the block shape
|
| 1180 |
-
dim3 ${operation_name}_get_block_shape() {
|
| 1181 |
-
return GemmType::get_block_shape();
|
| 1182 |
-
}
|
| 1183 |
-
}
|
| 1184 |
-
"""
|
| 1185 |
-
|
| 1186 |
-
def __init__(self, operation):
|
| 1187 |
-
super(GemmRTUniversal3x, self).__init__(operation)
|
| 1188 |
-
self.extra_funcs = {
|
| 1189 |
-
"get_grid_shape": dim3_,
|
| 1190 |
-
"get_block_shape": dim3_,
|
| 1191 |
-
"get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64,
|
| 1192 |
-
"get_kernel_workspace_size": ctypes.c_uint64,
|
| 1193 |
-
"uses_default_epilogue": ctypes.c_bool,
|
| 1194 |
-
}
|
| 1195 |
-
self.emitter = EmitGemmUniversalInstance3x("_type")
|
| 1196 |
-
|
| 1197 |
-
def get_device_workspace_size(self, arguments: GemmArguments3x):
|
| 1198 |
-
return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()))
|
| 1199 |
-
|
| 1200 |
-
|
| 1201 |
-
class EmitGemmUniversalInstance3x:
|
| 1202 |
-
"""Responsible for emitting a CUTLASS 3 template definition"""
|
| 1203 |
-
|
| 1204 |
-
def __init__(self, operation_suffix=""):
|
| 1205 |
-
self.operation_suffix = operation_suffix
|
| 1206 |
-
self.includes = [
|
| 1207 |
-
"cutlass/cutlass.h",
|
| 1208 |
-
"cute/tensor.hpp",
|
| 1209 |
-
"cute/atom/mma_atom.hpp",
|
| 1210 |
-
"cutlass/numeric_types.h",
|
| 1211 |
-
"cutlass/gemm/collective/collective_builder.hpp",
|
| 1212 |
-
"cutlass/gemm/kernel/sm90_tile_scheduler.hpp",
|
| 1213 |
-
"cutlass/gemm/kernel/gemm_universal.hpp",
|
| 1214 |
-
"cutlass/epilogue/collective/collective_builder.hpp",
|
| 1215 |
-
"cutlass/epilogue/collective/default_epilogue.hpp",
|
| 1216 |
-
"cutlass/epilogue/thread/linear_combination.h"
|
| 1217 |
-
]
|
| 1218 |
-
self.gemm_template_kernel = """
|
| 1219 |
-
using namespace cute;
|
| 1220 |
-
|
| 1221 |
-
using CollectiveEpilogue =
|
| 1222 |
-
typename cutlass::epilogue::collective::CollectiveBuilder<
|
| 1223 |
-
${arch}, ${opcode_class},
|
| 1224 |
-
cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
|
| 1225 |
-
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
| 1226 |
-
cutlass::epilogue::collective::EpilogueTileAuto,
|
| 1227 |
-
${element_accumulator}, ${element_epilogue},
|
| 1228 |
-
${element_c}, ${layout_c}, ${align_c},
|
| 1229 |
-
${element_d}, ${layout_d}, ${align_d},
|
| 1230 |
-
${epilogue_schedule}
|
| 1231 |
-
>::CollectiveOp;
|
| 1232 |
-
|
| 1233 |
-
using CollectiveMainloop =
|
| 1234 |
-
typename cutlass::gemm::collective::CollectiveBuilder<
|
| 1235 |
-
${arch}, ${opcode_class},
|
| 1236 |
-
${element_a}, ${layout_a}, ${align_a},
|
| 1237 |
-
${element_b}, ${layout_b}, ${align_b},
|
| 1238 |
-
${element_accumulator},
|
| 1239 |
-
cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
|
| 1240 |
-
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
| 1241 |
-
${stage_count_type},
|
| 1242 |
-
${kernel_schedule}
|
| 1243 |
-
>::CollectiveOp;
|
| 1244 |
-
|
| 1245 |
-
// Gemm operator ${operation_name}
|
| 1246 |
-
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
|
| 1247 |
-
Shape<int,int,int,int>,
|
| 1248 |
-
CollectiveMainloop,
|
| 1249 |
-
CollectiveEpilogue,
|
| 1250 |
-
${tile_scheduler}
|
| 1251 |
-
>;
|
| 1252 |
-
|
| 1253 |
-
// Define named type
|
| 1254 |
-
struct ${operation_name}${operation_suffix} :
|
| 1255 |
-
public ${operation_name}_base { };
|
| 1256 |
-
"""
|
| 1257 |
-
self.gemm_template_kernel_visitor = """
|
| 1258 |
-
using namespace cute;
|
| 1259 |
-
|
| 1260 |
-
${callback_decl}
|
| 1261 |
-
|
| 1262 |
-
using CollectiveEpilogue =
|
| 1263 |
-
typename cutlass::epilogue::collective::CollectiveBuilder<
|
| 1264 |
-
${arch}, ${opcode_class},
|
| 1265 |
-
cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
|
| 1266 |
-
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
| 1267 |
-
cutlass::epilogue::collective::EpilogueTileAuto,
|
| 1268 |
-
${element_accumulator}, ${element_epilogue},
|
| 1269 |
-
ElementC, StrideC, ${align_c},
|
| 1270 |
-
ElementD, StrideD, ${align_d},
|
| 1271 |
-
${epilogue_schedule},
|
| 1272 |
-
${callback_name}
|
| 1273 |
-
>::CollectiveOp;
|
| 1274 |
-
|
| 1275 |
-
using CollectiveMainloop =
|
| 1276 |
-
typename cutlass::gemm::collective::CollectiveBuilder<
|
| 1277 |
-
${arch}, ${opcode_class},
|
| 1278 |
-
${element_a}, ${layout_a}, ${align_a},
|
| 1279 |
-
${element_b}, ${layout_b}, ${align_b},
|
| 1280 |
-
${element_accumulator},
|
| 1281 |
-
cute::Shape<cute::_${threadblock_shape_m}, cute::_${threadblock_shape_n}, cute::_${threadblock_shape_k}>,
|
| 1282 |
-
cute::Shape<cute::_${cluster_m},cute::_${cluster_n},cute::_${cluster_k}>,
|
| 1283 |
-
${stage_count_type},
|
| 1284 |
-
${kernel_schedule}
|
| 1285 |
-
>::CollectiveOp;
|
| 1286 |
-
|
| 1287 |
-
// Gemm operator ${operation_name}
|
| 1288 |
-
using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal<
|
| 1289 |
-
Shape<int,int,int,int>,
|
| 1290 |
-
CollectiveMainloop,
|
| 1291 |
-
CollectiveEpilogue,
|
| 1292 |
-
${tile_scheduler}
|
| 1293 |
-
>;
|
| 1294 |
-
|
| 1295 |
-
// Define named type
|
| 1296 |
-
struct ${operation_name}${operation_suffix} :
|
| 1297 |
-
public ${operation_name}_base { };
|
| 1298 |
-
"""
|
| 1299 |
-
|
| 1300 |
-
self.gemm_template_device = self.gemm_template_kernel + """
|
| 1301 |
-
|
| 1302 |
-
// Define device-level operator
|
| 1303 |
-
using DeviceKernel = cutlass::gemm::device::GemmUniversalAdapter<${operation_name}${operation_suffix}>;
|
| 1304 |
-
"""
|
| 1305 |
-
|
| 1306 |
-
def emit(self, operation):
|
| 1307 |
-
# Support built-in epilogue functors or user-defined functions
|
| 1308 |
-
|
| 1309 |
-
if operation.tile_description.stages is None or operation.tile_description.stages == 0:
|
| 1310 |
-
stage_count_type = "cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>"
|
| 1311 |
-
else:
|
| 1312 |
-
stage_count_type = "_" + str(operation.tile_description.stages)
|
| 1313 |
-
|
| 1314 |
-
if operation.emission_type == EmissionType.Kernel:
|
| 1315 |
-
gemm_template = self.gemm_template_kernel
|
| 1316 |
-
else:
|
| 1317 |
-
gemm_template = self.gemm_template_device
|
| 1318 |
-
|
| 1319 |
-
kschedule = KernelScheduleType.ScheduleAuto
|
| 1320 |
-
eschedule = EpilogueScheduleType.ScheduleAuto
|
| 1321 |
-
tschedule = TileSchedulerType.Default
|
| 1322 |
-
if operation.tile_description.kernel_schedule is not None:
|
| 1323 |
-
kschedule = operation.tile_description.kernel_schedule
|
| 1324 |
-
if operation.tile_description.epilogue_schedule is not None:
|
| 1325 |
-
eschedule = operation.tile_description.epilogue_schedule
|
| 1326 |
-
if operation.tile_description.tile_scheduler is not None:
|
| 1327 |
-
tschedule = operation.tile_description.tile_scheduler
|
| 1328 |
-
|
| 1329 |
-
emit_tile_m, emit_tile_n, emit_tile_k = operation.tile_description.blackwell_threadblock_shape
|
| 1330 |
-
|
| 1331 |
-
values = {
|
| 1332 |
-
"operation_name": operation.procedural_name(),
|
| 1333 |
-
"operation_suffix": self.operation_suffix,
|
| 1334 |
-
"element_a": DataTypeTag[operation.A.element],
|
| 1335 |
-
"layout_a": LayoutTag[operation.A.layout],
|
| 1336 |
-
"element_b": DataTypeTag[operation.B.element],
|
| 1337 |
-
"layout_b": LayoutTag[operation.B.layout],
|
| 1338 |
-
"element_c": DataTypeTag[operation.C.element],
|
| 1339 |
-
"layout_c": LayoutTag[operation.C.layout],
|
| 1340 |
-
"element_d": DataTypeTag[operation.epilogue_functor.element_output],
|
| 1341 |
-
"layout_d": LayoutTag[operation.C.layout],
|
| 1342 |
-
"element_accumulator": DataTypeTag[operation.accumulator_type()],
|
| 1343 |
-
"element_epilogue": DataTypeTag[operation.epilogue_functor.element_epilogue],
|
| 1344 |
-
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 1345 |
-
"arch": "cutlass::arch::Sm%d" % operation.arch,
|
| 1346 |
-
"threadblock_shape_m": str(emit_tile_m),
|
| 1347 |
-
"threadblock_shape_n": str(emit_tile_n),
|
| 1348 |
-
"threadblock_shape_k": str(emit_tile_k),
|
| 1349 |
-
"cluster_m": str(operation.tile_description.cluster_shape[0]),
|
| 1350 |
-
"cluster_n": str(operation.tile_description.cluster_shape[1]),
|
| 1351 |
-
"cluster_k": str(operation.tile_description.cluster_shape[2]),
|
| 1352 |
-
"align_a": str(operation.A.alignment),
|
| 1353 |
-
"align_b": str(operation.B.alignment),
|
| 1354 |
-
"align_c": str(operation.C.alignment),
|
| 1355 |
-
"align_d": str(operation.C.alignment),
|
| 1356 |
-
"stage_count_type": stage_count_type,
|
| 1357 |
-
"kernel_schedule": KernelScheduleTag[kschedule],
|
| 1358 |
-
"epilogue_schedule": EpilogueScheduleTag[eschedule],
|
| 1359 |
-
"tile_scheduler": TileSchedulerTag[tschedule]
|
| 1360 |
-
}
|
| 1361 |
-
if hasattr(operation.epilogue_functor, "visitor"):
|
| 1362 |
-
callback_name, callback_decl = operation.epilogue_functor.emit(operation)
|
| 1363 |
-
values["callback_name"] = callback_name
|
| 1364 |
-
values["callback_decl"] = callback_decl
|
| 1365 |
-
return SubstituteTemplate(self.gemm_template_kernel_visitor, values)
|
| 1366 |
-
|
| 1367 |
-
else:
|
| 1368 |
-
values["epilogue_functor"] = operation.epilogue_functor.emit()
|
| 1369 |
-
return SubstituteTemplate(gemm_template, values)
|
| 1370 |
-
|
| 1371 |
-
|
| 1372 |
-
###################################################################################################
|
| 1373 |
-
# Runtime module for GEMM Grouped
|
| 1374 |
-
###################################################################################################
|
| 1375 |
-
|
| 1376 |
-
|
| 1377 |
-
class GemmRTGrouped(GemmRTbase):
|
| 1378 |
-
"""
|
| 1379 |
-
GemmRTGrouped manages the CUTLASS runtime components
|
| 1380 |
-
"""
|
| 1381 |
-
|
| 1382 |
-
KernelTemplate = r"""
|
| 1383 |
-
extern "C"
|
| 1384 |
-
__global__ void
|
| 1385 |
-
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
| 1386 |
-
|
| 1387 |
-
// Dynamic shared memory base pointer
|
| 1388 |
-
extern __shared__ int SharedStorageBase[];
|
| 1389 |
-
|
| 1390 |
-
// Declare pointer to dynamic shared memory.
|
| 1391 |
-
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
| 1392 |
-
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
| 1393 |
-
|
| 1394 |
-
${operation_name}${operation_suffix} op;
|
| 1395 |
-
|
| 1396 |
-
op(params, *shared_storage);
|
| 1397 |
-
}
|
| 1398 |
-
"""
|
| 1399 |
-
|
| 1400 |
-
HostTemplate = r"""
|
| 1401 |
-
extern "C" {
|
| 1402 |
-
|
| 1403 |
-
// precompute scheduling information
|
| 1404 |
-
char * ${operation_name}_precompute(${operation_name}_base::Arguments const &args, int tile_count, size_t workspace_bytes) {
|
| 1405 |
-
char* host_workspace = new char[workspace_bytes];
|
| 1406 |
-
${operation_name}_base::ProblemVisitor::host_precompute(
|
| 1407 |
-
args.host_problem_sizes,
|
| 1408 |
-
args.problem_count,
|
| 1409 |
-
args.threadblock_count,
|
| 1410 |
-
(void*)host_workspace
|
| 1411 |
-
);
|
| 1412 |
-
return host_workspace;
|
| 1413 |
-
}
|
| 1414 |
-
|
| 1415 |
-
// Get the size of params in bytes
|
| 1416 |
-
int ${operation_name}_get_param_size(){
|
| 1417 |
-
return sizeof(${operation_name}${operation_suffix}::Params);
|
| 1418 |
-
}
|
| 1419 |
-
|
| 1420 |
-
// Get the size of dynamic shared memory in bytes
|
| 1421 |
-
int ${operation_name}_shared_memory_size() {
|
| 1422 |
-
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
| 1423 |
-
}
|
| 1424 |
-
|
| 1425 |
-
// Get the params as byte array
|
| 1426 |
-
char* ${operation_name}_get_params(${operation_name}_base::Arguments* argument, int tile_count, void* workspace=nullptr){
|
| 1427 |
-
${operation_name}_base::Params* params;
|
| 1428 |
-
params = new ${operation_name}_base::Params(*argument, workspace, tile_count);
|
| 1429 |
-
|
| 1430 |
-
char *bytes = ((char*)(params));
|
| 1431 |
-
char *output = new char[sizeof(${operation_name}_base::Params)];
|
| 1432 |
-
for (unsigned int i = 0; i < sizeof(${operation_name}_base::Params); i ++)
|
| 1433 |
-
output[i] = bytes[i];
|
| 1434 |
-
|
| 1435 |
-
return output;
|
| 1436 |
-
}
|
| 1437 |
-
|
| 1438 |
-
cutlass::gemm::GemmCoord ${operation_name}_get_tiled_shape(
|
| 1439 |
-
cutlass::gemm::GemmCoord problem_size, cutlass::gemm::GemmCoord tile_size, int split_k_slices) {
|
| 1440 |
-
return ${operation_name}_base::ThreadblockSwizzle::get_tiled_shape(
|
| 1441 |
-
problem_size, tile_size, split_k_slices);
|
| 1442 |
-
}
|
| 1443 |
-
|
| 1444 |
-
dim3 ${operation_name}_get_grid_shape(cutlass::gemm::GemmCoord tiled_shape) {
|
| 1445 |
-
return ${operation_name}_base::ThreadblockSwizzle::get_grid_shape(tiled_shape);
|
| 1446 |
-
}
|
| 1447 |
-
}
|
| 1448 |
-
"""
|
| 1449 |
-
|
| 1450 |
-
def __init__(self, operation: "GemmOperation"):
|
| 1451 |
-
super(GemmRTGrouped, self).__init__(operation)
|
| 1452 |
-
self.extra_funcs = {
|
| 1453 |
-
"precompute": None,
|
| 1454 |
-
"get_tiled_shape": GemmCoord_,
|
| 1455 |
-
"get_grid_shape": dim3_,
|
| 1456 |
-
}
|
| 1457 |
-
self.emitter = EmitGemmGroupedInstance("_type")
|
| 1458 |
-
self.argument_type, self.epilogue_type = get_gemm_grouped_arguments(operation.epilogue_functor)
|
| 1459 |
-
self.argtype = [ctypes.POINTER(self.argument_type), ctypes.c_int, ctypes.c_void_p]
|
| 1460 |
-
|
| 1461 |
-
def host_precompute(self, arguments, workspace_bytes):
|
| 1462 |
-
self.precompute.argtype = [
|
| 1463 |
-
self.argtype[0], ctypes.c_int, ctypes.c_longlong]
|
| 1464 |
-
self.precompute.restype = ctypes.POINTER(ctypes.c_byte * workspace_bytes)
|
| 1465 |
-
|
| 1466 |
-
problem_info = self.precompute(
|
| 1467 |
-
ctypes.byref(arguments.arguments),
|
| 1468 |
-
arguments.total_tiles,
|
| 1469 |
-
workspace_bytes)
|
| 1470 |
-
problem_info_array = bytearray(problem_info.contents)
|
| 1471 |
-
|
| 1472 |
-
# copy to device memory
|
| 1473 |
-
return todevice(problem_info_array).ptr
|
| 1474 |
-
|
| 1475 |
-
def plan(self, arguments):
|
| 1476 |
-
return LaunchConfiguration(
|
| 1477 |
-
[arguments.total_tiles, 1, 1],
|
| 1478 |
-
[self.threads, 1, 1],
|
| 1479 |
-
self.shared_memory_capacity,
|
| 1480 |
-
)
|
| 1481 |
-
|
| 1482 |
-
def get_workspace_size(self, arguments):
|
| 1483 |
-
if self.operation.precompute_mode == SchedulerMode.Device:
|
| 1484 |
-
return 0
|
| 1485 |
-
elif self.operation.precompute_mode == SchedulerMode.Host:
|
| 1486 |
-
total_tiles = arguments.total_tiles
|
| 1487 |
-
entries_per_block = 1
|
| 1488 |
-
return 8 * entries_per_block * total_tiles # three int32_t
|
| 1489 |
-
|
| 1490 |
-
|
| 1491 |
-
################################################################################
|
| 1492 |
-
# Runtime module for GEMM and grouped GEMM
|
| 1493 |
-
################################################################################
|
| 1494 |
-
|
| 1495 |
-
|
| 1496 |
-
class GemmOperationBase:
|
| 1497 |
-
"""
|
| 1498 |
-
CUTLASS GEMM operation
|
| 1499 |
-
"""
|
| 1500 |
-
|
| 1501 |
-
def __init__(
|
| 1502 |
-
self, gemm_kind, arch, tile_description: TileDescription,
|
| 1503 |
-
A: TensorDescription, B: TensorDescription, C: TensorDescription,
|
| 1504 |
-
epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1,
|
| 1505 |
-
api=ApiVersion.v2x, emission_type=EmissionType.Kernel, **kwargs):
|
| 1506 |
-
self.operation_kind: OperationKind = OperationKind.Gemm
|
| 1507 |
-
self.arch: int = arch
|
| 1508 |
-
self.tile_description: TileDescription = tile_description
|
| 1509 |
-
self.gemm_kind: GemmKind = gemm_kind
|
| 1510 |
-
|
| 1511 |
-
self.api = api
|
| 1512 |
-
self.prefix = "3x" if self.api == ApiVersion.v3x else ""
|
| 1513 |
-
self.emission_type = emission_type
|
| 1514 |
-
|
| 1515 |
-
# Optionally swap the TensorDescriptions for operands A and B and transpose their
|
| 1516 |
-
# layouts. This is needed to mimic the transpose performed by device::GemmUniversal.
|
| 1517 |
-
# The code below uses deep copy to avoid overwritting the original TensorDescription
|
| 1518 |
-
self.switched = (self.api != ApiVersion.v3x and
|
| 1519 |
-
self.emission_type == EmissionType.Kernel and
|
| 1520 |
-
C.layout == LayoutType.ColumnMajor)
|
| 1521 |
-
|
| 1522 |
-
self.A, self.B, self.C = GemmOperationBase.get_operands(A, B, C, self.switched)
|
| 1523 |
-
|
| 1524 |
-
self.epilogue_functor = epilogue_functor
|
| 1525 |
-
self.swizzling_functor = swizzling_functor
|
| 1526 |
-
|
| 1527 |
-
if "direct_store" in kwargs:
|
| 1528 |
-
self.direct_store = kwargs["direct_store"]
|
| 1529 |
-
else:
|
| 1530 |
-
self.direct_store = False
|
| 1531 |
-
|
| 1532 |
-
@staticmethod
|
| 1533 |
-
def get_operands(A: TensorDescription, B: TensorDescription, C: TensorDescription, swap: bool):
|
| 1534 |
-
"""
|
| 1535 |
-
Makes copies of A, B, and C, and possibly transposes their order. If ``swap`` is set,
|
| 1536 |
-
A and B are swapped, and the layout of A, B, and C are transposed.
|
| 1537 |
-
|
| 1538 |
-
:param A: description of operand A
|
| 1539 |
-
:type A: TensorDescription
|
| 1540 |
-
:param B: description of operand B
|
| 1541 |
-
:type B: TensorDescription
|
| 1542 |
-
:param C: description of operand C
|
| 1543 |
-
:type C: TensorDescription
|
| 1544 |
-
|
| 1545 |
-
:return: descriptions of operands A, B, and C
|
| 1546 |
-
:rtype: tuple[TileDescription]
|
| 1547 |
-
"""
|
| 1548 |
-
if swap:
|
| 1549 |
-
A_out = copy.deepcopy(B)
|
| 1550 |
-
B_out = copy.deepcopy(A)
|
| 1551 |
-
C_out = copy.deepcopy(C)
|
| 1552 |
-
A_out.layout = transpose_layout(A_out.layout)
|
| 1553 |
-
B_out.layout = transpose_layout(B_out.layout)
|
| 1554 |
-
C_out.layout = transpose_layout(C_out.layout)
|
| 1555 |
-
else:
|
| 1556 |
-
A_out = copy.deepcopy(A)
|
| 1557 |
-
B_out = copy.deepcopy(B)
|
| 1558 |
-
C_out = copy.deepcopy(C)
|
| 1559 |
-
return A_out, B_out, C_out
|
| 1560 |
-
|
| 1561 |
-
def run(self, arguments: GemmArguments) -> cuda.CUresult:
|
| 1562 |
-
"""
|
| 1563 |
-
Configure and launch the cuda kernel with input arguments
|
| 1564 |
-
"""
|
| 1565 |
-
if self.emission_type == EmissionType.Device:
|
| 1566 |
-
raise Exception('Running a kernel via PyCUTLASS is only enabled with emission type "Kernel"')
|
| 1567 |
-
|
| 1568 |
-
err = self.rt_module.run(
|
| 1569 |
-
arguments.host_workspace,
|
| 1570 |
-
arguments.device_workspace,
|
| 1571 |
-
arguments.launch_config,
|
| 1572 |
-
arguments.stream
|
| 1573 |
-
)
|
| 1574 |
-
|
| 1575 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 1576 |
-
raise RuntimeError("CUDA Error %s" % str(err))
|
| 1577 |
-
|
| 1578 |
-
return err
|
| 1579 |
-
|
| 1580 |
-
def is_complex(self):
|
| 1581 |
-
complex_operators = [
|
| 1582 |
-
MathOperation.multiply_add_complex,
|
| 1583 |
-
MathOperation.multiply_add_complex_gaussian,
|
| 1584 |
-
MathOperation.multiply_add_complex_fast_f32,
|
| 1585 |
-
]
|
| 1586 |
-
return self.tile_description.math_instruction.math_operation in complex_operators
|
| 1587 |
-
|
| 1588 |
-
def is_planar_complex(self):
|
| 1589 |
-
return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray)
|
| 1590 |
-
|
| 1591 |
-
def accumulator_type(self):
|
| 1592 |
-
accum = self.tile_description.math_instruction.element_accumulator
|
| 1593 |
-
|
| 1594 |
-
if self.is_complex():
|
| 1595 |
-
return get_complex_from_real(accum)
|
| 1596 |
-
|
| 1597 |
-
return accum
|
| 1598 |
-
|
| 1599 |
-
def short_math_name(self):
|
| 1600 |
-
if self.tile_description.math_instruction.math_operation == MathOperation.multiply_add_complex_gaussian:
|
| 1601 |
-
return "g%s" % ShortDataTypeNames[self.accumulator_type()]
|
| 1602 |
-
return ShortDataTypeNames[self.accumulator_type()]
|
| 1603 |
-
|
| 1604 |
-
def core_name(self):
|
| 1605 |
-
"""The basic operation kind is prefixed with a letter indicating the accumulation type."""
|
| 1606 |
-
|
| 1607 |
-
inst_shape = ""
|
| 1608 |
-
inst_operation = ""
|
| 1609 |
-
intermediate_type = ""
|
| 1610 |
-
|
| 1611 |
-
math_operations_map = {
|
| 1612 |
-
MathOperation.xor_popc: "xor",
|
| 1613 |
-
}
|
| 1614 |
-
|
| 1615 |
-
if (self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or
|
| 1616 |
-
self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp):
|
| 1617 |
-
math_op = self.tile_description.math_instruction.math_operation
|
| 1618 |
-
math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else ""
|
| 1619 |
-
|
| 1620 |
-
if self.tile_description.math_instruction.instruction_shape is not None:
|
| 1621 |
-
if self.api == ApiVersion.v3x and self.arch >= 90:
|
| 1622 |
-
inst_shape = "%dx%dx%d" % tuple(
|
| 1623 |
-
self.tile_description.math_instruction.instruction_shape)
|
| 1624 |
-
else:
|
| 1625 |
-
inst_shape = "%d%d%d" % tuple(
|
| 1626 |
-
self.tile_description.math_instruction.instruction_shape)
|
| 1627 |
-
else:
|
| 1628 |
-
inst_shape = "Default"
|
| 1629 |
-
inst_shape += math_op_string
|
| 1630 |
-
|
| 1631 |
-
if (self.tile_description.math_instruction.element_a != self.A.element and
|
| 1632 |
-
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator):
|
| 1633 |
-
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 1634 |
-
|
| 1635 |
-
return "%s%s%s%s" % (self.short_math_name(), inst_shape, intermediate_type, GemmKindNames[self.gemm_kind])
|
| 1636 |
-
|
| 1637 |
-
def extended_name(self):
|
| 1638 |
-
"""Append data types if they differ from compute type."""
|
| 1639 |
-
if self.is_complex():
|
| 1640 |
-
extended_name = "${core_name}"
|
| 1641 |
-
else:
|
| 1642 |
-
if (self.C.element != self.tile_description.math_instruction.element_accumulator and
|
| 1643 |
-
self.A.element != self.tile_description.math_instruction.element_accumulator):
|
| 1644 |
-
extended_name = "${element_c}_${core_name}_${element_a}"
|
| 1645 |
-
elif (self.C.element == self.tile_description.math_instruction.element_accumulator and
|
| 1646 |
-
self.A.element != self.tile_description.math_instruction.element_accumulator):
|
| 1647 |
-
extended_name = "${core_name}_${element_a}"
|
| 1648 |
-
else:
|
| 1649 |
-
extended_name = "${core_name}"
|
| 1650 |
-
|
| 1651 |
-
extended_name = SubstituteTemplate(extended_name, {
|
| 1652 |
-
"element_a": DataTypeNames[self.A.element],
|
| 1653 |
-
"element_c": DataTypeNames[self.C.element],
|
| 1654 |
-
"core_name": self.core_name(),
|
| 1655 |
-
})
|
| 1656 |
-
|
| 1657 |
-
return extended_name
|
| 1658 |
-
|
| 1659 |
-
def extended_name_3x(self):
|
| 1660 |
-
"""Generates a string representing the MMA atom. Assumes accumulator type is C type."""
|
| 1661 |
-
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
| 1662 |
-
element_a=DataTypeNames[self.A.element],
|
| 1663 |
-
element_b=DataTypeNames[self.B.element],
|
| 1664 |
-
element_acc=DataTypeNames[self.accumulator_type()],
|
| 1665 |
-
element_c=DataTypeNames[self.C.element],
|
| 1666 |
-
element_d=DataTypeNames[self.epilogue_functor.element_output],
|
| 1667 |
-
core_name=self.core_name())
|
| 1668 |
-
return extended_name
|
| 1669 |
-
|
| 1670 |
-
def layout_name(self):
|
| 1671 |
-
if self.is_complex() or self.is_planar_complex():
|
| 1672 |
-
return "%s%s" % (
|
| 1673 |
-
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
| 1674 |
-
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)]
|
| 1675 |
-
)
|
| 1676 |
-
return "%s%s" % (ShortLayoutTypeNames[self.A.layout], ShortLayoutTypeNames[self.B.layout])
|
| 1677 |
-
|
| 1678 |
-
# Generates a short string representing the ABC layout tags (e.g. ntn or tnn)
|
| 1679 |
-
def layout_name_3x(self):
|
| 1680 |
-
if self.is_complex() or self.is_planar_complex():
|
| 1681 |
-
return "{}{}{}".format(
|
| 1682 |
-
ShortComplexLayoutNames[(self.A.layout, self.A.complex_transform)],
|
| 1683 |
-
ShortComplexLayoutNames[(self.B.layout, self.B.complex_transform)],
|
| 1684 |
-
ShortComplexLayoutNames[(self.C.layout, self.C.complex_transform)])
|
| 1685 |
-
else:
|
| 1686 |
-
return "{}{}{}".format(
|
| 1687 |
-
ShortLayoutTypeNames[self.A.layout],
|
| 1688 |
-
ShortLayoutTypeNames[self.B.layout],
|
| 1689 |
-
ShortLayoutTypeNames[self.C.layout])
|
| 1690 |
-
|
| 1691 |
-
# Generates a short string representing underlying kernel schedule type
|
| 1692 |
-
def kernel_schedule_name_3x(self):
|
| 1693 |
-
if self.tile_description.kernel_schedule is None:
|
| 1694 |
-
return KernelScheduleSuffixes[KernelScheduleType.ScheduleAuto]
|
| 1695 |
-
else:
|
| 1696 |
-
return KernelScheduleSuffixes[self.tile_description.kernel_schedule]
|
| 1697 |
-
|
| 1698 |
-
# Generates a short string representing underlying epilogue schedule type
|
| 1699 |
-
def epilogue_schedule_name_3x(self):
|
| 1700 |
-
if self.tile_description.epilogue_schedule is None:
|
| 1701 |
-
return EpilogueScheduleSuffixes[EpilogueScheduleType.ScheduleAuto]
|
| 1702 |
-
else:
|
| 1703 |
-
return EpilogueScheduleSuffixes[self.tile_description.epilogue_schedule]
|
| 1704 |
-
|
| 1705 |
-
def procedural_name(self):
|
| 1706 |
-
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
|
| 1707 |
-
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 1708 |
-
if self.api == ApiVersion.v3x and self.arch >= 90:
|
| 1709 |
-
kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}"
|
| 1710 |
-
return kernel_name_template.format(
|
| 1711 |
-
p=self.prefix,
|
| 1712 |
-
ar=self.arch,
|
| 1713 |
-
op=opcode_class_name,
|
| 1714 |
-
ex=self.extended_name_3x(),
|
| 1715 |
-
tbm=self.tile_description.threadblock_shape[0],
|
| 1716 |
-
tbn=self.tile_description.threadblock_shape[1],
|
| 1717 |
-
tbk=self.tile_description.threadblock_shape[2],
|
| 1718 |
-
cm=self.tile_description.cluster_shape[0],
|
| 1719 |
-
cn=self.tile_description.cluster_shape[1],
|
| 1720 |
-
ck=self.tile_description.cluster_shape[2],
|
| 1721 |
-
l=self.tile_description.stages,
|
| 1722 |
-
s=self.layout_name_3x(),
|
| 1723 |
-
al=str(self.A.alignment),
|
| 1724 |
-
k=self.kernel_schedule_name_3x(),
|
| 1725 |
-
e=self.epilogue_schedule_name_3x()
|
| 1726 |
-
)
|
| 1727 |
-
else:
|
| 1728 |
-
threadblock = self.tile_description.procedural_name_2x()
|
| 1729 |
-
return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format(
|
| 1730 |
-
p=self.prefix,
|
| 1731 |
-
op=opcode_class_name,
|
| 1732 |
-
ex=self.extended_name(),
|
| 1733 |
-
tb=threadblock,
|
| 1734 |
-
l=self.layout_name(),
|
| 1735 |
-
a=str(self.A.alignment)
|
| 1736 |
-
)
|
| 1737 |
-
|
| 1738 |
-
def configuration_name(self):
|
| 1739 |
-
"""The full procedural name indicates architecture, extended name, tile size, and layout."""
|
| 1740 |
-
return self.procedural_name()
|
| 1741 |
-
|
| 1742 |
-
|
| 1743 |
-
class GemmOperationUniversal(GemmOperationBase):
|
| 1744 |
-
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
|
| 1745 |
-
epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs):
|
| 1746 |
-
api = api_version(arch, tile_description.math_instruction.opcode_class, A.element)
|
| 1747 |
-
super(GemmOperationUniversal, self).__init__(GemmKind.Universal, arch, tile_description,
|
| 1748 |
-
A, B, C, epilogue_functor, swizzling_functor,
|
| 1749 |
-
api=api, **kwargs, )
|
| 1750 |
-
if api == ApiVersion.v3x:
|
| 1751 |
-
if swizzling_functor == SwizzlingFunctor.StreamK:
|
| 1752 |
-
raise Exception("Stream K swizzle functor is currently only supported for CUTLASS 2.x kernels")
|
| 1753 |
-
self.rt_module = GemmRTUniversal3x(self)
|
| 1754 |
-
else:
|
| 1755 |
-
if swizzling_functor == SwizzlingFunctor.StreamK:
|
| 1756 |
-
self.rt_module = GemmRTUniversalStreamK(self)
|
| 1757 |
-
else:
|
| 1758 |
-
self.rt_module = GemmRTUniversal(self)
|
| 1759 |
-
self.argument_type = self.rt_module.argument_type
|
| 1760 |
-
self.epilogue_type = self.rt_module.epilogue_type
|
| 1761 |
-
|
| 1762 |
-
def device_op(self):
|
| 1763 |
-
"""
|
| 1764 |
-
Returns a new GemmOperationUniversal object that is constructed with emission type
|
| 1765 |
-
``EmissionType.Device``. Since the device-emitted kernel does not require swapping,
|
| 1766 |
-
any swappng performed by the kernel-emitted operation is reversed.
|
| 1767 |
-
|
| 1768 |
-
:return: operation ready for device-level code emission
|
| 1769 |
-
:rtype: GemmUniversalOperation
|
| 1770 |
-
"""
|
| 1771 |
-
A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched)
|
| 1772 |
-
return GemmOperationUniversal(self.arch, self.tile_description, A, B, C,
|
| 1773 |
-
self.epilogue_functor, self.swizzling_functor,
|
| 1774 |
-
emission_type=EmissionType.Device, direct_store=self.direct_store)
|
| 1775 |
-
|
| 1776 |
-
|
| 1777 |
-
class GemmOperationGrouped(GemmOperationBase):
|
| 1778 |
-
def __init__(self, arch, tile_description: TileDescription, A: TensorDescription, B, C,
|
| 1779 |
-
epilogue_functor, swizzling_functor=SwizzlingFunctor.Identity1, **kwargs):
|
| 1780 |
-
super(GemmOperationGrouped, self).__init__(GemmKind.Grouped, arch, tile_description,
|
| 1781 |
-
A, B, C, epilogue_functor, swizzling_functor, **kwargs)
|
| 1782 |
-
assert "precompute_mode" in kwargs.keys(), "missing keyword arguement 'precompute_mode'."
|
| 1783 |
-
self.precompute_mode = kwargs["precompute_mode"]
|
| 1784 |
-
self.rt_module = GemmRTGrouped(self)
|
| 1785 |
-
self.argument_type = self.rt_module.argument_type
|
| 1786 |
-
self.epilogue_type = self.rt_module.epilogue_type
|
| 1787 |
-
|
| 1788 |
-
def device_op(self):
|
| 1789 |
-
"""
|
| 1790 |
-
Returns a new GemmOperationGrouped object that is constructed with emission type
|
| 1791 |
-
``EmissionType.Device``. Since the device-emitted kernel does not require swapping,
|
| 1792 |
-
any swappng performed by the kernel-emitted operation is reversed.
|
| 1793 |
-
|
| 1794 |
-
:return: operation ready for device-level code emission
|
| 1795 |
-
:rtype: GemmOperationGrouped
|
| 1796 |
-
"""
|
| 1797 |
-
A, B, C = GemmOperationBase.get_operands(self.A, self.B, self.C, self.switched)
|
| 1798 |
-
return GemmOperationGrouped(
|
| 1799 |
-
self.arch, self.tile_description, A, B, C, self.epilogue_functor,
|
| 1800 |
-
self.swizzling_functor, emission_type=EmissionType.Device,
|
| 1801 |
-
direct_store=self.direct_store, precompute_mode=self.precompute_mode, )
|
| 1802 |
-
|
| 1803 |
-
|
| 1804 |
-
###################################################################################################
|
| 1805 |
-
#
|
| 1806 |
-
# Emits single instances of a CUTLASS device-wide operator
|
| 1807 |
-
#
|
| 1808 |
-
###################################################################################################
|
| 1809 |
-
|
| 1810 |
-
|
| 1811 |
-
class EmitGemmUniversalInstance:
|
| 1812 |
-
"""Responsible for emitting a CUTLASS template definition"""
|
| 1813 |
-
|
| 1814 |
-
def __init__(
|
| 1815 |
-
self,
|
| 1816 |
-
operation_suffix="",
|
| 1817 |
-
direct_store=False
|
| 1818 |
-
):
|
| 1819 |
-
self.operation_suffix = operation_suffix
|
| 1820 |
-
self.direct_store = direct_store
|
| 1821 |
-
self.includes = [
|
| 1822 |
-
"cutlass/cutlass.h",
|
| 1823 |
-
"cutlass/gemm_coord.h",
|
| 1824 |
-
"cutlass/numeric_types.h",
|
| 1825 |
-
"cutlass/arch/arch.h",
|
| 1826 |
-
"cutlass/arch/mma.h",
|
| 1827 |
-
"cutlass/layout/matrix.h",
|
| 1828 |
-
"cutlass/gemm/device/gemm.h",
|
| 1829 |
-
"cutlass/gemm/device/gemm_universal_adapter.h",
|
| 1830 |
-
"cutlass/gemm/kernel/default_gemm_universal.h",
|
| 1831 |
-
]
|
| 1832 |
-
if self.direct_store:
|
| 1833 |
-
self.includes.append(
|
| 1834 |
-
"cutlass/epilogue/threadblock/default_epilogue_direct_store.h"
|
| 1835 |
-
)
|
| 1836 |
-
self.gemm_template_kernel = """
|
| 1837 |
-
// Gemm operator ${operation_name}
|
| 1838 |
-
using ${operation_name}_base =
|
| 1839 |
-
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
| 1840 |
-
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
| 1841 |
-
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
| 1842 |
-
${element_c}, ${layout_c},
|
| 1843 |
-
${element_accumulator},
|
| 1844 |
-
${opcode_class},
|
| 1845 |
-
${arch},
|
| 1846 |
-
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1847 |
-
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1848 |
-
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1849 |
-
${epilogue_functor},
|
| 1850 |
-
${swizzling_functor},
|
| 1851 |
-
${stages},
|
| 1852 |
-
${math_operation}
|
| 1853 |
-
>::GemmKernel;
|
| 1854 |
-
|
| 1855 |
-
// Define named type
|
| 1856 |
-
struct ${operation_name}${operation_suffix} :
|
| 1857 |
-
public ${operation_name}_base { };
|
| 1858 |
-
"""
|
| 1859 |
-
|
| 1860 |
-
self.gemm_template_device = """
|
| 1861 |
-
// Gemm operator ${operation_name}
|
| 1862 |
-
using DeviceKernel =
|
| 1863 |
-
typename cutlass::gemm::device::GemmUniversal<
|
| 1864 |
-
// Data type and layout of operand A
|
| 1865 |
-
${element_a}, ${layout_a},
|
| 1866 |
-
// Data type and layout of operand B
|
| 1867 |
-
${element_b}, ${layout_b},
|
| 1868 |
-
// Data type and layout of operand C
|
| 1869 |
-
${element_c}, ${layout_c},
|
| 1870 |
-
// Data type of accumulator
|
| 1871 |
-
${element_accumulator},
|
| 1872 |
-
// Class of operation
|
| 1873 |
-
${opcode_class},
|
| 1874 |
-
// Compute capability of the target kernel
|
| 1875 |
-
${arch},
|
| 1876 |
-
// Threadblock tile shape
|
| 1877 |
-
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1878 |
-
// Warp tile shape
|
| 1879 |
-
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1880 |
-
// Instruction shape
|
| 1881 |
-
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1882 |
-
// Epilogue functor
|
| 1883 |
-
${epilogue_functor},
|
| 1884 |
-
// Swizzling function
|
| 1885 |
-
${swizzling_functor},
|
| 1886 |
-
// Number of pipeline stages
|
| 1887 |
-
${stages},
|
| 1888 |
-
// Alignment of operands A and B
|
| 1889 |
-
${align_a}, ${align_b},
|
| 1890 |
-
// Type of math operation
|
| 1891 |
-
${math_operation},
|
| 1892 |
-
// Complex transform types of operands A and B
|
| 1893 |
-
${transform_a}, ${transform_b}
|
| 1894 |
-
>;
|
| 1895 |
-
"""
|
| 1896 |
-
self.gemm_template_direct_store = """
|
| 1897 |
-
// Gemm operator ${operation_name}
|
| 1898 |
-
using ${operation_name}_default =
|
| 1899 |
-
typename cutlass::gemm::kernel::DefaultGemmUniversal<
|
| 1900 |
-
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
| 1901 |
-
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
| 1902 |
-
${element_c}, ${layout_c},
|
| 1903 |
-
${element_accumulator},
|
| 1904 |
-
${opcode_class},
|
| 1905 |
-
${arch},
|
| 1906 |
-
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1907 |
-
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1908 |
-
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1909 |
-
${epilogue_functor},
|
| 1910 |
-
${swizzling_functor},
|
| 1911 |
-
${stages},
|
| 1912 |
-
${math_operation}
|
| 1913 |
-
>::GemmKernel;
|
| 1914 |
-
|
| 1915 |
-
using ${operation_name}_base =
|
| 1916 |
-
cutlass::gemm::kernel::GemmUniversal<
|
| 1917 |
-
${operation_name}_default::Mma,
|
| 1918 |
-
cutlass::epilogue::threadblock::DefaultEpilogueDirectStore<
|
| 1919 |
-
${operation_name}_default::Epilogue
|
| 1920 |
-
>::Epilogue,
|
| 1921 |
-
${operation_name}_default::ThreadblockSwizzle
|
| 1922 |
-
>;
|
| 1923 |
-
|
| 1924 |
-
// Define named type
|
| 1925 |
-
struct ${operation_name}${operation_suffix} :
|
| 1926 |
-
public ${operation_name}_base { };
|
| 1927 |
-
"""
|
| 1928 |
-
self.gemm_template_kernel_visitor = """
|
| 1929 |
-
|
| 1930 |
-
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
| 1931 |
-
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1932 |
-
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1933 |
-
${element_c},
|
| 1934 |
-
${align_c},
|
| 1935 |
-
${epilogue_stages} /* epilogue stages */
|
| 1936 |
-
>;
|
| 1937 |
-
|
| 1938 |
-
${callback_decl}
|
| 1939 |
-
|
| 1940 |
-
// Gemm operator ${operation_name}
|
| 1941 |
-
using ${operation_name}_base =
|
| 1942 |
-
typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
| 1943 |
-
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
| 1944 |
-
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
| 1945 |
-
${element_c}, ${layout_c}, ${align_c},
|
| 1946 |
-
${element_accumulator},
|
| 1947 |
-
${element_epilogue},
|
| 1948 |
-
${opcode_class},
|
| 1949 |
-
${arch},
|
| 1950 |
-
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 1951 |
-
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 1952 |
-
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 1953 |
-
${callback_name},
|
| 1954 |
-
${swizzling_functor},
|
| 1955 |
-
${stages},
|
| 1956 |
-
${math_operation},
|
| 1957 |
-
${epilogue_stages} /* epilogue stages */
|
| 1958 |
-
>::GemmKernel;
|
| 1959 |
-
|
| 1960 |
-
// Define named type
|
| 1961 |
-
struct ${operation_name}${operation_suffix} :
|
| 1962 |
-
public ${operation_name}_base { };
|
| 1963 |
-
"""
|
| 1964 |
-
|
| 1965 |
-
def instance_template(self):
|
| 1966 |
-
return """
|
| 1967 |
-
${compile_guard_start}
|
| 1968 |
-
manifest.append(new ${gemm_kind}<
|
| 1969 |
-
cutlass::gemm::device::GemmUniversalAdapter<${operation_name}>
|
| 1970 |
-
>("${operation_name}"));
|
| 1971 |
-
${compile_guard_end}
|
| 1972 |
-
"""
|
| 1973 |
-
|
| 1974 |
-
def emit(self, operation):
|
| 1975 |
-
threadblock_shape = operation.tile_description.threadblock_shape
|
| 1976 |
-
warp_count = operation.tile_description.warp_count
|
| 1977 |
-
|
| 1978 |
-
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 1979 |
-
|
| 1980 |
-
instance_layout_A, instance_layout_B, instance_layout_C = \
|
| 1981 |
-
(operation.A.layout, operation.B.layout, operation.C.layout)
|
| 1982 |
-
|
| 1983 |
-
if operation.emission_type == EmissionType.Kernel:
|
| 1984 |
-
if self.direct_store:
|
| 1985 |
-
gemm_template = self.gemm_template_direct_store
|
| 1986 |
-
else:
|
| 1987 |
-
gemm_template = self.gemm_template_kernel
|
| 1988 |
-
else:
|
| 1989 |
-
gemm_template = self.gemm_template_device
|
| 1990 |
-
|
| 1991 |
-
values = {
|
| 1992 |
-
"operation_name": operation.procedural_name(),
|
| 1993 |
-
"operation_suffix": self.operation_suffix,
|
| 1994 |
-
"element_a": DataTypeTag[operation.A.element],
|
| 1995 |
-
"layout_a": LayoutTag[instance_layout_A],
|
| 1996 |
-
"element_b": DataTypeTag[operation.B.element],
|
| 1997 |
-
"layout_b": LayoutTag[instance_layout_B],
|
| 1998 |
-
"element_c": DataTypeTag[operation.C.element],
|
| 1999 |
-
"layout_c": LayoutTag[instance_layout_C],
|
| 2000 |
-
"element_accumulator": DataTypeTag[operation.accumulator_type()],
|
| 2001 |
-
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 2002 |
-
"arch": "cutlass::arch::Sm%d" % operation.arch,
|
| 2003 |
-
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
|
| 2004 |
-
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
|
| 2005 |
-
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
|
| 2006 |
-
"warp_shape_m": str(warp_shape[0]),
|
| 2007 |
-
"warp_shape_n": str(warp_shape[1]),
|
| 2008 |
-
"warp_shape_k": str(warp_shape[2]),
|
| 2009 |
-
"instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 2010 |
-
"instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 2011 |
-
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 2012 |
-
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
|
| 2013 |
-
"stages": str(operation.tile_description.stages),
|
| 2014 |
-
"align_a": str(operation.A.alignment),
|
| 2015 |
-
"align_b": str(operation.B.alignment),
|
| 2016 |
-
"transform_a": ComplexTransformTag[operation.A.complex_transform],
|
| 2017 |
-
"transform_b": ComplexTransformTag[operation.B.complex_transform],
|
| 2018 |
-
"math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 2019 |
-
}
|
| 2020 |
-
|
| 2021 |
-
if hasattr(operation.epilogue_functor, "visitor"):
|
| 2022 |
-
self.includes += [
|
| 2023 |
-
"cutlass/epilogue/threadblock/fusion/visitors.hpp",
|
| 2024 |
-
"cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
| 2025 |
-
]
|
| 2026 |
-
callback_name, callback_decl = operation.epilogue_functor.emit(operation)
|
| 2027 |
-
values["callback_name"] = callback_name
|
| 2028 |
-
values["callback_decl"] = callback_decl
|
| 2029 |
-
values["align_c"] = str(operation.C.alignment)
|
| 2030 |
-
values["element_epilogue"] = DataTypeTag[operation.epilogue_functor.element_epilogue]
|
| 2031 |
-
if hasattr(operation.epilogue_functor, "epilogue_stages"):
|
| 2032 |
-
epilogue_stages = operation.epilogue_functor.epilogue_stages
|
| 2033 |
-
else:
|
| 2034 |
-
epilogue_stages = 1
|
| 2035 |
-
values["epilogue_stages"] = str(epilogue_stages)
|
| 2036 |
-
return SubstituteTemplate(self.gemm_template_kernel_visitor, values)
|
| 2037 |
-
else:
|
| 2038 |
-
values["epilogue_functor"] = operation.epilogue_functor.emit()
|
| 2039 |
-
return SubstituteTemplate(gemm_template, values)
|
| 2040 |
-
|
| 2041 |
-
|
| 2042 |
-
class EmitGemmGroupedInstance:
|
| 2043 |
-
"""Responsible for emitting a CUTLASS template definition"""
|
| 2044 |
-
|
| 2045 |
-
def __init__(self, operation_suffix=""):
|
| 2046 |
-
self.operation_suffix = operation_suffix
|
| 2047 |
-
self.includes = [
|
| 2048 |
-
"cutlass/cutlass.h",
|
| 2049 |
-
"cutlass/numeric_types.h",
|
| 2050 |
-
"cutlass/arch/arch.h",
|
| 2051 |
-
"cutlass/arch/mma.h",
|
| 2052 |
-
"cutlass/layout/matrix.h",
|
| 2053 |
-
"cutlass/gemm/kernel/gemm_grouped.h",
|
| 2054 |
-
"cutlass/gemm/kernel/default_gemm_grouped.h",
|
| 2055 |
-
]
|
| 2056 |
-
self.gemm_template_kernel = """
|
| 2057 |
-
// Gemm operator ${operation_name}
|
| 2058 |
-
using ${operation_name}_base =
|
| 2059 |
-
typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
| 2060 |
-
${element_a}, ${layout_a}, ${transform_a}, ${align_a},
|
| 2061 |
-
${element_b}, ${layout_b}, ${transform_b}, ${align_b},
|
| 2062 |
-
${element_c}, ${layout_c},
|
| 2063 |
-
${element_accumulator},
|
| 2064 |
-
${opcode_class},
|
| 2065 |
-
${arch},
|
| 2066 |
-
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 2067 |
-
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 2068 |
-
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 2069 |
-
${epilogue_functor},
|
| 2070 |
-
${swizzling_functor},
|
| 2071 |
-
${stages},
|
| 2072 |
-
${precompute_mode},
|
| 2073 |
-
${math_operation}
|
| 2074 |
-
>::GemmKernel;
|
| 2075 |
-
|
| 2076 |
-
// Define named type
|
| 2077 |
-
struct ${operation_name}${operation_suffix} :
|
| 2078 |
-
public ${operation_name}_base { };
|
| 2079 |
-
"""
|
| 2080 |
-
self.gemm_template_device = (
|
| 2081 |
-
self.gemm_template_kernel
|
| 2082 |
-
+ """
|
| 2083 |
-
using DeviceKernel = cutlass::gemm::device::GemmGrouped<${operation_name}_base>;
|
| 2084 |
-
"""
|
| 2085 |
-
)
|
| 2086 |
-
|
| 2087 |
-
def instance_template(self):
|
| 2088 |
-
return """
|
| 2089 |
-
${compile_guard_start}
|
| 2090 |
-
manifest.append(new ${gemm_kind}<
|
| 2091 |
-
cutlass::gemm::device::GemmGrouped<${operation_name}>
|
| 2092 |
-
>("${operation_name}"));
|
| 2093 |
-
${compile_guard_end}
|
| 2094 |
-
"""
|
| 2095 |
-
|
| 2096 |
-
def emit(self, operation):
|
| 2097 |
-
threadblock_shape = operation.tile_description.threadblock_shape
|
| 2098 |
-
warp_count = operation.tile_description.warp_count
|
| 2099 |
-
|
| 2100 |
-
warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)]
|
| 2101 |
-
|
| 2102 |
-
instance_layout_A, instance_layout_B, instance_layout_C = \
|
| 2103 |
-
(operation.A.layout, operation.B.layout, operation.C.layout)
|
| 2104 |
-
|
| 2105 |
-
# Support built-in epilogue functors or user-defined functions
|
| 2106 |
-
epilogue_functor = operation.epilogue_functor.emit()
|
| 2107 |
-
|
| 2108 |
-
values = {
|
| 2109 |
-
"operation_name": operation.procedural_name(),
|
| 2110 |
-
"operation_suffix": self.operation_suffix,
|
| 2111 |
-
"element_a": DataTypeTag[operation.A.element],
|
| 2112 |
-
"layout_a": LayoutTag[instance_layout_A],
|
| 2113 |
-
"element_b": DataTypeTag[operation.B.element],
|
| 2114 |
-
"layout_b": LayoutTag[instance_layout_B],
|
| 2115 |
-
"element_c": DataTypeTag[operation.C.element],
|
| 2116 |
-
"layout_c": LayoutTag[instance_layout_C],
|
| 2117 |
-
"element_accumulator": DataTypeTag[operation.accumulator_type()],
|
| 2118 |
-
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 2119 |
-
"arch": "cutlass::arch::Sm%d" % operation.arch,
|
| 2120 |
-
"threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]),
|
| 2121 |
-
"threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]),
|
| 2122 |
-
"threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]),
|
| 2123 |
-
"warp_shape_m": str(warp_shape[0]),
|
| 2124 |
-
"warp_shape_n": str(warp_shape[1]),
|
| 2125 |
-
"warp_shape_k": str(warp_shape[2]),
|
| 2126 |
-
"instruction_shape_m": str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 2127 |
-
"instruction_shape_n": str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 2128 |
-
"instruction_shape_k": str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 2129 |
-
"epilogue_functor": epilogue_functor,
|
| 2130 |
-
"swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor],
|
| 2131 |
-
"stages": str(operation.tile_description.stages),
|
| 2132 |
-
"align_a": str(operation.A.alignment),
|
| 2133 |
-
"align_b": str(operation.B.alignment),
|
| 2134 |
-
"transform_a": ComplexTransformTag[operation.A.complex_transform],
|
| 2135 |
-
"transform_b": ComplexTransformTag[operation.B.complex_transform],
|
| 2136 |
-
"precompute_mode": SchedulerModeTag[operation.precompute_mode],
|
| 2137 |
-
"math_operation": MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 2138 |
-
}
|
| 2139 |
-
|
| 2140 |
-
if operation.emission_type == EmissionType.Kernel:
|
| 2141 |
-
gemm_template = self.gemm_template_kernel
|
| 2142 |
-
else:
|
| 2143 |
-
gemm_template = self.gemm_template_device
|
| 2144 |
-
|
| 2145 |
-
return SubstituteTemplate(gemm_template, values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/library.py
DELETED
|
@@ -1,509 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Common data types and string names/tags for them
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import enum
|
| 38 |
-
|
| 39 |
-
from cutlass_library import (
|
| 40 |
-
ComplexTransform,
|
| 41 |
-
DataType,
|
| 42 |
-
DataTypeSize,
|
| 43 |
-
EpilogueScheduleType,
|
| 44 |
-
KernelScheduleSuffixes,
|
| 45 |
-
KernelScheduleType,
|
| 46 |
-
MathOperation,
|
| 47 |
-
OpcodeClass,
|
| 48 |
-
TileSchedulerType
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
# The following block implements enum.auto() for Python 3.5 variants that don't include it such
|
| 53 |
-
# as the default 3.5.2 on Ubuntu 16.04.
|
| 54 |
-
#
|
| 55 |
-
# https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility
|
| 56 |
-
|
| 57 |
-
try:
|
| 58 |
-
from enum import auto as enum_auto
|
| 59 |
-
except ImportError:
|
| 60 |
-
__cutlass_library_auto_enum = 0
|
| 61 |
-
|
| 62 |
-
def enum_auto() -> int:
|
| 63 |
-
global __cutlass_library_auto_enum
|
| 64 |
-
i = __cutlass_library_auto_enum
|
| 65 |
-
__cutlass_library_auto_enum += 1
|
| 66 |
-
return i
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
class DataTypeSizeBytes:
|
| 70 |
-
"""
|
| 71 |
-
Static class to mimic the `DataTypeSize` dictionary, but with checks for whether the
|
| 72 |
-
data type key is less than a full byte or a non-integer number of bytes.
|
| 73 |
-
"""
|
| 74 |
-
|
| 75 |
-
@staticmethod
|
| 76 |
-
def __class_getitem__(datatype):
|
| 77 |
-
"""
|
| 78 |
-
Returns the number of bytes in size the data type is. Raises an exception if the data type
|
| 79 |
-
is either less than a full byte or a non-integer number of bytes in size.
|
| 80 |
-
|
| 81 |
-
:param datatype: data type to query
|
| 82 |
-
|
| 83 |
-
:return: number of bytes the data type occupies
|
| 84 |
-
:rtype: int
|
| 85 |
-
"""
|
| 86 |
-
bits = DataTypeSize[datatype]
|
| 87 |
-
if bits < 8:
|
| 88 |
-
raise Exception(
|
| 89 |
-
f"Data type {datatype} is less than one byte in size."
|
| 90 |
-
)
|
| 91 |
-
elif bits % 8 != 0:
|
| 92 |
-
raise Exception(
|
| 93 |
-
f"Data type datatype is not an integer number of bytes."
|
| 94 |
-
)
|
| 95 |
-
return bits // 8
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
class SchedulerMode(enum.Enum):
|
| 99 |
-
Device = enum_auto()
|
| 100 |
-
Host = enum_auto()
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
SchedulerModeTag = {
|
| 104 |
-
SchedulerMode.Device: "cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly",
|
| 105 |
-
SchedulerMode.Host: "cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute",
|
| 106 |
-
}
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
ShortSchedulerModeNames = {SchedulerMode.Device: "Device", SchedulerMode.Host: "Host"}
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class FunctionalOp(enum.Enum):
|
| 113 |
-
AtomicAdd = enum_auto()
|
| 114 |
-
AtomicMaximum = enum_auto()
|
| 115 |
-
Divides = enum_auto()
|
| 116 |
-
Maximum = enum_auto()
|
| 117 |
-
Minimum = enum_auto()
|
| 118 |
-
Minus = enum_auto()
|
| 119 |
-
Multiplies = enum_auto()
|
| 120 |
-
MultiplyAdd = enum_auto()
|
| 121 |
-
Plus = enum_auto()
|
| 122 |
-
Exp = enum_auto()
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
FunctionalOpTag = {
|
| 126 |
-
FunctionalOp.AtomicAdd: "cutlass::atomic_add",
|
| 127 |
-
FunctionalOp.AtomicMaximum: "cutlass::atomic_maximum",
|
| 128 |
-
FunctionalOp.Divides: "cutlass::divides",
|
| 129 |
-
FunctionalOp.Maximum: "cutlass::maximum",
|
| 130 |
-
FunctionalOp.Minimum: "cutlass::minimum",
|
| 131 |
-
FunctionalOp.Minus: "cutlass::minus",
|
| 132 |
-
FunctionalOp.Multiplies: "cutlass::multiplies",
|
| 133 |
-
FunctionalOp.MultiplyAdd: "cutlass::multiply_add",
|
| 134 |
-
FunctionalOp.Plus: "cutlass::plus",
|
| 135 |
-
FunctionalOp.Exp: "cutlass::fast_exp_op",
|
| 136 |
-
}
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
class ActivationOp(enum.Enum):
|
| 140 |
-
DGelu = enum_auto()
|
| 141 |
-
Gelu = enum_auto()
|
| 142 |
-
GeluTaylor = enum_auto()
|
| 143 |
-
HardSwish = enum_auto()
|
| 144 |
-
Identity = enum_auto()
|
| 145 |
-
LeakyReLU = enum_auto()
|
| 146 |
-
ReLU = enum_auto()
|
| 147 |
-
Sigmoid = enum_auto()
|
| 148 |
-
SiLU = enum_auto()
|
| 149 |
-
Tanh = enum_auto()
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
ActivationOpTag = {
|
| 153 |
-
ActivationOp.DGelu: "cutlass::epilogue::thread::dGELU",
|
| 154 |
-
ActivationOp.Gelu: "cutlass::epilogue::thread::GELU",
|
| 155 |
-
ActivationOp.GeluTaylor: "cutlass::epilogue::thread::GELU_taylor",
|
| 156 |
-
ActivationOp.HardSwish: "cutlass::epilogue::thread::HardSwish",
|
| 157 |
-
ActivationOp.Identity: "cutlass::epilogue::thread::Identity",
|
| 158 |
-
ActivationOp.LeakyReLU: "cutlass::epilogue::thread::LeakyReLU",
|
| 159 |
-
ActivationOp.ReLU: "cutlass::epilogue::thread::ReLu",
|
| 160 |
-
ActivationOp.Sigmoid: "cutlass::epilogue::thread::Sigmoid",
|
| 161 |
-
ActivationOp.SiLU: "cutlass::epilogue::thread::SiLu",
|
| 162 |
-
ActivationOp.Tanh: "cutlass::epilogue::thread::Tanh",
|
| 163 |
-
}
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def op_tag(op) -> str:
|
| 167 |
-
"""
|
| 168 |
-
Dispatches `op` to the appropriate *Tag dictionary depending on whether
|
| 169 |
-
`op` is an ActivationOp or FunctionalOp. This is useful for cases in which
|
| 170 |
-
either type can be used.
|
| 171 |
-
|
| 172 |
-
:param op: operation to emit a tag for
|
| 173 |
-
:type op: ActivationOp | FunctionalOp
|
| 174 |
-
|
| 175 |
-
:return: tag corresponding to op
|
| 176 |
-
:rtype: str
|
| 177 |
-
"""
|
| 178 |
-
if isinstance(op, ActivationOp):
|
| 179 |
-
return ActivationOpTag[op]
|
| 180 |
-
elif isinstance(op, FunctionalOp):
|
| 181 |
-
return FunctionalOpTag[op]
|
| 182 |
-
else:
|
| 183 |
-
raise Exception(f"Unexpected op type {op}. Must be one of ActivationOp or FunctionalOp.")
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
class FloatRoundStyle(enum.Enum):
|
| 187 |
-
ToNearest = enum_auto()
|
| 188 |
-
ToNearestSatfinite = enum_auto()
|
| 189 |
-
Indeterminate = enum_auto()
|
| 190 |
-
TowardZero = enum_auto()
|
| 191 |
-
TowardInfinity = enum_auto()
|
| 192 |
-
TowardNegInfinity = enum_auto()
|
| 193 |
-
HalfUlpTruncDntz = enum_auto()
|
| 194 |
-
HalfUlpTruncate = enum_auto()
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
FloatRoundStyleTag = {
|
| 198 |
-
FloatRoundStyle.ToNearest: "cutlass::FloatRoundStyle::round_to_nearest",
|
| 199 |
-
FloatRoundStyle.ToNearestSatfinite: "cutlass::FloatRoundStyle::round_to_nearest_satfinite",
|
| 200 |
-
FloatRoundStyle.Indeterminate: "cutlass::FloatRoundStyle::round_indeterminate",
|
| 201 |
-
FloatRoundStyle.TowardZero: "cutlass::FloatRoundStyle::round_toward_zero",
|
| 202 |
-
FloatRoundStyle.TowardInfinity: "cutlass::FloatRoundStyle::round_toward_infinity",
|
| 203 |
-
FloatRoundStyle.TowardNegInfinity: "cutlass::FloatRoundStyle::round_toward_neg_infinity",
|
| 204 |
-
FloatRoundStyle.HalfUlpTruncDntz: "cutlass::FloatRoundStyle::round_half_ulp_trunc_dntz",
|
| 205 |
-
FloatRoundStyle.HalfUlpTruncate: "cutlass::FloatRoundStyle::round_half_ulp_truncate",
|
| 206 |
-
}
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
class MathInstruction:
|
| 210 |
-
"""
|
| 211 |
-
Description of a the lowest-level matrix-multiply-accumulate operation to be used in a kernel
|
| 212 |
-
"""
|
| 213 |
-
|
| 214 |
-
def __init__(
|
| 215 |
-
self,
|
| 216 |
-
instruction_shape,
|
| 217 |
-
element_a,
|
| 218 |
-
element_b,
|
| 219 |
-
element_accumulator,
|
| 220 |
-
opcode_class=OpcodeClass.Simt,
|
| 221 |
-
math_operation=MathOperation.multiply_add,
|
| 222 |
-
):
|
| 223 |
-
"""
|
| 224 |
-
:param instruction_shape: size of the [M, N, K] dimensions of the instruction
|
| 225 |
-
:type instruction_shape: list or tuple
|
| 226 |
-
:param element_a: data type of operand A
|
| 227 |
-
:param element_b: data type of operand B
|
| 228 |
-
:param element_accumulator: data type used in accumulation
|
| 229 |
-
:param opcode_class: higher-level class of the instruction (e.g., SIMT or Tensor Core)
|
| 230 |
-
:type opcode_class: cutlass_library.library.OpcodeClass
|
| 231 |
-
:param math_operation: the type of low-level operation to be performed (e.g., multiply accumulate)
|
| 232 |
-
:type math_operation: MathOperation
|
| 233 |
-
"""
|
| 234 |
-
self.instruction_shape = instruction_shape
|
| 235 |
-
self.element_a = element_a
|
| 236 |
-
self.element_b = element_b
|
| 237 |
-
self.element_accumulator = element_accumulator
|
| 238 |
-
self.opcode_class = opcode_class
|
| 239 |
-
self.math_operation = math_operation
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
def to_blackwell_threadblock_shape(tile_description, cluster_shape, kernel_schedule):
|
| 243 |
-
blackwell_threadblock_shape = tile_description.threadblock_shape
|
| 244 |
-
is_2sm = False if kernel_schedule is None else ("2sm" in KernelScheduleSuffixes[kernel_schedule])
|
| 245 |
-
if cluster_shape[0] > 0:
|
| 246 |
-
blackwell_threadblock_shape = [
|
| 247 |
-
tile_description.threadblock_shape[0] // cluster_shape[0],
|
| 248 |
-
tile_description.threadblock_shape[1] // cluster_shape[1],
|
| 249 |
-
tile_description.threadblock_shape[2] // cluster_shape[2]
|
| 250 |
-
]
|
| 251 |
-
if is_2sm:
|
| 252 |
-
blackwell_threadblock_shape[0] *= 2
|
| 253 |
-
else:
|
| 254 |
-
blackwell_threadblock_shape = tile_description.math_instruction.instruction_shape
|
| 255 |
-
return blackwell_threadblock_shape, is_2sm
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
class TileDescription:
|
| 259 |
-
"""
|
| 260 |
-
Description of a tile of computation to be performed in the kernel, encompassing threadblock, cluster, and warp shapes,
|
| 261 |
-
stage count, and math instruction specification
|
| 262 |
-
"""
|
| 263 |
-
|
| 264 |
-
def __init__(
|
| 265 |
-
self,
|
| 266 |
-
threadblock_shape,
|
| 267 |
-
stages,
|
| 268 |
-
warp_count,
|
| 269 |
-
math_instruction,
|
| 270 |
-
cluster_shape=[1, 1, 1],
|
| 271 |
-
kernel_schedule: KernelScheduleType = None,
|
| 272 |
-
epilogue_schedule: EpilogueScheduleType = None,
|
| 273 |
-
tile_scheduler: TileSchedulerType = None
|
| 274 |
-
):
|
| 275 |
-
"""
|
| 276 |
-
:param threadblock_shape: shape of a threadblock tyle
|
| 277 |
-
:type threadblock_shape: list or tuple
|
| 278 |
-
:param stages: number of pipline stages in the operation. For SM90 kernels, this can be set to `None` and the maximum
|
| 279 |
-
number of stages that can be supported for an operation on a given architecture will be computed at a later time
|
| 280 |
-
:type stages: int or None
|
| 281 |
-
:param warp_count: number of warps in each [M, N, K] dimension of a threadblock tile
|
| 282 |
-
:type warp_count: list, tuple, or None
|
| 283 |
-
:param math_instruction: specification of the instruction type and shape to be performed and the types of its operands
|
| 284 |
-
:type math_instruction: MathInstruction
|
| 285 |
-
:param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster
|
| 286 |
-
:param kernel_schedule: type of kernel schedule to use (only available for SM90+)
|
| 287 |
-
:type kernel_schedule: cutlass_library.KernelScheduleType
|
| 288 |
-
:param epilogue_schedule: type of epilogue schedule to use (only available for SM90+)
|
| 289 |
-
:type epilogue_schedule: cutlass_library.EpilogueScheduleType
|
| 290 |
-
:param tile_scheduler: type of tile scheduler to use (only available for SM90+)
|
| 291 |
-
:type tile_scheduler: cutlass_library.TileSchedulerType
|
| 292 |
-
"""
|
| 293 |
-
if ((kernel_schedule is None and epilogue_schedule is not None) or
|
| 294 |
-
(kernel_schedule is not None and epilogue_schedule is None)):
|
| 295 |
-
raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.")
|
| 296 |
-
|
| 297 |
-
self.threadblock_shape = threadblock_shape
|
| 298 |
-
self.cluster_shape = cluster_shape
|
| 299 |
-
self.kernel_schedule = kernel_schedule
|
| 300 |
-
self.epilogue_schedule = epilogue_schedule
|
| 301 |
-
self.tile_scheduler = tile_scheduler
|
| 302 |
-
self.stages = stages
|
| 303 |
-
|
| 304 |
-
self.math_instruction = math_instruction
|
| 305 |
-
self.instruction_shape = math_instruction.instruction_shape
|
| 306 |
-
|
| 307 |
-
# Number of warps along x, y, z directions
|
| 308 |
-
self.warp_count = warp_count
|
| 309 |
-
|
| 310 |
-
self.blackwell_threadblock_shape, self.is_2sm = to_blackwell_threadblock_shape(self, self.cluster_shape, self.kernel_schedule)
|
| 311 |
-
|
| 312 |
-
def clone_and_update(self, td: dict):
|
| 313 |
-
attrs = {
|
| 314 |
-
"cluster_shape": None,
|
| 315 |
-
"threadblock_shape": None,
|
| 316 |
-
"warp_count": None,
|
| 317 |
-
"stages": None,
|
| 318 |
-
"instruction_shape": None,
|
| 319 |
-
"kernel_schedule": None,
|
| 320 |
-
"epilogue_schedule": None,
|
| 321 |
-
"tile_scheduler": None
|
| 322 |
-
}
|
| 323 |
-
for key in attrs.keys():
|
| 324 |
-
if key in td.keys():
|
| 325 |
-
attrs[key] = td[key]
|
| 326 |
-
else:
|
| 327 |
-
attrs[key] = getattr(self, key)
|
| 328 |
-
|
| 329 |
-
attrs["math_instruction"] = MathInstruction(
|
| 330 |
-
attrs["instruction_shape"],
|
| 331 |
-
self.math_instruction.element_a,
|
| 332 |
-
self.math_instruction.element_b,
|
| 333 |
-
self.math_instruction.element_accumulator,
|
| 334 |
-
self.math_instruction.opcode_class,
|
| 335 |
-
self.math_instruction.math_operation
|
| 336 |
-
)
|
| 337 |
-
|
| 338 |
-
# Remove the instruction shape
|
| 339 |
-
del attrs["instruction_shape"]
|
| 340 |
-
|
| 341 |
-
return TileDescription(**attrs)
|
| 342 |
-
|
| 343 |
-
@property
|
| 344 |
-
def num_threads(self):
|
| 345 |
-
"""
|
| 346 |
-
Returns the number of threads in the threadblock
|
| 347 |
-
|
| 348 |
-
:return: number of threads in the threadblock
|
| 349 |
-
:rtype: int or None (if warp count is None)
|
| 350 |
-
"""
|
| 351 |
-
if self.warp_count is not None:
|
| 352 |
-
threads = 32
|
| 353 |
-
for cnt in self.warp_count:
|
| 354 |
-
threads *= cnt
|
| 355 |
-
return threads
|
| 356 |
-
return None
|
| 357 |
-
|
| 358 |
-
def procedural_name(self):
|
| 359 |
-
"""
|
| 360 |
-
Returns a name identifying the tile description
|
| 361 |
-
|
| 362 |
-
:return: name identifying the tile description
|
| 363 |
-
:rtype: int
|
| 364 |
-
"""
|
| 365 |
-
emit_stages = 0 if self.stages is None else self.stages
|
| 366 |
-
name = "%dx%dx%d_%dx%d_%dx%d" % (
|
| 367 |
-
self.cluster_shape[0],
|
| 368 |
-
self.cluster_shape[1],
|
| 369 |
-
self.cluster_shape[2],
|
| 370 |
-
self.threadblock_shape[0],
|
| 371 |
-
self.threadblock_shape[1],
|
| 372 |
-
self.threadblock_shape[2],
|
| 373 |
-
emit_stages
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
return name
|
| 377 |
-
|
| 378 |
-
def procedural_name_2x(self):
|
| 379 |
-
"""
|
| 380 |
-
Returns a name identifying the tile description
|
| 381 |
-
|
| 382 |
-
:return: name identifying the tile description
|
| 383 |
-
:rtype: int
|
| 384 |
-
"""
|
| 385 |
-
return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages)
|
| 386 |
-
|
| 387 |
-
def __str__(self):
|
| 388 |
-
"""
|
| 389 |
-
Returns a string with containing each of the tile description's values
|
| 390 |
-
|
| 391 |
-
:return: contents of tile description
|
| 392 |
-
:rtype: str
|
| 393 |
-
"""
|
| 394 |
-
if self.kernel_schedule is not None:
|
| 395 |
-
kschedule = self.kernel_schedule
|
| 396 |
-
else:
|
| 397 |
-
kschedule = KernelScheduleType.ScheduleAuto
|
| 398 |
-
|
| 399 |
-
if self.epilogue_schedule is not None:
|
| 400 |
-
eschedule = self.epilogue_schedule
|
| 401 |
-
else:
|
| 402 |
-
eschedule = EpilogueScheduleType.ScheduleAuto
|
| 403 |
-
|
| 404 |
-
if self.tile_scheduler is not None:
|
| 405 |
-
tschedule = self.tile_scheduler.name
|
| 406 |
-
else:
|
| 407 |
-
tschedule = "None"
|
| 408 |
-
return f"""
|
| 409 |
-
{{
|
| 410 |
-
ClusterShape: {self.cluster_shape}
|
| 411 |
-
ThreadblockShape: {self.threadblock_shape}
|
| 412 |
-
WarpCount: {self.warp_count}
|
| 413 |
-
Stages: {self.stages if self.stages is not None else 'Auto'}
|
| 414 |
-
InstructionShape: {self.math_instruction.instruction_shape}
|
| 415 |
-
Kernel schedule: {kschedule.name}
|
| 416 |
-
Epilogue schedule: {kschedule.name}
|
| 417 |
-
TileScheduler: {tschedule}
|
| 418 |
-
}}"""
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
class TensorDescription:
|
| 422 |
-
def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none):
|
| 423 |
-
self.element = element
|
| 424 |
-
self.layout = layout
|
| 425 |
-
if element != DataType.void:
|
| 426 |
-
self.alignment = min(128 // DataTypeSize[self.element], alignment)
|
| 427 |
-
else:
|
| 428 |
-
self.alignment = alignment
|
| 429 |
-
self.complex_transform = complex_transform
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
def CalculateSmemUsagePerStage(operation):
|
| 433 |
-
"""
|
| 434 |
-
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
|
| 435 |
-
|
| 436 |
-
:param op: operation for which the maximum stages should be computed. If stages are
|
| 437 |
-
set via the `op.tile_description.stages` parameter, this setting is ignored
|
| 438 |
-
in the present calculation
|
| 439 |
-
:type op: cutlass_cppgen.backend.Operation
|
| 440 |
-
|
| 441 |
-
:return: number of bytes of shared memory consumed by a single stage
|
| 442 |
-
:rtype: int
|
| 443 |
-
"""
|
| 444 |
-
m, n, k = operation.tile_description.threadblock_shape
|
| 445 |
-
|
| 446 |
-
if operation.operation_kind == OperationKind.Gemm:
|
| 447 |
-
stage_barrier_bytes = 32
|
| 448 |
-
return (
|
| 449 |
-
(DataTypeSize[operation.A.element] * m * k // 8)
|
| 450 |
-
+ (DataTypeSize[operation.B.element] * k * n // 8)
|
| 451 |
-
+ stage_barrier_bytes
|
| 452 |
-
)
|
| 453 |
-
else:
|
| 454 |
-
raise Exception("Unsupported operation kind {}.".format(operation.operation_kind))
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
def CalculateSmemUsage(operation):
|
| 458 |
-
"""
|
| 459 |
-
Returns the amount of shared memory in bytes consumed by a kernel.
|
| 460 |
-
|
| 461 |
-
:param op: operation for which the maximum stages should be computed. If stages are
|
| 462 |
-
set via the `op.tile_description.stages` parameter, this setting is ignored
|
| 463 |
-
in the present calculation
|
| 464 |
-
:type op: cutlass_cppgen.backend.Operation
|
| 465 |
-
|
| 466 |
-
:return: int
|
| 467 |
-
"""
|
| 468 |
-
return operation.tile_description.stages * CalculateSmemUsagePerStage(operation)
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
class ApiVersion(enum.Enum):
|
| 472 |
-
"""
|
| 473 |
-
Differentiate between CUTLASS 2.x and 3.x API versions
|
| 474 |
-
"""
|
| 475 |
-
|
| 476 |
-
v2x = enum_auto()
|
| 477 |
-
v3x = enum_auto()
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
def api_version(arch, opclass, dtype):
|
| 481 |
-
"""
|
| 482 |
-
Returns whether the architecture, opcode class, and datatype in question require using CUTLASS 2.x
|
| 483 |
-
or 3.x for code emission.
|
| 484 |
-
|
| 485 |
-
:param arch: compute capability of device on which to run
|
| 486 |
-
:type arch: int
|
| 487 |
-
:param opclass: class of the operation being performed
|
| 488 |
-
:type opclass: cutlass_library.OpcodeClass
|
| 489 |
-
:param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same)
|
| 490 |
-
:type dtype: cutlass_library.DataType
|
| 491 |
-
|
| 492 |
-
:return: API version to be used in code emission
|
| 493 |
-
:rtype: ApiVersion
|
| 494 |
-
"""
|
| 495 |
-
if (arch in [90, 100, 101, 103] and
|
| 496 |
-
opclass == OpcodeClass.TensorOp and
|
| 497 |
-
(dtype != DataType.f64)):
|
| 498 |
-
return ApiVersion.v3x
|
| 499 |
-
else:
|
| 500 |
-
return ApiVersion.v2x
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
class EmissionType(enum.Enum):
|
| 504 |
-
"""
|
| 505 |
-
Tags for whether to emit a kernel- or device-level operation
|
| 506 |
-
"""
|
| 507 |
-
|
| 508 |
-
Kernel = enum_auto()
|
| 509 |
-
Device = enum_auto()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/memory_manager.py
DELETED
|
@@ -1,121 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import numpy as np
|
| 34 |
-
|
| 35 |
-
import cutlass_cppgen
|
| 36 |
-
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
| 37 |
-
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 38 |
-
|
| 39 |
-
if cutlass_cppgen.use_rmm:
|
| 40 |
-
import rmm
|
| 41 |
-
else:
|
| 42 |
-
cudart = lazy_import("cuda.cudart")
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
class PoolMemoryManager:
|
| 46 |
-
def __init__(self, init_pool_size: int, max_pool_size: int) -> None:
|
| 47 |
-
self.pool = rmm.mr.PoolMemoryResource(
|
| 48 |
-
rmm.mr.CudaMemoryResource(),
|
| 49 |
-
initial_pool_size=init_pool_size,
|
| 50 |
-
maximum_pool_size=max_pool_size
|
| 51 |
-
)
|
| 52 |
-
self.mr = rmm.mr.TrackingResourceAdaptor(self.pool)
|
| 53 |
-
rmm.mr.set_current_device_resource(self.mr)
|
| 54 |
-
|
| 55 |
-
def pool_size(self):
|
| 56 |
-
return self.pool.pool_size()
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class DevicePtrWrapper:
|
| 60 |
-
"""
|
| 61 |
-
Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer
|
| 62 |
-
(at least in terms of the interface used by the CUTLASS Python interface)
|
| 63 |
-
"""
|
| 64 |
-
def __init__(self, dev_ptr):
|
| 65 |
-
self.dev_ptr = dev_ptr
|
| 66 |
-
|
| 67 |
-
@property
|
| 68 |
-
def ptr(self):
|
| 69 |
-
return self.dev_ptr
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def _todevice(host_data):
|
| 73 |
-
"""
|
| 74 |
-
Helper for transferring host data to device memory
|
| 75 |
-
"""
|
| 76 |
-
if cutlass_cppgen.use_rmm:
|
| 77 |
-
return rmm.DeviceBuffer.to_device(host_data.tobytes())
|
| 78 |
-
else:
|
| 79 |
-
nbytes = len(host_data.tobytes())
|
| 80 |
-
dev_ptr_wrapper = device_mem_alloc(nbytes)
|
| 81 |
-
err, = cudart.cudaMemcpy(
|
| 82 |
-
dev_ptr_wrapper.ptr,
|
| 83 |
-
host_data.__array_interface__['data'][0],
|
| 84 |
-
nbytes,
|
| 85 |
-
cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
|
| 86 |
-
)
|
| 87 |
-
if err != cudart.cudaError_t.cudaSuccess:
|
| 88 |
-
raise Exception(f"cudaMemcpy failed with error {err}")
|
| 89 |
-
return dev_ptr_wrapper
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def todevice(host_data, dtype=np.float32):
|
| 93 |
-
"""
|
| 94 |
-
Pass the host_data to device memory
|
| 95 |
-
"""
|
| 96 |
-
if isinstance(host_data, list):
|
| 97 |
-
return _todevice(np.array(host_data, dtype=dtype))
|
| 98 |
-
elif is_numpy_tensor(host_data):
|
| 99 |
-
return _todevice(host_data)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def device_mem_alloc(size):
|
| 103 |
-
if cutlass_cppgen.use_rmm:
|
| 104 |
-
return rmm.DeviceBuffer(size=size)
|
| 105 |
-
else:
|
| 106 |
-
err, ptr = cudart.cudaMalloc(size)
|
| 107 |
-
if err != cudart.cudaError_t.cudaSuccess:
|
| 108 |
-
raise Exception(f"cudaMalloc failed with error {err}")
|
| 109 |
-
return DevicePtrWrapper(ptr)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
def align_size(size, alignment=256):
|
| 113 |
-
return ((size + alignment - 1) // alignment) * alignment
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34):
|
| 117 |
-
if cutlass_cppgen.use_rmm:
|
| 118 |
-
memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size)
|
| 119 |
-
return memory_pool
|
| 120 |
-
else:
|
| 121 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/operation.py
DELETED
|
@@ -1,140 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import ctypes
|
| 34 |
-
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 35 |
-
cuda = lazy_import("cuda.cuda")
|
| 36 |
-
|
| 37 |
-
from cutlass_cppgen.backend.utils.device import device_cc
|
| 38 |
-
|
| 39 |
-
_supports_cluster_launch = None
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def supports_cluster_launch():
|
| 43 |
-
from cuda import __version__
|
| 44 |
-
_version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")]
|
| 45 |
-
global _supports_cluster_launch
|
| 46 |
-
if _supports_cluster_launch is None:
|
| 47 |
-
major, minor = _version_splits[0], _version_splits[1]
|
| 48 |
-
_supports_cluster_launch = device_cc() in [90, 100, 101, 103] and (major > 11 or (major == 11 and minor >= 8))
|
| 49 |
-
return _supports_cluster_launch
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
class LaunchConfiguration:
|
| 53 |
-
def __init__(self, grid=[1, 1, 1], block=[1, 1, 1], smem=0):
|
| 54 |
-
self.grid = grid
|
| 55 |
-
self.block = block
|
| 56 |
-
self.shared_memory_capacity = smem
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
class ExecutableOperation:
|
| 60 |
-
def __init__(self, operation):
|
| 61 |
-
self.operation = operation
|
| 62 |
-
self.module = None
|
| 63 |
-
self.kernel = None
|
| 64 |
-
|
| 65 |
-
def name(self):
|
| 66 |
-
return self.operation.procedural_name()
|
| 67 |
-
|
| 68 |
-
def emit(self):
|
| 69 |
-
return ""
|
| 70 |
-
|
| 71 |
-
def can_implement(self, configuration, arguments):
|
| 72 |
-
raise NotImplementedError()
|
| 73 |
-
|
| 74 |
-
def get_host_workspace_size(self, arguments):
|
| 75 |
-
raise NotImplementedError()
|
| 76 |
-
|
| 77 |
-
def get_device_workspace_size(self, arguments):
|
| 78 |
-
raise NotImplementedError()
|
| 79 |
-
|
| 80 |
-
def plan(self, arguments):
|
| 81 |
-
raise NotImplementedError()
|
| 82 |
-
|
| 83 |
-
def initialize(self, host_workspace, device_workspace, launch_config, arguments, stream=None):
|
| 84 |
-
raise NotImplementedError()
|
| 85 |
-
|
| 86 |
-
def run_with_clusters(self, launch_config, kernel_params, stream=None):
|
| 87 |
-
if not stream:
|
| 88 |
-
stream = cuda.CUstream(0)
|
| 89 |
-
if hasattr(self.operation, "tile_description") and hasattr(self.operation.tile_description, "cluster_shape"):
|
| 90 |
-
attr = cuda.CUlaunchAttribute()
|
| 91 |
-
attr.value.clusterDim.x, attr.value.clusterDim.y, attr.value.clusterDim.z = self.operation.tile_description.cluster_shape
|
| 92 |
-
attr.id = cuda.CUstreamAttrID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
|
| 93 |
-
attrs = [attr]
|
| 94 |
-
|
| 95 |
-
# Allow for non-portable cluster sizes
|
| 96 |
-
err, = cuda.cuFuncSetAttribute(
|
| 97 |
-
self.kernel, cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1)
|
| 98 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 99 |
-
return err
|
| 100 |
-
else:
|
| 101 |
-
attrs = []
|
| 102 |
-
|
| 103 |
-
config = cuda.CUlaunchConfig()
|
| 104 |
-
config.gridDimX, config.gridDimY, config.gridDimZ = launch_config.grid
|
| 105 |
-
config.blockDimX, config.blockDimY, config.blockDimZ = launch_config.block
|
| 106 |
-
config.blockDimZ = launch_config.block[2]
|
| 107 |
-
config.sharedMemBytes = launch_config.shared_memory_capacity
|
| 108 |
-
config.hStream = stream
|
| 109 |
-
config.attrs = attrs
|
| 110 |
-
config.numAttrs = len(attrs)
|
| 111 |
-
|
| 112 |
-
err, = cuda.cuLaunchKernelEx(
|
| 113 |
-
config, f=self.kernel, kernelParams=kernel_params, extra=0)
|
| 114 |
-
return err
|
| 115 |
-
|
| 116 |
-
def run_without_clusters(self, launch_config, kernel_params, stream=None):
|
| 117 |
-
if not stream:
|
| 118 |
-
stream = cuda.CUstream(0)
|
| 119 |
-
err, = cuda.cuLaunchKernel(
|
| 120 |
-
self.kernel,
|
| 121 |
-
launch_config.grid[0], launch_config.grid[1], launch_config.grid[2],
|
| 122 |
-
launch_config.block[0], launch_config.block[1], launch_config.block[2],
|
| 123 |
-
launch_config.shared_memory_capacity,
|
| 124 |
-
stream,
|
| 125 |
-
kernel_params,
|
| 126 |
-
0)
|
| 127 |
-
|
| 128 |
-
return err
|
| 129 |
-
|
| 130 |
-
def run(self, host_workspace, device_workspace, launch_config, stream=None):
|
| 131 |
-
if not stream:
|
| 132 |
-
stream = cuda.CUstream(0)
|
| 133 |
-
cArg = (ctypes.c_char * len(host_workspace)).from_buffer(host_workspace)
|
| 134 |
-
packed = (ctypes.c_void_p * 1)()
|
| 135 |
-
packed[0] = ctypes.addressof(cArg)
|
| 136 |
-
|
| 137 |
-
if supports_cluster_launch():
|
| 138 |
-
return self.run_with_clusters(launch_config, packed, stream)
|
| 139 |
-
else:
|
| 140 |
-
return self.run_without_clusters(launch_config, packed, stream)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/reduction_operation.py
DELETED
|
@@ -1,455 +0,0 @@
|
|
| 1 |
-
################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
################################################################################
|
| 32 |
-
from __future__ import annotations
|
| 33 |
-
|
| 34 |
-
import ctypes
|
| 35 |
-
from typing import Union
|
| 36 |
-
|
| 37 |
-
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 38 |
-
cuda = lazy_import("cuda.cuda")
|
| 39 |
-
cudart = lazy_import("cuda.cudart")
|
| 40 |
-
import numpy as np
|
| 41 |
-
|
| 42 |
-
from cutlass_library import (
|
| 43 |
-
DataTypeNames,
|
| 44 |
-
DataTypeSize,
|
| 45 |
-
DataTypeTag,
|
| 46 |
-
LayoutType,
|
| 47 |
-
SubstituteTemplate
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
import cutlass_cppgen
|
| 51 |
-
from cutlass_cppgen.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params
|
| 52 |
-
from cutlass_cppgen.backend.frontend import NumpyFrontend, TorchFrontend
|
| 53 |
-
from cutlass_cppgen.backend.library import TensorDescription
|
| 54 |
-
from cutlass_cppgen.backend.memory_manager import DevicePtrWrapper
|
| 55 |
-
from cutlass_cppgen.backend.operation import ExecutableOperation, LaunchConfiguration
|
| 56 |
-
from cutlass_cppgen.shape import MatrixCoord
|
| 57 |
-
from cutlass_cppgen.utils.datatypes import is_numpy_tensor, is_torch_tensor
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
class ReductionOperation:
|
| 61 |
-
pass
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
class ReductionArguments:
|
| 65 |
-
"""
|
| 66 |
-
Arguments of reduction
|
| 67 |
-
"""
|
| 68 |
-
|
| 69 |
-
def __init__(
|
| 70 |
-
self,
|
| 71 |
-
operation: ReductionOperation,
|
| 72 |
-
problem_size: "list[int]",
|
| 73 |
-
partitions: int,
|
| 74 |
-
workspace: cuda.CUdeviceptr,
|
| 75 |
-
destination: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
|
| 76 |
-
source: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor]",
|
| 77 |
-
**kwargs,
|
| 78 |
-
) -> None:
|
| 79 |
-
# tensor_C can be interpreted as the bias with bias=True in keyword args
|
| 80 |
-
if "bias" in kwargs.keys():
|
| 81 |
-
self.bias = kwargs["bias"]
|
| 82 |
-
else:
|
| 83 |
-
# by default, tensor_C is not bias
|
| 84 |
-
self.bias = False
|
| 85 |
-
if "stream" in kwargs.keys():
|
| 86 |
-
self.stream = kwargs["stream"]
|
| 87 |
-
else:
|
| 88 |
-
self.stream = cuda.CUstream(0)
|
| 89 |
-
|
| 90 |
-
self.operation = operation
|
| 91 |
-
self.ptr_workspace = workspace
|
| 92 |
-
|
| 93 |
-
# number of split-k partitions
|
| 94 |
-
self.partitions = partitions
|
| 95 |
-
|
| 96 |
-
if is_numpy_tensor(destination):
|
| 97 |
-
self.host_D = destination
|
| 98 |
-
self.destination_buffer = NumpyFrontend.argument(destination, True)
|
| 99 |
-
self.source_buffer = NumpyFrontend.argument(source, False)
|
| 100 |
-
self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr)
|
| 101 |
-
self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr)
|
| 102 |
-
elif is_torch_tensor(destination):
|
| 103 |
-
self.ptr_destination = TorchFrontend.argument(destination)
|
| 104 |
-
self.ptr_source = TorchFrontend.argument(source)
|
| 105 |
-
elif isinstance(destination, cuda.CUdeviceptr):
|
| 106 |
-
self.ptr_destination = destination
|
| 107 |
-
self.ptr_source = source
|
| 108 |
-
else:
|
| 109 |
-
raise TypeError("unknown Type")
|
| 110 |
-
|
| 111 |
-
self.problem_size = MatrixCoord_(problem_size[0], problem_size[1])
|
| 112 |
-
|
| 113 |
-
self.partition_stride = (
|
| 114 |
-
problem_size[0] * problem_size[1] * DataTypeSize[operation.C.element] // 8
|
| 115 |
-
)
|
| 116 |
-
|
| 117 |
-
if "output_op" in kwargs.keys():
|
| 118 |
-
self.output_op = kwargs["output_op"]
|
| 119 |
-
else:
|
| 120 |
-
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
| 121 |
-
|
| 122 |
-
self.get_arguments()
|
| 123 |
-
|
| 124 |
-
@staticmethod
|
| 125 |
-
def get_tensor_ref(
|
| 126 |
-
extent: "tuple[int]",
|
| 127 |
-
device_ptr: cuda.CUdeviceptr,
|
| 128 |
-
layout: LayoutType,
|
| 129 |
-
):
|
| 130 |
-
if layout == LayoutType.RowMajor:
|
| 131 |
-
return TensorRef2D_(int(device_ptr), extent[1])
|
| 132 |
-
else:
|
| 133 |
-
raise ValueError(f"Unknown layout type {layout}")
|
| 134 |
-
|
| 135 |
-
def get_arguments(self):
|
| 136 |
-
ref_workspace = ReductionArguments.get_tensor_ref(
|
| 137 |
-
extent=[
|
| 138 |
-
self.problem_size.row,
|
| 139 |
-
self.problem_size.column,
|
| 140 |
-
],
|
| 141 |
-
device_ptr=self.ptr_workspace,
|
| 142 |
-
layout=LayoutType.RowMajor,
|
| 143 |
-
)
|
| 144 |
-
if self.bias:
|
| 145 |
-
ref_source = ReductionArguments.get_tensor_ref(
|
| 146 |
-
extent=[0, 0],
|
| 147 |
-
device_ptr=self.ptr_source,
|
| 148 |
-
layout=LayoutType.RowMajor,
|
| 149 |
-
)
|
| 150 |
-
else:
|
| 151 |
-
ref_source = ReductionArguments.get_tensor_ref(
|
| 152 |
-
extent=[
|
| 153 |
-
self.problem_size.row,
|
| 154 |
-
self.problem_size.column,
|
| 155 |
-
],
|
| 156 |
-
device_ptr=self.ptr_source,
|
| 157 |
-
layout=LayoutType.RowMajor,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
ref_destination = ReductionArguments.get_tensor_ref(
|
| 161 |
-
extent=[
|
| 162 |
-
self.problem_size.row,
|
| 163 |
-
self.problem_size.column,
|
| 164 |
-
],
|
| 165 |
-
device_ptr=self.ptr_destination,
|
| 166 |
-
layout=LayoutType.RowMajor,
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
self.c_arguments = self.operation.argument_type(
|
| 170 |
-
self.problem_size,
|
| 171 |
-
self.partitions,
|
| 172 |
-
self.partition_stride,
|
| 173 |
-
ref_workspace,
|
| 174 |
-
ref_destination,
|
| 175 |
-
ref_source,
|
| 176 |
-
self.output_op,
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
params_ = self.operation.rt_module.get_args(ctypes.byref(self.c_arguments))
|
| 180 |
-
self.host_workspace = bytearray(params_.contents)
|
| 181 |
-
|
| 182 |
-
def sync(self):
|
| 183 |
-
(err,) = cudart.cudaDeviceSynchronize()
|
| 184 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 185 |
-
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 186 |
-
|
| 187 |
-
if hasattr(self, "host_D"):
|
| 188 |
-
(err,) = cuda.cuMemcpyDtoH(
|
| 189 |
-
self.host_D,
|
| 190 |
-
self.ptr_destination,
|
| 191 |
-
self.host_D.size * self.host_D.itemsize,
|
| 192 |
-
)
|
| 193 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 194 |
-
raise RuntimeError("CUDA Error %s" % str(err))
|
| 195 |
-
|
| 196 |
-
self.free()
|
| 197 |
-
|
| 198 |
-
def free(self):
|
| 199 |
-
"""
|
| 200 |
-
Frees allocated device-side memory
|
| 201 |
-
"""
|
| 202 |
-
# Free any device memory allocated manually
|
| 203 |
-
if not cutlass_cppgen.use_rmm:
|
| 204 |
-
for attr in ["destination_buffer", "source_buffer"]:
|
| 205 |
-
if hasattr(self, attr):
|
| 206 |
-
buf = getattr(self, attr)
|
| 207 |
-
if isinstance(buf, DevicePtrWrapper):
|
| 208 |
-
err, = cudart.cudaFree(buf.ptr)
|
| 209 |
-
if err != cudart.cudaError_t.cudaSuccess:
|
| 210 |
-
raise RuntimeError(f"cudaFree failed with error {err}")
|
| 211 |
-
del buf
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
class ReductionRT(ExecutableOperation):
|
| 215 |
-
"""
|
| 216 |
-
ReductionRT manages the CUTLASS runtime components for reduction
|
| 217 |
-
"""
|
| 218 |
-
|
| 219 |
-
KernelTemplate = r"""
|
| 220 |
-
extern "C"
|
| 221 |
-
__global__ void
|
| 222 |
-
${operation_name}(${operation_name}${operation_suffix}::Params params) {
|
| 223 |
-
|
| 224 |
-
// Dynamic shared memory base pointer
|
| 225 |
-
extern __shared__ int SharedStorageBase[];
|
| 226 |
-
|
| 227 |
-
// Declare pointer to dynamic shared memory.
|
| 228 |
-
${operation_name}${operation_suffix}::SharedStorage *shared_storage =
|
| 229 |
-
reinterpret_cast<${operation_name}${operation_suffix}::SharedStorage *>(SharedStorageBase);
|
| 230 |
-
|
| 231 |
-
${operation_name}${operation_suffix} op;
|
| 232 |
-
|
| 233 |
-
op(params, *shared_storage);
|
| 234 |
-
}
|
| 235 |
-
"""
|
| 236 |
-
HostTemplate = r"""
|
| 237 |
-
extern "C" {
|
| 238 |
-
// Get the size of params in bytes
|
| 239 |
-
int ${operation_name}_get_param_size(){
|
| 240 |
-
return sizeof(${operation_name}${operation_suffix}::Params);
|
| 241 |
-
}
|
| 242 |
-
|
| 243 |
-
// Get the size of dynamic shared memory in bytes
|
| 244 |
-
int ${operation_name}_shared_memory_size() {
|
| 245 |
-
return int(sizeof(${operation_name}${operation_suffix}::SharedStorage));
|
| 246 |
-
}
|
| 247 |
-
|
| 248 |
-
// Get the params as byte array
|
| 249 |
-
char* ${operation_name}_get_params(${operation_name}${operation_suffix}::Params* params){
|
| 250 |
-
char *bytes = ((char*)(params));
|
| 251 |
-
char *output = new char[sizeof(${operation_name}${operation_suffix}::Params)];
|
| 252 |
-
for (unsigned int i = 0; i < sizeof(${operation_name}${operation_suffix}::Params); i ++)
|
| 253 |
-
output[i] = bytes[i];
|
| 254 |
-
|
| 255 |
-
return output;
|
| 256 |
-
}
|
| 257 |
-
}
|
| 258 |
-
"""
|
| 259 |
-
|
| 260 |
-
def __init__(self, operation: ReductionOperation):
|
| 261 |
-
super().__init__(operation)
|
| 262 |
-
|
| 263 |
-
self.operation: ReductionOperation = operation
|
| 264 |
-
self.emitter = EmitReductionInstance("_type")
|
| 265 |
-
|
| 266 |
-
self.elements_per_access = self.operation.count
|
| 267 |
-
(
|
| 268 |
-
self.argument_type,
|
| 269 |
-
self.epilogue_type,
|
| 270 |
-
) = get_reduction_params(operation.epilogue_functor)
|
| 271 |
-
self.argtype = [ctypes.POINTER(self.argument_type)]
|
| 272 |
-
|
| 273 |
-
def emit(self):
|
| 274 |
-
return self.emitter.emit(self.operation)
|
| 275 |
-
|
| 276 |
-
def plan(self, arguments: ReductionArguments):
|
| 277 |
-
block_shape = [
|
| 278 |
-
self.operation.shape.column // self.elements_per_access,
|
| 279 |
-
self.operation.shape.row,
|
| 280 |
-
1,
|
| 281 |
-
]
|
| 282 |
-
grid_shape = [
|
| 283 |
-
(arguments.problem_size.row + self.operation.shape.row - 1)
|
| 284 |
-
// self.operation.shape.row,
|
| 285 |
-
(arguments.problem_size.column + self.operation.shape.column - 1)
|
| 286 |
-
// self.operation.shape.column,
|
| 287 |
-
1,
|
| 288 |
-
]
|
| 289 |
-
return LaunchConfiguration(
|
| 290 |
-
grid_shape,
|
| 291 |
-
block_shape,
|
| 292 |
-
self.shared_memory_capacity,
|
| 293 |
-
)
|
| 294 |
-
|
| 295 |
-
def initialize(self):
|
| 296 |
-
(err,) = cuda.cuFuncSetAttribute(
|
| 297 |
-
self.kernel,
|
| 298 |
-
attrib=cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
|
| 299 |
-
value=self.shared_memory_capacity,
|
| 300 |
-
)
|
| 301 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 302 |
-
raise RuntimeError(f"CUDA Error: {err}")
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
class ReductionOperation:
|
| 306 |
-
"""
|
| 307 |
-
CUTLASS reduction Operation
|
| 308 |
-
"""
|
| 309 |
-
|
| 310 |
-
def __init__(
|
| 311 |
-
self,
|
| 312 |
-
shape: MatrixCoord,
|
| 313 |
-
C: TensorDescription,
|
| 314 |
-
element_accumulator,
|
| 315 |
-
element_workspace=None,
|
| 316 |
-
element_compute=None,
|
| 317 |
-
epilogue_functor=None,
|
| 318 |
-
count: int = 1,
|
| 319 |
-
partitions_per_stage: int = 4,
|
| 320 |
-
) -> None:
|
| 321 |
-
self.shape = shape
|
| 322 |
-
self.epilogue_functor = epilogue_functor
|
| 323 |
-
self.element_accumulator = element_accumulator
|
| 324 |
-
|
| 325 |
-
if element_workspace is None:
|
| 326 |
-
self.element_workspace = element_accumulator
|
| 327 |
-
else:
|
| 328 |
-
self.element_workspace = element_workspace
|
| 329 |
-
|
| 330 |
-
if element_compute is None:
|
| 331 |
-
self.element_compute = element_accumulator
|
| 332 |
-
else:
|
| 333 |
-
self.element_compute = element_compute
|
| 334 |
-
|
| 335 |
-
self.element_output = C.element
|
| 336 |
-
self.C: TensorDescription = C
|
| 337 |
-
|
| 338 |
-
# Reduce op processing size
|
| 339 |
-
self.count: int = count
|
| 340 |
-
|
| 341 |
-
# Number of partitions to reduce per stage
|
| 342 |
-
self.partitions_per_stage: int = partitions_per_stage
|
| 343 |
-
|
| 344 |
-
self.rt_module: ReductionRT = ReductionRT(self)
|
| 345 |
-
self.argument_type = self.rt_module.argument_type
|
| 346 |
-
self.epilogue_type = self.rt_module.epilogue_type
|
| 347 |
-
|
| 348 |
-
def extended_name(self):
|
| 349 |
-
extend_name = "${element_workspace}_${element_accumulator}_${element_compute}_${element_output}"
|
| 350 |
-
|
| 351 |
-
return SubstituteTemplate(
|
| 352 |
-
extend_name,
|
| 353 |
-
{
|
| 354 |
-
"element_workspace": DataTypeNames[self.element_workspace],
|
| 355 |
-
"element_accumulator": DataTypeNames[self.element_accumulator],
|
| 356 |
-
"element_compute": DataTypeNames[self.element_compute],
|
| 357 |
-
"element_output": DataTypeNames[self.element_output],
|
| 358 |
-
},
|
| 359 |
-
)
|
| 360 |
-
|
| 361 |
-
def configuration_name(self):
|
| 362 |
-
"""The full procedural name indicates architecture, extended name, tile size"""
|
| 363 |
-
|
| 364 |
-
configuration_name = "cutlass_reduce_split_k_${extended_name}_${threadblock}"
|
| 365 |
-
|
| 366 |
-
threadblock = "%dx%d" % (
|
| 367 |
-
self.shape.row,
|
| 368 |
-
self.shape.column,
|
| 369 |
-
)
|
| 370 |
-
|
| 371 |
-
return SubstituteTemplate(
|
| 372 |
-
configuration_name,
|
| 373 |
-
{
|
| 374 |
-
"extended_name": self.extended_name(),
|
| 375 |
-
"threadblock": threadblock,
|
| 376 |
-
},
|
| 377 |
-
)
|
| 378 |
-
|
| 379 |
-
def procedural_name(self):
|
| 380 |
-
"""The full procedural name indicates architeture, extended name, tile size"""
|
| 381 |
-
return self.configuration_name()
|
| 382 |
-
|
| 383 |
-
def run(self, arguments: ReductionArguments) -> cuda.CUresult:
|
| 384 |
-
"""
|
| 385 |
-
Configure and launch the cuda kernel with input arguments
|
| 386 |
-
"""
|
| 387 |
-
launch_config = self.rt_module.plan(arguments)
|
| 388 |
-
|
| 389 |
-
host_workspace = arguments.host_workspace
|
| 390 |
-
device_workspace = None
|
| 391 |
-
|
| 392 |
-
err = self.rt_module.run(
|
| 393 |
-
host_workspace,
|
| 394 |
-
device_workspace,
|
| 395 |
-
launch_config,
|
| 396 |
-
arguments.stream
|
| 397 |
-
)
|
| 398 |
-
|
| 399 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 400 |
-
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 401 |
-
|
| 402 |
-
return err
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
class EmitReductionInstance:
|
| 406 |
-
def __init__(self, operation_suffix="") -> None:
|
| 407 |
-
self.operation_suffix = operation_suffix
|
| 408 |
-
self.includes = [
|
| 409 |
-
"cutlass/cutlass.h",
|
| 410 |
-
"cutlass/numeric_types.h",
|
| 411 |
-
"cutlass/arch/arch.h",
|
| 412 |
-
"cutlass/arch/mma.h",
|
| 413 |
-
"cutlass/layout/matrix.h",
|
| 414 |
-
"cutlass/gemm/device/gemm.h",
|
| 415 |
-
"cutlass/gemm/device/gemm_universal_adapter.h",
|
| 416 |
-
"cutlass/gemm/kernel/default_gemm_universal.h",
|
| 417 |
-
"cutlass/reduction/kernel/reduce_split_k.h",
|
| 418 |
-
"cutlass/reduction/thread/reduction_operators.h",
|
| 419 |
-
]
|
| 420 |
-
self.template = """
|
| 421 |
-
// Reduction kernel instance
|
| 422 |
-
using ${operation_name}_base =
|
| 423 |
-
typename cutlass::reduction::kernel::ReduceSplitK<
|
| 424 |
-
cutlass::MatrixShape<${shape_row}, ${shape_column}>,
|
| 425 |
-
${epilogue_functor},
|
| 426 |
-
cutlass::reduction::thread::ReduceAdd<
|
| 427 |
-
${element_accumulator},
|
| 428 |
-
${element_output},
|
| 429 |
-
${count}>,
|
| 430 |
-
${partition_per_stage}>;
|
| 431 |
-
|
| 432 |
-
struct ${operation_name}${operation_suffix}:
|
| 433 |
-
public ${operation_name}_base { };
|
| 434 |
-
"""
|
| 435 |
-
|
| 436 |
-
def emit(self, operation: ReductionOperation):
|
| 437 |
-
vector_length_bits = min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
|
| 438 |
-
epilogue_vector_length = vector_length_bits // DataTypeSize[operation.C.element]
|
| 439 |
-
|
| 440 |
-
values = {
|
| 441 |
-
"operation_name": operation.configuration_name(),
|
| 442 |
-
"operation_suffix": self.operation_suffix,
|
| 443 |
-
"shape_row": str(operation.shape.row),
|
| 444 |
-
"shape_column": str(operation.shape.column),
|
| 445 |
-
"epilogue_functor": operation.epilogue_functor.emit(),
|
| 446 |
-
"element_output": DataTypeTag[operation.element_output],
|
| 447 |
-
"epilogue_vector_length": str(epilogue_vector_length),
|
| 448 |
-
"element_accumulator": DataTypeTag[operation.element_accumulator],
|
| 449 |
-
"element_compute": DataTypeTag[operation.element_compute],
|
| 450 |
-
"element_workspace": DataTypeTag[operation.element_workspace],
|
| 451 |
-
"count": str(operation.count),
|
| 452 |
-
"partition_per_stage": str(operation.partitions_per_stage),
|
| 453 |
-
}
|
| 454 |
-
|
| 455 |
-
return SubstituteTemplate(self.template, values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/type_hint.py
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
################################################################################
|
| 32 |
-
|
| 33 |
-
GemmOperation = "Union[GemmOperationUniversal, GemmOperationGrouped]"
|
| 34 |
-
|
| 35 |
-
Tensor = "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/__init__.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
################################################################################
|
| 32 |
-
|
| 33 |
-
from cutlass_cppgen.backend.utils.device import check_cuda_errors, device_cc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/backend/utils/device.py
DELETED
|
@@ -1,126 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Utility functions for interacting with the device
|
| 35 |
-
"""
|
| 36 |
-
from __future__ import annotations
|
| 37 |
-
|
| 38 |
-
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 39 |
-
cuda = lazy_import("cuda.cuda")
|
| 40 |
-
cudart = lazy_import("cuda.cudart")
|
| 41 |
-
|
| 42 |
-
import cutlass_cppgen
|
| 43 |
-
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def check_cuda_errors(result: list):
|
| 47 |
-
"""
|
| 48 |
-
Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
|
| 49 |
-
returns the result contained in the remaining fields of `result`.
|
| 50 |
-
|
| 51 |
-
:param result: the results of the `cudart` method, consisting of an error code and any method results
|
| 52 |
-
:type result: list
|
| 53 |
-
|
| 54 |
-
:return: non-error-code results from the `results` parameter
|
| 55 |
-
"""
|
| 56 |
-
# `result` is of the format : (cudaError_t, result...)
|
| 57 |
-
err = result[0]
|
| 58 |
-
if err.value:
|
| 59 |
-
raise RuntimeError("CUDA error: {}".format(cudart.cudaGetErrorName(err)))
|
| 60 |
-
|
| 61 |
-
if len(result) == 1:
|
| 62 |
-
return None
|
| 63 |
-
elif len(result) == 2:
|
| 64 |
-
return result[1]
|
| 65 |
-
else:
|
| 66 |
-
return result[1:]
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def device_cc(device: int = -1) -> int:
|
| 70 |
-
"""
|
| 71 |
-
Returns the compute capability of the device with ID `device`.
|
| 72 |
-
|
| 73 |
-
:param device: ID of the device to query
|
| 74 |
-
:type device: int
|
| 75 |
-
|
| 76 |
-
:return: compute capability of the queried device (e.g., 80 for SM80)
|
| 77 |
-
:rtype: int
|
| 78 |
-
"""
|
| 79 |
-
if device == -1:
|
| 80 |
-
device = cutlass_cppgen.device_id()
|
| 81 |
-
|
| 82 |
-
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
|
| 83 |
-
major = str(deviceProp.major)
|
| 84 |
-
minor = str(deviceProp.minor)
|
| 85 |
-
return int(major + minor)
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
def device_sm_count(device: int = -1):
|
| 89 |
-
if device == -1:
|
| 90 |
-
device = cutlass_cppgen.device_id()
|
| 91 |
-
err, device_sm_count = cuda.cuDeviceGetAttribute(
|
| 92 |
-
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device
|
| 93 |
-
)
|
| 94 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 95 |
-
raise Exception(
|
| 96 |
-
"Failed to retireve SM count. "
|
| 97 |
-
f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}"
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
return device_sm_count
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def to_device_ptr(tensor) -> cuda.CUdeviceptr:
|
| 104 |
-
"""
|
| 105 |
-
Converts a tensor to a CUdeviceptr
|
| 106 |
-
|
| 107 |
-
:param tensor: tensor to convert
|
| 108 |
-
:type tensor: np.ndarray | torch.Tensor | cp.ndarray | int
|
| 109 |
-
|
| 110 |
-
:return: device pointer
|
| 111 |
-
:rtype: cuda.CUdeviceptr
|
| 112 |
-
"""
|
| 113 |
-
if is_numpy_tensor(tensor):
|
| 114 |
-
ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0])
|
| 115 |
-
elif is_torch_tensor(tensor):
|
| 116 |
-
ptr = cuda.CUdeviceptr(tensor.data_ptr())
|
| 117 |
-
elif is_cupy_tensor(tensor):
|
| 118 |
-
ptr = cuda.CUdeviceptr(int(tensor.data.ptr))
|
| 119 |
-
elif isinstance(tensor, cuda.CUdeviceptr):
|
| 120 |
-
ptr = tensor
|
| 121 |
-
elif isinstance(tensor, int):
|
| 122 |
-
ptr = cuda.CUdeviceptr(tensor)
|
| 123 |
-
else:
|
| 124 |
-
raise NotImplementedError(tensor)
|
| 125 |
-
|
| 126 |
-
return ptr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/__init__.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
from cutlass_cppgen.emit.pytorch import pytorch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/common.py
DELETED
|
@@ -1,267 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Common utilities for emitting CUTLASS kernels
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import cutlass_cppgen
|
| 38 |
-
|
| 39 |
-
# Strings used for printing information about the generation of emitted scripts
|
| 40 |
-
_AUTOGEN_STR = f"This file was automatically generated by the CUTLASS {cutlass_cppgen.__version__} Python interface (https://github.com/nvidia/cutlass/python)"
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
_CSTYLE_AUTOGEN_COMMENT = f"""// {_AUTOGEN_STR}
|
| 44 |
-
"""
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
_PYSTYLE_AUTOGEN_COMMENT = f"""# {_AUTOGEN_STR}
|
| 48 |
-
"""
|
| 49 |
-
|
| 50 |
-
_CUTLASS_KERNEL_ARGS_2x = """
|
| 51 |
-
typename DeviceKernel::Arguments arguments {
|
| 52 |
-
cutlass::gemm::GemmUniversalMode::kGemm,
|
| 53 |
-
{M, N, K}, // problem size
|
| 54 |
-
1,
|
| 55 |
-
{alpha, beta},
|
| 56 |
-
A, B, C, D,
|
| 57 |
-
0, 0, 0, 0, // batch strides
|
| 58 |
-
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
|
| 59 |
-
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
|
| 60 |
-
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
|
| 61 |
-
DeviceKernel::LayoutC::packed({M, N}).stride(0) // ldd
|
| 62 |
-
};
|
| 63 |
-
"""
|
| 64 |
-
|
| 65 |
-
_CUTLASS_KERNEL_ARGS_2x_STREAM_K = """
|
| 66 |
-
typename DeviceKernel::Arguments arguments {
|
| 67 |
-
cutlass::gemm::GemmUniversalMode::kGemm,
|
| 68 |
-
{M, N, K}, // problem size
|
| 69 |
-
1,
|
| 70 |
-
{alpha, beta},
|
| 71 |
-
A, B, C, D,
|
| 72 |
-
0, 0, 0, 0, // batch strides
|
| 73 |
-
DeviceKernel::LayoutA::packed({M, K}).stride(0), // lda
|
| 74 |
-
DeviceKernel::LayoutB::packed({K, N}).stride(0), // ldb
|
| 75 |
-
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldc
|
| 76 |
-
DeviceKernel::LayoutC::packed({M, N}).stride(0), // ldd
|
| 77 |
-
-1 // avail_sms
|
| 78 |
-
};
|
| 79 |
-
"""
|
| 80 |
-
|
| 81 |
-
_CUTLASS_KERNEL_RUN_GEMM_2x = """
|
| 82 |
-
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
| 83 |
-
|
| 84 |
-
cutlass::Status ${name}_kernel_run(int M, int N, int K,
|
| 85 |
-
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
|
| 86 |
-
ElementCompute alpha, ElementCompute beta) {
|
| 87 |
-
${args}
|
| 88 |
-
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
| 89 |
-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 90 |
-
|
| 91 |
-
DeviceKernel gemm_op;
|
| 92 |
-
cutlass::Status status = gemm_op.initialize(arguments,
|
| 93 |
-
workspace.get(),
|
| 94 |
-
nullptr); // CUDA stream
|
| 95 |
-
|
| 96 |
-
if (status != cutlass::Status::kSuccess) {
|
| 97 |
-
return status;
|
| 98 |
-
}
|
| 99 |
-
|
| 100 |
-
status = gemm_op();
|
| 101 |
-
return status;
|
| 102 |
-
}
|
| 103 |
-
"""
|
| 104 |
-
|
| 105 |
-
_CUTLASS_KERNEL_RUN_GEMM_3x = """
|
| 106 |
-
using StrideA = typename DeviceKernel::GemmKernel::StrideA;
|
| 107 |
-
using StrideB = typename DeviceKernel::GemmKernel::StrideB;
|
| 108 |
-
using StrideC = typename DeviceKernel::GemmKernel::StrideC;
|
| 109 |
-
using StrideD = typename DeviceKernel::GemmKernel::StrideD;
|
| 110 |
-
|
| 111 |
-
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
| 112 |
-
|
| 113 |
-
cutlass::Status ${name}_kernel_run(
|
| 114 |
-
int M, int N, int K, int L,
|
| 115 |
-
const DeviceKernel::ElementA* A, const DeviceKernel::ElementB* B, const DeviceKernel::ElementC* C, DeviceKernel::ElementC* D,
|
| 116 |
-
ElementCompute alpha, ElementCompute beta, const cutlass::KernelHardwareInfo& hw_info) {
|
| 117 |
-
|
| 118 |
-
typename DeviceKernel::Arguments arguments{
|
| 119 |
-
cutlass::gemm::GemmUniversalMode::kGemm,
|
| 120 |
-
{M, N, K, L}, // problem size
|
| 121 |
-
{
|
| 122 |
-
A, // ptrA
|
| 123 |
-
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
|
| 124 |
-
B, // ptrB
|
| 125 |
-
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
|
| 126 |
-
},
|
| 127 |
-
{
|
| 128 |
-
{alpha, beta},
|
| 129 |
-
C, // ptrC
|
| 130 |
-
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
|
| 131 |
-
D, // ptrD
|
| 132 |
-
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
|
| 133 |
-
},
|
| 134 |
-
hw_info
|
| 135 |
-
};
|
| 136 |
-
|
| 137 |
-
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
| 138 |
-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 139 |
-
|
| 140 |
-
DeviceKernel gemm_op;
|
| 141 |
-
cutlass::Status status = gemm_op.run(arguments,
|
| 142 |
-
workspace.get(),
|
| 143 |
-
nullptr); // CUDA stream
|
| 144 |
-
|
| 145 |
-
return status;
|
| 146 |
-
}
|
| 147 |
-
"""
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
_CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x = """
|
| 151 |
-
using ElementCompute = typename DeviceKernel::EpilogueOutputOp::ElementCompute;
|
| 152 |
-
|
| 153 |
-
int threadblock_count = DeviceKernel::sufficient();
|
| 154 |
-
|
| 155 |
-
cutlass::Status ${name}_kernel_run(int problem_count, cutlass::gemm::GemmCoord* problem_sizes,
|
| 156 |
-
DeviceKernel::ElementA** A, DeviceKernel::ElementB** B, DeviceKernel::ElementC** C, DeviceKernel::ElementC** D,
|
| 157 |
-
int64_t* lda, int64_t* ldb, int64_t* ldc, int64_t* ldd,
|
| 158 |
-
ElementCompute alpha, ElementCompute beta) {
|
| 159 |
-
|
| 160 |
-
typename DeviceKernel::Arguments arguments {
|
| 161 |
-
problem_sizes,
|
| 162 |
-
problem_count,
|
| 163 |
-
threadblock_count,
|
| 164 |
-
{alpha, beta},
|
| 165 |
-
A, B, C, D,
|
| 166 |
-
lda, ldb, ldc, ldd
|
| 167 |
-
};
|
| 168 |
-
|
| 169 |
-
size_t workspace_size = DeviceKernel::get_workspace_size(arguments);
|
| 170 |
-
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 171 |
-
|
| 172 |
-
DeviceKernel gemm_op;
|
| 173 |
-
cutlass::Status status = gemm_op.initialize(arguments,
|
| 174 |
-
workspace.get(),
|
| 175 |
-
nullptr); // CUDA stream
|
| 176 |
-
|
| 177 |
-
if (status != cutlass::Status::kSuccess) {
|
| 178 |
-
return status;
|
| 179 |
-
}
|
| 180 |
-
|
| 181 |
-
status = gemm_op();
|
| 182 |
-
return status;
|
| 183 |
-
}
|
| 184 |
-
"""
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
_CUTLASS_KERNEL_RUN_CONV2D_2x = """
|
| 188 |
-
|
| 189 |
-
using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel;
|
| 190 |
-
namespace {
|
| 191 |
-
using TensorRefA = typename UnderlyingKernel::TensorRefA;
|
| 192 |
-
using TensorRefB = typename UnderlyingKernel::TensorRefB;
|
| 193 |
-
using TensorRefC = typename UnderlyingKernel::TensorRefC;
|
| 194 |
-
using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute;
|
| 195 |
-
}
|
| 196 |
-
|
| 197 |
-
template<typename TensorRef, typename Element>
|
| 198 |
-
TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){
|
| 199 |
-
cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord);
|
| 200 |
-
TensorRef tensor_ref(ptr, layout);
|
| 201 |
-
return tensor_ref;
|
| 202 |
-
}
|
| 203 |
-
|
| 204 |
-
cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size,
|
| 205 |
-
UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B,
|
| 206 |
-
UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D,
|
| 207 |
-
ElementCompute alpha, ElementCompute beta, std::string split_k_mode,
|
| 208 |
-
cudaStream_t stream, int device_id=0) {
|
| 209 |
-
// create the tensor references
|
| 210 |
-
cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent(
|
| 211 |
-
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
| 212 |
-
);
|
| 213 |
-
cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent(
|
| 214 |
-
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
| 215 |
-
);
|
| 216 |
-
cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent(
|
| 217 |
-
cutlass::conv::Operator::k${conv_kind_name}, *problem_size
|
| 218 |
-
);
|
| 219 |
-
|
| 220 |
-
TensorRefA tensor_ref_A = get_tensor_ref<TensorRefA, UnderlyingKernel::ElementA>(tensor_coord_A, A);
|
| 221 |
-
TensorRefB tensor_ref_B = get_tensor_ref<TensorRefB, UnderlyingKernel::ElementB>(tensor_coord_B, B);
|
| 222 |
-
TensorRefC tensor_ref_C = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, C);
|
| 223 |
-
TensorRefC tensor_ref_D = get_tensor_ref<TensorRefC, UnderlyingKernel::ElementC>(tensor_coord_C, D);
|
| 224 |
-
|
| 225 |
-
cutlass::conv::SplitKMode mode;
|
| 226 |
-
if (split_k_mode == "serial") {
|
| 227 |
-
mode = cutlass::conv::SplitKMode::kSerial;
|
| 228 |
-
} else if (split_k_mode == "parallel") {
|
| 229 |
-
mode = cutlass::conv::SplitKMode::kParallel;
|
| 230 |
-
} else {
|
| 231 |
-
throw std::runtime_error("Invalid split_k_mode: " + split_k_mode);
|
| 232 |
-
}
|
| 233 |
-
|
| 234 |
-
typename DeviceKernel::Arguments arguments{
|
| 235 |
-
*problem_size,
|
| 236 |
-
tensor_ref_A,
|
| 237 |
-
tensor_ref_B,
|
| 238 |
-
tensor_ref_C,
|
| 239 |
-
tensor_ref_D,
|
| 240 |
-
{alpha, beta},
|
| 241 |
-
mode
|
| 242 |
-
};
|
| 243 |
-
|
| 244 |
-
DeviceKernel implicit_gemm_op;
|
| 245 |
-
|
| 246 |
-
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
|
| 247 |
-
|
| 248 |
-
void* workspace_ptr = device_memory_allocation(workspace_size, device_id);
|
| 249 |
-
|
| 250 |
-
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
|
| 251 |
-
if (status != cutlass::Status::kSuccess) {
|
| 252 |
-
return status;
|
| 253 |
-
}
|
| 254 |
-
|
| 255 |
-
status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream);
|
| 256 |
-
if (status != cutlass::Status::kSuccess) {
|
| 257 |
-
return status;
|
| 258 |
-
}
|
| 259 |
-
|
| 260 |
-
//
|
| 261 |
-
// Launch initialized CUTLASS kernel
|
| 262 |
-
//
|
| 263 |
-
status = implicit_gemm_op(stream);
|
| 264 |
-
|
| 265 |
-
return status;
|
| 266 |
-
}
|
| 267 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/emit/pytorch.py
DELETED
|
@@ -1,936 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Utilities for generating source for building a PyTorch CUDA extension that using a CUTLASS kernel.
|
| 35 |
-
If specified, the extension can be JIT compiled via PyTorch's ``cpp_extension.load`` method.
|
| 36 |
-
|
| 37 |
-
Example usage with JIT compilation:
|
| 38 |
-
|
| 39 |
-
.. highlight:: python
|
| 40 |
-
.. code-block:: python
|
| 41 |
-
|
| 42 |
-
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor)
|
| 43 |
-
op = plan.construct()
|
| 44 |
-
mod = cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=True)
|
| 45 |
-
|
| 46 |
-
# Generate inputs for the GEMM
|
| 47 |
-
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
|
| 48 |
-
|
| 49 |
-
# Run the module
|
| 50 |
-
D = mod.run(A, B, C)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
Example usage without JIT compilation:
|
| 54 |
-
|
| 55 |
-
.. highlight:: python
|
| 56 |
-
.. code-block:: python
|
| 57 |
-
|
| 58 |
-
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 59 |
-
op = plan.construct()
|
| 60 |
-
cutlass_cppgen.emit.pytorch(op, 'cutlass_gemm', 80, jit=False, sourcedir='output')
|
| 61 |
-
|
| 62 |
-
After this call, the directory ``output`` contains ``setup.py``,
|
| 63 |
-
``cutlass_gemm.cpp``, and ``cutlass_gemm_kernel.cu``. The module can be built from
|
| 64 |
-
within ``output`` by running: ``TORCH_CUDA_ARCH_LIST="8.0" python setup.py develop --user``.
|
| 65 |
-
|
| 66 |
-
The module can later be used in Python via:
|
| 67 |
-
|
| 68 |
-
.. highlight:: python
|
| 69 |
-
.. code-block:: python
|
| 70 |
-
|
| 71 |
-
import torch
|
| 72 |
-
import cutlass_gemm
|
| 73 |
-
|
| 74 |
-
# Generate inputs for the GEMM
|
| 75 |
-
A, B, C = [torch.ones((512, 512)).to('cuda') for _ in range(3)]
|
| 76 |
-
|
| 77 |
-
# Run the module
|
| 78 |
-
D = cutlass_gemm.run(A, B, C)
|
| 79 |
-
"""
|
| 80 |
-
|
| 81 |
-
import logging
|
| 82 |
-
import os
|
| 83 |
-
|
| 84 |
-
from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate
|
| 85 |
-
|
| 86 |
-
from cutlass_cppgen import CUTLASS_PATH, logger, swizzle
|
| 87 |
-
from cutlass_cppgen.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal
|
| 88 |
-
from cutlass_cppgen.backend.conv2d_operation import Conv2dOperation
|
| 89 |
-
from cutlass_cppgen.backend.library import ApiVersion
|
| 90 |
-
from cutlass_cppgen.emit import common
|
| 91 |
-
from cutlass_cppgen.utils.datatypes import is_torch_available
|
| 92 |
-
|
| 93 |
-
if is_torch_available():
|
| 94 |
-
import torch
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
_PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
| 98 |
-
#include <cuda_runtime.h>
|
| 99 |
-
#include <torch/extension.h>
|
| 100 |
-
#include <ATen/ATen.h>
|
| 101 |
-
#include <ATen/cuda/CUDAContext.h>
|
| 102 |
-
#include "cutlass/cutlass.h"
|
| 103 |
-
#include "cutlass/util/device_memory.h"
|
| 104 |
-
|
| 105 |
-
// helper function allocating the memory
|
| 106 |
-
void* device_memory_allocation(size_t size, int device_id=0) {
|
| 107 |
-
if (size > 0) {
|
| 108 |
-
torch::Device device(torch::kCUDA, device_id);
|
| 109 |
-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 110 |
-
torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device);
|
| 111 |
-
at::Tensor device_tensor = torch::empty({(long)size,}, options);
|
| 112 |
-
return reinterpret_cast<void*>(device_tensor.data_ptr());
|
| 113 |
-
} else {
|
| 114 |
-
return nullptr;
|
| 115 |
-
}
|
| 116 |
-
}
|
| 117 |
-
|
| 118 |
-
${includes}
|
| 119 |
-
${declaration}
|
| 120 |
-
${impl}
|
| 121 |
-
"""
|
| 122 |
-
|
| 123 |
-
_PYTORCH_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
| 124 |
-
#include <torch/extension.h>
|
| 125 |
-
#include <ATen/ATen.h>
|
| 126 |
-
#include <pybind11/stl.h>
|
| 127 |
-
|
| 128 |
-
// CUDA forward declarations
|
| 129 |
-
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f);
|
| 130 |
-
|
| 131 |
-
// C++ interface
|
| 132 |
-
at::Tensor ${name}(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt, float alpha=1.f, float beta=0.f) {
|
| 133 |
-
return ${name}_kernel(A, B, C, alpha, beta);
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 137 |
-
m.def("run", py::overload_cast<const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>, float, float>(&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
|
| 138 |
-
}
|
| 139 |
-
"""
|
| 140 |
-
|
| 141 |
-
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
| 142 |
-
#include <torch/extension.h>
|
| 143 |
-
#include <ATen/ATen.h>
|
| 144 |
-
#include <pybind11/stl.h>
|
| 145 |
-
|
| 146 |
-
// CUDA forward declarations
|
| 147 |
-
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f);
|
| 148 |
-
|
| 149 |
-
// C++ interface
|
| 150 |
-
std::vector<at::Tensor> ${name}(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C=at::nullopt, float alpha=1.f, float beta=0.f) {
|
| 151 |
-
return ${name}_kernel(A, B, C, alpha, beta);
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 155 |
-
m.def("run", py::overload_cast<const std::vector<at::Tensor>&, const std::vector<at::Tensor>&, at::optional<const std::vector<at::Tensor>>, float, float>(&${name}),
|
| 156 |
-
py::arg("A"), py::arg("B"), py::arg("C") = nullptr, py::arg("alpha") = 1.f, py::arg("beta") = 0.f);
|
| 157 |
-
}
|
| 158 |
-
"""
|
| 159 |
-
|
| 160 |
-
_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
| 161 |
-
#include <torch/extension.h>
|
| 162 |
-
#include <ATen/ATen.h>
|
| 163 |
-
#include <pybind11/stl.h>
|
| 164 |
-
|
| 165 |
-
// CUDA forward declarations
|
| 166 |
-
at::Tensor ${name}_kernel(
|
| 167 |
-
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 168 |
-
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
| 169 |
-
float alpha=1.f, float beta=0.f,
|
| 170 |
-
std::string split_k_mode="serial", int split_k_slices=1);
|
| 171 |
-
|
| 172 |
-
// C++ interface
|
| 173 |
-
at::Tensor ${name}(
|
| 174 |
-
const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 175 |
-
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
| 176 |
-
float alpha=1.f, float beta=0.f,
|
| 177 |
-
std::string split_k_mode="serial", int split_k_slices=1) {
|
| 178 |
-
return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
| 179 |
-
}
|
| 180 |
-
|
| 181 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 182 |
-
m.def("run",
|
| 183 |
-
py::overload_cast<
|
| 184 |
-
const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
| 185 |
-
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
| 186 |
-
&${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
| 187 |
-
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
| 188 |
-
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
| 189 |
-
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
| 190 |
-
}
|
| 191 |
-
"""
|
| 192 |
-
|
| 193 |
-
_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """
|
| 194 |
-
#include <torch/extension.h>
|
| 195 |
-
#include <ATen/ATen.h>
|
| 196 |
-
#include <pybind11/stl.h>
|
| 197 |
-
|
| 198 |
-
// CUDA forward declarations
|
| 199 |
-
at::Tensor ${name}_kernel(
|
| 200 |
-
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 201 |
-
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
| 202 |
-
float alpha=1.f, float beta=0.f,
|
| 203 |
-
std::string split_k_mode="serial", int split_k_slices=1);
|
| 204 |
-
|
| 205 |
-
// C++ interface
|
| 206 |
-
at::Tensor ${name}(
|
| 207 |
-
std::tuple<int, int, int, int> result_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 208 |
-
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
| 209 |
-
float alpha=1.f, float beta=0.f,
|
| 210 |
-
std::string split_k_mode="serial", int split_k_slices=1) {
|
| 211 |
-
return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices);
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 215 |
-
m.def("run",
|
| 216 |
-
py::overload_cast<
|
| 217 |
-
std::tuple<int, int, int, int>, const at::Tensor&, const at::Tensor&, at::optional<const at::Tensor>,
|
| 218 |
-
std::tuple<int, int>, std::tuple<int, int>, std::tuple<int, int>, float, float, std::string, int>(
|
| 219 |
-
&${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr,
|
| 220 |
-
py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1),
|
| 221 |
-
py::arg("alpha") = 1.f, py::arg("beta") = 0.f,
|
| 222 |
-
py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1);
|
| 223 |
-
}
|
| 224 |
-
"""
|
| 225 |
-
|
| 226 |
-
_PYTORCH_GEMM_INCLUDES = {
|
| 227 |
-
ApiVersion.v2x: """
|
| 228 |
-
#include "cutlass/gemm/device/gemm_universal.h"
|
| 229 |
-
""",
|
| 230 |
-
ApiVersion.v3x: """
|
| 231 |
-
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
| 232 |
-
#include "cutlass/gemm/collective/collective_builder.hpp"
|
| 233 |
-
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
| 234 |
-
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
| 235 |
-
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
| 236 |
-
#include "cutlass/util/packed_stride.hpp"
|
| 237 |
-
""",
|
| 238 |
-
}
|
| 239 |
-
|
| 240 |
-
_PYTORCH_GROUPED_GEMM_INCLUDES = """
|
| 241 |
-
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
| 242 |
-
#include "cutlass/gemm/device/gemm_grouped.h"
|
| 243 |
-
"""
|
| 244 |
-
|
| 245 |
-
_PYTORCH_CONV2D_INCLUDES = """
|
| 246 |
-
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
|
| 247 |
-
#include "cutlass/conv/kernel/default_conv2d_dgrad.h"
|
| 248 |
-
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
|
| 249 |
-
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
| 250 |
-
"""
|
| 251 |
-
|
| 252 |
-
_CUTLASS_TYPE_TO_TORCH_TYPE = {
|
| 253 |
-
DataType.f16: "torch::kF16",
|
| 254 |
-
DataType.f32: "torch::kF32",
|
| 255 |
-
DataType.f64: "torch::kF64",
|
| 256 |
-
DataType.s8: "torch::kI8",
|
| 257 |
-
DataType.s32: "torch::kI32",
|
| 258 |
-
DataType.bf16: "torch::kBFloat16",
|
| 259 |
-
}
|
| 260 |
-
|
| 261 |
-
_PYTORCH_GEMM_IMPL_TEMPLATE_2x = (
|
| 262 |
-
common._CUTLASS_KERNEL_RUN_GEMM_2x
|
| 263 |
-
+ """
|
| 264 |
-
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
|
| 265 |
-
int M = A.size(0);
|
| 266 |
-
int N = B.size(1);
|
| 267 |
-
int K = A.size(1);
|
| 268 |
-
|
| 269 |
-
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
|
| 270 |
-
nullptr :
|
| 271 |
-
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
|
| 272 |
-
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
|
| 273 |
-
|
| 274 |
-
cutlass::Status status = ${name}_kernel_run(M, N, K,
|
| 275 |
-
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
|
| 276 |
-
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
|
| 277 |
-
ptrC,
|
| 278 |
-
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
|
| 279 |
-
ElementCompute(alpha), ElementCompute(beta));
|
| 280 |
-
|
| 281 |
-
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
| 282 |
-
return D;
|
| 283 |
-
}
|
| 284 |
-
"""
|
| 285 |
-
)
|
| 286 |
-
|
| 287 |
-
_PYTORCH_GEMM_IMPL_TEMPLATE_3x = (
|
| 288 |
-
common._CUTLASS_KERNEL_RUN_GEMM_3x
|
| 289 |
-
+ """
|
| 290 |
-
bool hw_info_queried = false;
|
| 291 |
-
cutlass::KernelHardwareInfo hw_info;
|
| 292 |
-
|
| 293 |
-
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C, float alpha, float beta) {
|
| 294 |
-
int M = A.size(0);
|
| 295 |
-
int N = B.size(1);
|
| 296 |
-
int K = A.size(1);
|
| 297 |
-
int L = 1;
|
| 298 |
-
|
| 299 |
-
// Query hardware info if we haven't already
|
| 300 |
-
if (!hw_info_queried) {
|
| 301 |
-
hw_info.device_id = 0;
|
| 302 |
-
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
| 303 |
-
}
|
| 304 |
-
|
| 305 |
-
typename DeviceKernel::ElementC* ptrC = (C == at::nullopt) ?
|
| 306 |
-
nullptr :
|
| 307 |
-
reinterpret_cast<typename DeviceKernel::ElementC*>(C->contiguous().data_ptr());
|
| 308 |
-
at::Tensor D = B.new_empty({M, N}, ${torch_type_C});
|
| 309 |
-
|
| 310 |
-
cutlass::Status status = ${name}_kernel_run(M, N, K, L,
|
| 311 |
-
reinterpret_cast<typename DeviceKernel::ElementA*>(A.contiguous().data_ptr()),
|
| 312 |
-
reinterpret_cast<typename DeviceKernel::ElementB*>(B.contiguous().data_ptr()),
|
| 313 |
-
ptrC,
|
| 314 |
-
reinterpret_cast<typename DeviceKernel::ElementC*>(D.contiguous().data_ptr()),
|
| 315 |
-
ElementCompute(alpha), ElementCompute(beta),
|
| 316 |
-
hw_info);
|
| 317 |
-
|
| 318 |
-
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
| 319 |
-
return D;
|
| 320 |
-
}
|
| 321 |
-
"""
|
| 322 |
-
)
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE = (
|
| 326 |
-
common._CUTLASS_KERNEL_RUN_GROUPED_GEMM_2x
|
| 327 |
-
+ """
|
| 328 |
-
std::vector<at::Tensor> ${name}_kernel(const std::vector<at::Tensor>& A, const std::vector<at::Tensor>& B, at::optional<const std::vector<at::Tensor>> C, float alpha, float beta) {
|
| 329 |
-
size_t num = A.size();
|
| 330 |
-
|
| 331 |
-
// To avoid performing many small cudaMallocs and host-to-device copies,
|
| 332 |
-
// we serialize the grouped GEMM arguments on the host, allocate one
|
| 333 |
-
// large chunk of device memory, and perform a single cudaMemcpy to
|
| 334 |
-
// copy the host data to the device. Allocation overheads could be
|
| 335 |
-
// avoided by using a memory pool.
|
| 336 |
-
|
| 337 |
-
// Calculate the total size of the data to be copied from host to device
|
| 338 |
-
size_t total_size = sizeof(cutlass::gemm::GemmCoord) +
|
| 339 |
-
sizeof(DeviceKernel::ElementA*) +
|
| 340 |
-
sizeof(DeviceKernel::ElementB*) +
|
| 341 |
-
sizeof(DeviceKernel::ElementC*) +
|
| 342 |
-
sizeof(DeviceKernel::ElementC*) +
|
| 343 |
-
sizeof(int64_t) +
|
| 344 |
-
sizeof(int64_t) +
|
| 345 |
-
sizeof(int64_t);
|
| 346 |
-
total_size *= num;
|
| 347 |
-
|
| 348 |
-
// num * sizeof(cutlass::gemm::GemmCoord) may leave one at a non-multiple
|
| 349 |
-
// of sizeof(DeviceKernel::ElementA*) (which will be 64 on a 64-bit system).
|
| 350 |
-
// To ensure that we don't end up having misaligned loads in the kernel,
|
| 351 |
-
// we pad to the nearest multiple of 8.
|
| 352 |
-
//
|
| 353 |
-
// Note that, even on a 32-bit system (for which sizeof(X*) will not equal
|
| 354 |
-
// sizeof(int64_t)), only padding between the list of GemmCoords and the
|
| 355 |
-
// list of ptr_As is sufficient because the set of four equal-length lists of pointers
|
| 356 |
-
// (A*, B*, C*, D*) will ensure that the first list of int64_ts will always
|
| 357 |
-
// start on a multiple of 8.
|
| 358 |
-
int64_t padding = 8 - (total_size % 8);
|
| 359 |
-
total_size += padding;
|
| 360 |
-
|
| 361 |
-
uint8_t* host_data = new uint8_t[total_size];
|
| 362 |
-
cutlass::DeviceAllocation<uint8_t> device_data(total_size);
|
| 363 |
-
|
| 364 |
-
uint8_t* start = host_data;
|
| 365 |
-
cutlass::gemm::GemmCoord* problem_sizes_host = reinterpret_cast<cutlass::gemm::GemmCoord*>(start);
|
| 366 |
-
|
| 367 |
-
// Apply the padding after the list of GemmCoords
|
| 368 |
-
start += num * sizeof(cutlass::gemm::GemmCoord) + padding;
|
| 369 |
-
|
| 370 |
-
int64_t ptr_A_offset = start - host_data;
|
| 371 |
-
DeviceKernel::ElementA** ptr_A_host = reinterpret_cast<DeviceKernel::ElementA**>(start);
|
| 372 |
-
start += num * sizeof(DeviceKernel::ElementA*);
|
| 373 |
-
|
| 374 |
-
int64_t ptr_B_offset = start - host_data;
|
| 375 |
-
DeviceKernel::ElementB** ptr_B_host = reinterpret_cast<DeviceKernel::ElementB**>(start);
|
| 376 |
-
start += num * sizeof(DeviceKernel::ElementB*);
|
| 377 |
-
|
| 378 |
-
int64_t ptr_C_offset = start - host_data;
|
| 379 |
-
DeviceKernel::ElementC** ptr_C_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
|
| 380 |
-
start += num * sizeof(DeviceKernel::ElementC*);
|
| 381 |
-
|
| 382 |
-
int64_t ptr_D_offset = start - host_data;
|
| 383 |
-
DeviceKernel::ElementC** ptr_D_host = reinterpret_cast<DeviceKernel::ElementC**>(start);
|
| 384 |
-
start += num * sizeof(DeviceKernel::ElementC*);
|
| 385 |
-
|
| 386 |
-
int64_t lda_offset = start - host_data;
|
| 387 |
-
int64_t* lda_host = reinterpret_cast<int64_t*>(start);
|
| 388 |
-
start += num * sizeof(int64_t);
|
| 389 |
-
|
| 390 |
-
int64_t ldb_offset = start - host_data;
|
| 391 |
-
int64_t* ldb_host = reinterpret_cast<int64_t*>(start);
|
| 392 |
-
start += num * sizeof(int64_t);
|
| 393 |
-
|
| 394 |
-
int64_t ldc_offset = start - host_data;
|
| 395 |
-
int64_t* ldc_host = reinterpret_cast<int64_t*>(start);
|
| 396 |
-
start += num * sizeof(int64_t);
|
| 397 |
-
|
| 398 |
-
std::vector<at::Tensor> D(num);
|
| 399 |
-
|
| 400 |
-
bool need_C = (C != at::nullopt) && (beta != 0.f);
|
| 401 |
-
for (size_t i = 0; i < num; ++i) {
|
| 402 |
-
int M = A[i].size(0);
|
| 403 |
-
int N = B[i].size(1);
|
| 404 |
-
int K = A[i].size(1);
|
| 405 |
-
*(problem_sizes_host + i) = {M, N, K};
|
| 406 |
-
*(ptr_A_host + i) = reinterpret_cast<typename DeviceKernel::ElementA*>(A[i].contiguous().data_ptr());
|
| 407 |
-
*(ptr_B_host + i) = reinterpret_cast<typename DeviceKernel::ElementB*>(B[i].contiguous().data_ptr());
|
| 408 |
-
|
| 409 |
-
if (need_C) {
|
| 410 |
-
*(ptr_C_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(C->at(i).contiguous().data_ptr());
|
| 411 |
-
}
|
| 412 |
-
else {
|
| 413 |
-
*(ptr_C_host + i) = nullptr;
|
| 414 |
-
}
|
| 415 |
-
|
| 416 |
-
D[i] = B[i].new_empty({M, N}, ${torch_type_C});
|
| 417 |
-
*(ptr_D_host + i) = reinterpret_cast<typename DeviceKernel::ElementC*>(D[i].contiguous().data_ptr());
|
| 418 |
-
|
| 419 |
-
*(lda_host + i) = DeviceKernel::LayoutA::packed({M, K}).stride(0);
|
| 420 |
-
*(ldb_host + i) = DeviceKernel::LayoutB::packed({K, N}).stride(0);
|
| 421 |
-
*(ldc_host + i) = DeviceKernel::LayoutC::packed({M, N}).stride(0);
|
| 422 |
-
}
|
| 423 |
-
|
| 424 |
-
device_data.copy_from_host(host_data);
|
| 425 |
-
|
| 426 |
-
cutlass::Status status = ${name}_kernel_run(
|
| 427 |
-
num,
|
| 428 |
-
reinterpret_cast<cutlass::gemm::GemmCoord*>(device_data.get()),
|
| 429 |
-
reinterpret_cast<DeviceKernel::ElementA**>(device_data.get() + ptr_A_offset),
|
| 430 |
-
reinterpret_cast<DeviceKernel::ElementB**>(device_data.get() + ptr_B_offset),
|
| 431 |
-
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_C_offset),
|
| 432 |
-
reinterpret_cast<DeviceKernel::ElementC**>(device_data.get() + ptr_D_offset),
|
| 433 |
-
reinterpret_cast<int64_t*>(device_data.get() + lda_offset),
|
| 434 |
-
reinterpret_cast<int64_t*>(device_data.get() + ldb_offset),
|
| 435 |
-
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
|
| 436 |
-
reinterpret_cast<int64_t*>(device_data.get() + ldc_offset),
|
| 437 |
-
ElementCompute(alpha), ElementCompute(beta));
|
| 438 |
-
|
| 439 |
-
delete[] host_data;
|
| 440 |
-
|
| 441 |
-
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
| 442 |
-
return D;
|
| 443 |
-
}
|
| 444 |
-
"""
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """
|
| 448 |
-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 449 |
-
|
| 450 |
-
cutlass::Status status = ${name}_kernel_run(
|
| 451 |
-
&problem_size,
|
| 452 |
-
reinterpret_cast<typename UnderlyingKernel::ElementA*>(A.data_ptr()),
|
| 453 |
-
reinterpret_cast<typename UnderlyingKernel::ElementB*>(B.data_ptr()),
|
| 454 |
-
ptrC,
|
| 455 |
-
reinterpret_cast<typename UnderlyingKernel::ElementC*>(D.data_ptr()),
|
| 456 |
-
alpha, beta,
|
| 457 |
-
split_k_mode, stream, B.device().index());
|
| 458 |
-
|
| 459 |
-
TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed");
|
| 460 |
-
return D;
|
| 461 |
-
}
|
| 462 |
-
"""
|
| 463 |
-
|
| 464 |
-
_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = (
|
| 465 |
-
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
| 466 |
-
+ """
|
| 467 |
-
at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 468 |
-
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1},
|
| 469 |
-
float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) {
|
| 470 |
-
int N, H, W, C_, K, R, S, P, Q;
|
| 471 |
-
N = A.size(0);
|
| 472 |
-
C_ = A.size(1);
|
| 473 |
-
H = A.size(2);
|
| 474 |
-
W = A.size(3);
|
| 475 |
-
|
| 476 |
-
K = B.size(0);
|
| 477 |
-
R = B.size(2);
|
| 478 |
-
S = B.size(3);
|
| 479 |
-
|
| 480 |
-
cutlass::conv::Conv2dProblemSize problem_size(
|
| 481 |
-
cutlass::Tensor4DCoord(N, H, W, C_),
|
| 482 |
-
cutlass::Tensor4DCoord(K, R, S, C_),
|
| 483 |
-
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
| 484 |
-
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
| 485 |
-
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
| 486 |
-
cutlass::conv::Mode::kCrossCorrelation,
|
| 487 |
-
split_k_slices
|
| 488 |
-
);
|
| 489 |
-
|
| 490 |
-
P = problem_size.P;
|
| 491 |
-
Q = problem_size.Q;
|
| 492 |
-
|
| 493 |
-
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
| 494 |
-
nullptr :
|
| 495 |
-
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
| 496 |
-
|
| 497 |
-
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
| 498 |
-
at::Tensor D = torch::zeros({N, K, P, Q}, options);
|
| 499 |
-
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
| 500 |
-
)
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = (
|
| 504 |
-
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
| 505 |
-
+ """
|
| 506 |
-
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> input_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 507 |
-
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
| 508 |
-
std::string split_k_mode="serial", int split_k_slices=1) {
|
| 509 |
-
int N, H, W, C_, K, R, S;
|
| 510 |
-
N = std::get<0>(input_size);
|
| 511 |
-
C_ = std::get<1>(input_size);
|
| 512 |
-
H = std::get<2>(input_size);
|
| 513 |
-
W = std::get<3>(input_size);
|
| 514 |
-
|
| 515 |
-
K = B.size(0);
|
| 516 |
-
R = B.size(2);
|
| 517 |
-
S = B.size(3);
|
| 518 |
-
|
| 519 |
-
cutlass::conv::Conv2dProblemSize problem_size(
|
| 520 |
-
cutlass::Tensor4DCoord(N, H, W, C_),
|
| 521 |
-
cutlass::Tensor4DCoord(K, R, S, C_),
|
| 522 |
-
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
| 523 |
-
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
| 524 |
-
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
| 525 |
-
cutlass::conv::Mode::kCrossCorrelation,
|
| 526 |
-
split_k_slices
|
| 527 |
-
);
|
| 528 |
-
|
| 529 |
-
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
| 530 |
-
nullptr :
|
| 531 |
-
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
| 532 |
-
|
| 533 |
-
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
| 534 |
-
at::Tensor D = torch::empty({N, C_, H, W}, options);
|
| 535 |
-
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
| 536 |
-
)
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = (
|
| 540 |
-
common._CUTLASS_KERNEL_RUN_CONV2D_2x
|
| 541 |
-
+ """
|
| 542 |
-
at::Tensor ${name}_kernel(std::tuple<int, int, int, int> weight_size, const at::Tensor& A, const at::Tensor& B, at::optional<const at::Tensor> C=at::nullopt,
|
| 543 |
-
std::tuple<int, int> stride={1, 1}, std::tuple<int, int> padding={0, 0}, std::tuple<int, int> dilation={1, 1}, float alpha=1.f, float beta=0.f,
|
| 544 |
-
std::string split_k_mode="serial", int split_k_slices=1) {
|
| 545 |
-
int N, H, W, C_, K, R, S;
|
| 546 |
-
K = std::get<0>(weight_size);
|
| 547 |
-
C_ = std::get<1>(weight_size);
|
| 548 |
-
R = std::get<2>(weight_size);
|
| 549 |
-
S = std::get<3>(weight_size);
|
| 550 |
-
|
| 551 |
-
N = B.size(0);
|
| 552 |
-
H = B.size(2);
|
| 553 |
-
W = B.size(3);
|
| 554 |
-
|
| 555 |
-
cutlass::conv::Conv2dProblemSize problem_size(
|
| 556 |
-
cutlass::Tensor4DCoord(N, H, W, C_),
|
| 557 |
-
cutlass::Tensor4DCoord(K, R, S, C_),
|
| 558 |
-
cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)),
|
| 559 |
-
cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)),
|
| 560 |
-
cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)),
|
| 561 |
-
cutlass::conv::Mode::kCrossCorrelation,
|
| 562 |
-
split_k_slices
|
| 563 |
-
);
|
| 564 |
-
|
| 565 |
-
typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ?
|
| 566 |
-
nullptr :
|
| 567 |
-
reinterpret_cast<typename UnderlyingKernel::ElementC*>(C->data_ptr());
|
| 568 |
-
|
| 569 |
-
torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast);
|
| 570 |
-
at::Tensor D = torch::empty({K, C_, R, S}, options);
|
| 571 |
-
""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x
|
| 572 |
-
)
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
_PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """
|
| 576 |
-
from setuptools import setup
|
| 577 |
-
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
| 578 |
-
|
| 579 |
-
setup(
|
| 580 |
-
name='${name}',
|
| 581 |
-
ext_modules=[
|
| 582 |
-
CUDAExtension('${name}', [
|
| 583 |
-
'${name}.cpp',
|
| 584 |
-
'${name}_kernel.cu',
|
| 585 |
-
],
|
| 586 |
-
include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
|
| 587 |
-
extra_compile_args={
|
| 588 |
-
'cxx': ['-std=c++17'],
|
| 589 |
-
'nvcc': ['-std=c++17', ${extra_compile_args}],
|
| 590 |
-
},
|
| 591 |
-
libraries=['cuda']
|
| 592 |
-
),
|
| 593 |
-
],
|
| 594 |
-
cmdclass={
|
| 595 |
-
'build_ext': BuildExtension
|
| 596 |
-
})
|
| 597 |
-
|
| 598 |
-
"""
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""):
|
| 602 |
-
"""
|
| 603 |
-
Generates a setup.py file for the extension
|
| 604 |
-
|
| 605 |
-
:param name: name of the module to generate
|
| 606 |
-
:type name: str
|
| 607 |
-
:param sourcedir: directory to which generated source files should be written
|
| 608 |
-
:type sourcedir: str
|
| 609 |
-
:param extra_compile_args: additional arguments to pass to setup.py
|
| 610 |
-
:type extra_args: str
|
| 611 |
-
"""
|
| 612 |
-
setup_py_file = os.path.join(sourcedir, "setup.py")
|
| 613 |
-
setup_source = SubstituteTemplate(
|
| 614 |
-
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args}
|
| 615 |
-
)
|
| 616 |
-
with open(setup_py_file, "w") as outfile:
|
| 617 |
-
outfile.write(setup_source)
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
class _ArchListSetter:
|
| 621 |
-
"""
|
| 622 |
-
Utility context manager for temporarily setting the value of the ``TORCH_CUDA_ARCH_LIST``
|
| 623 |
-
environment variable when building a PyTorch CUDA module.
|
| 624 |
-
|
| 625 |
-
``TORCH_CUDA_ARCH_LIST`` is a space-delmited list of compute capabilites for which a PyTorch
|
| 626 |
-
CUDA module should be compiled.
|
| 627 |
-
|
| 628 |
-
For example, ``TORCH_CUDA_ARCH_LIST="7.0 8.0"`` would result in the inclusion of
|
| 629 |
-
``-gencode=arch=compute_70,code=sm_70`` and ``-gencode=arch=compute_80,code=sm_80`` in the
|
| 630 |
-
compilation of the module.
|
| 631 |
-
|
| 632 |
-
This utility wraps the building of a PyTorch CUDA module with a setting of this environment
|
| 633 |
-
variable according to the current compute capability being targetted.
|
| 634 |
-
|
| 635 |
-
Example usage:
|
| 636 |
-
|
| 637 |
-
.. highlight:: python
|
| 638 |
-
.. code-block:: python
|
| 639 |
-
|
| 640 |
-
# Temporarily set TORCH_CUDA_ARCH_LIST="8.0"
|
| 641 |
-
with _ArchListSetter(80):
|
| 642 |
-
# Perform JIT compilation and loading of the module
|
| 643 |
-
mod = torch.utils.cpp_extension.load(...)
|
| 644 |
-
|
| 645 |
-
:param cc: compute capability
|
| 646 |
-
:type cc: int
|
| 647 |
-
"""
|
| 648 |
-
|
| 649 |
-
_TORCH_CUDA_ARCH_LIST = "TORCH_CUDA_ARCH_LIST"
|
| 650 |
-
|
| 651 |
-
def __init__(self, cc: int):
|
| 652 |
-
self.cc_str = ".".join(list(str(cc)))
|
| 653 |
-
|
| 654 |
-
def __enter__(self):
|
| 655 |
-
"""
|
| 656 |
-
Saves the old value of TORCH_CUDA_ARCH_LIST and reset it to the new value based on ``cc``
|
| 657 |
-
"""
|
| 658 |
-
self.old_arch_list = os.getenv(_ArchListSetter._TORCH_CUDA_ARCH_LIST)
|
| 659 |
-
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.cc_str
|
| 660 |
-
|
| 661 |
-
return self
|
| 662 |
-
|
| 663 |
-
def __exit__(self, exc_type, exc_val, traceback):
|
| 664 |
-
"""
|
| 665 |
-
Restores the old value of TORCH_CUDA_ARCH_LIST
|
| 666 |
-
"""
|
| 667 |
-
if self.old_arch_list is None:
|
| 668 |
-
del os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST]
|
| 669 |
-
else:
|
| 670 |
-
os.environ[_ArchListSetter._TORCH_CUDA_ARCH_LIST] = self.old_arch_list
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
|
| 674 |
-
"""
|
| 675 |
-
JIT compiles and loads a PyTorch CUDA extension.
|
| 676 |
-
|
| 677 |
-
:param name: name of the module to generate
|
| 678 |
-
:type name: str
|
| 679 |
-
:param cc: compute capability of the device the module should target
|
| 680 |
-
:type cc: int
|
| 681 |
-
:param cpp_file: path to file containing extension's C++ interface
|
| 682 |
-
:type cpp_file: str
|
| 683 |
-
:param cuda_file: path to file containing extension's CUDA interface
|
| 684 |
-
:type cuda_file: str
|
| 685 |
-
|
| 686 |
-
:return: loaded PyTorch module
|
| 687 |
-
"""
|
| 688 |
-
|
| 689 |
-
from torch.utils.cpp_extension import load
|
| 690 |
-
|
| 691 |
-
extra_cuda_cflags = ["-std=c++17"]
|
| 692 |
-
if cc in [90, 100, 101, 103]:
|
| 693 |
-
# PyTorch does not currently add the sm_90a target when compute capability
|
| 694 |
-
# 9.0 is set within TORCH_CUDA_ARCH_LIST. Thus, we manually add the sm_90a target.
|
| 695 |
-
extra_cuda_cflags.append(f"-gencode=arch=compute_{cc}a,code=sm_{cc}a")
|
| 696 |
-
|
| 697 |
-
with _ArchListSetter(cc):
|
| 698 |
-
jitmodule = load(
|
| 699 |
-
name,
|
| 700 |
-
[cpp_file, cuda_file],
|
| 701 |
-
extra_cuda_cflags=extra_cuda_cflags,
|
| 702 |
-
extra_include_paths=[
|
| 703 |
-
os.path.join(CUTLASS_PATH, "include"),
|
| 704 |
-
os.path.join(CUTLASS_PATH, "tools/util/include"),
|
| 705 |
-
],
|
| 706 |
-
extra_ldflags=["-lcuda"],
|
| 707 |
-
verbose=(logger.level == logging.DEBUG)
|
| 708 |
-
)
|
| 709 |
-
return jitmodule
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
| 713 |
-
"""
|
| 714 |
-
Generates source for building a PyTorch CUDA module that leverages the CUTLASS GEMM
|
| 715 |
-
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
| 716 |
-
compiled, loaded, and returned.
|
| 717 |
-
|
| 718 |
-
:param op: operation to emit in the module
|
| 719 |
-
:param name: name of the module to generate
|
| 720 |
-
:type name: str
|
| 721 |
-
:param cc: compute capability of the device the module should target
|
| 722 |
-
:type cc: int
|
| 723 |
-
:param jit: whether the module should be just-in-time compiled
|
| 724 |
-
:type jit: bool
|
| 725 |
-
:param sourcedir: directory to which generated source files should be written
|
| 726 |
-
:type sourcedir: str
|
| 727 |
-
|
| 728 |
-
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
| 729 |
-
"""
|
| 730 |
-
if sourcedir != "" and not os.path.isdir(sourcedir):
|
| 731 |
-
os.makedirs(sourcedir)
|
| 732 |
-
|
| 733 |
-
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
| 734 |
-
extra_kw = {}
|
| 735 |
-
if op.api == ApiVersion.v3x:
|
| 736 |
-
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_3x
|
| 737 |
-
else:
|
| 738 |
-
impl_template = _PYTORCH_GEMM_IMPL_TEMPLATE_2x
|
| 739 |
-
if op.swizzling_functor == swizzle.ThreadblockSwizzleStreamK:
|
| 740 |
-
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x_STREAM_K
|
| 741 |
-
else:
|
| 742 |
-
extra_kw["args"] = common._CUTLASS_KERNEL_ARGS_2x
|
| 743 |
-
impl_template = (
|
| 744 |
-
_PYTORCH_GEMM_IMPL_TEMPLATE_3x
|
| 745 |
-
if op.api == ApiVersion.v3x
|
| 746 |
-
else _PYTORCH_GEMM_IMPL_TEMPLATE_2x
|
| 747 |
-
)
|
| 748 |
-
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
|
| 749 |
-
cuda_source = SubstituteTemplate(
|
| 750 |
-
_PYTORCH_CUDA_TEMPLATE,
|
| 751 |
-
{
|
| 752 |
-
"includes": _PYTORCH_GEMM_INCLUDES[op.api],
|
| 753 |
-
"declaration": op.rt_module.emit(),
|
| 754 |
-
"procedural_name": op.procedural_name(),
|
| 755 |
-
"impl": cuda_impl,
|
| 756 |
-
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
| 757 |
-
},
|
| 758 |
-
)
|
| 759 |
-
with open(cuda_file, "w") as outfile:
|
| 760 |
-
outfile.write(cuda_source)
|
| 761 |
-
|
| 762 |
-
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
| 763 |
-
cpp_source = SubstituteTemplate(
|
| 764 |
-
_PYTORCH_GEMM_CPP_TEMPLATE,
|
| 765 |
-
{"name": name, "description": f"CUTLASS {op.procedural_name()} GEMM"},
|
| 766 |
-
)
|
| 767 |
-
with open(cpp_file, "w") as outfile:
|
| 768 |
-
outfile.write(cpp_source)
|
| 769 |
-
|
| 770 |
-
extra_compile_args = ""
|
| 771 |
-
if cc in [90, 100, 101, 103]:
|
| 772 |
-
extra_compile_args = f"'--generate-code=arch=compute_{cc}a,code=[sm_{cc}a]'"
|
| 773 |
-
_generate_setup(name, sourcedir, extra_compile_args)
|
| 774 |
-
|
| 775 |
-
if jit:
|
| 776 |
-
return _jit(name, cc, cpp_file, cuda_file)
|
| 777 |
-
|
| 778 |
-
return None
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
def _pytorch_grouped_gemm(
|
| 782 |
-
op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
|
| 783 |
-
):
|
| 784 |
-
"""
|
| 785 |
-
Generates source for building a PyTorch CUDA module that leverages the CUTLASS grouped GEMM
|
| 786 |
-
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
| 787 |
-
compiled, loaded, and returned.
|
| 788 |
-
|
| 789 |
-
:param op: operation to emit in the module
|
| 790 |
-
:param name: name of the module to generate
|
| 791 |
-
:type name: str
|
| 792 |
-
:param cc: compute capability of the device the module should target
|
| 793 |
-
:type cc: int
|
| 794 |
-
:param jit: whether the module should be just-in-time compiled
|
| 795 |
-
:type jit: bool
|
| 796 |
-
:param sourcedir: directory to which generated source files should be written
|
| 797 |
-
:type sourcedir: str
|
| 798 |
-
|
| 799 |
-
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
| 800 |
-
"""
|
| 801 |
-
if op.api != ApiVersion.v2x:
|
| 802 |
-
raise Exception("Grouped GEMM is currently only supported for CUTLASS 2.x")
|
| 803 |
-
|
| 804 |
-
if sourcedir != "" and not os.path.isdir(sourcedir):
|
| 805 |
-
os.makedirs(sourcedir)
|
| 806 |
-
|
| 807 |
-
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
| 808 |
-
cuda_impl = SubstituteTemplate(_PYTORCH_GROUPED_GEMM_IMPL_TEMPLATE, {"name": name})
|
| 809 |
-
cuda_source = SubstituteTemplate(
|
| 810 |
-
_PYTORCH_CUDA_TEMPLATE,
|
| 811 |
-
{
|
| 812 |
-
"includes": _PYTORCH_GROUPED_GEMM_INCLUDES,
|
| 813 |
-
"declaration": op.rt_module.emit(),
|
| 814 |
-
"procedural_name": op.procedural_name(),
|
| 815 |
-
"impl": cuda_impl,
|
| 816 |
-
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
| 817 |
-
},
|
| 818 |
-
)
|
| 819 |
-
with open(cuda_file, "w") as outfile:
|
| 820 |
-
outfile.write(cuda_source)
|
| 821 |
-
|
| 822 |
-
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
| 823 |
-
cpp_source = SubstituteTemplate(
|
| 824 |
-
_PYTORCH_GROUPED_GEMM_CPP_TEMPLATE,
|
| 825 |
-
{"name": name, "description": f"CUTLASS {op.procedural_name()} grouped GEMM"},
|
| 826 |
-
)
|
| 827 |
-
with open(cpp_file, "w") as outfile:
|
| 828 |
-
outfile.write(cpp_source)
|
| 829 |
-
|
| 830 |
-
_generate_setup(name, sourcedir)
|
| 831 |
-
|
| 832 |
-
if jit:
|
| 833 |
-
return _jit(name, cc, cpp_file, cuda_file)
|
| 834 |
-
|
| 835 |
-
return None
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
| 839 |
-
"""
|
| 840 |
-
Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d
|
| 841 |
-
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
| 842 |
-
compiled, loaded, and returned.
|
| 843 |
-
|
| 844 |
-
:param op: operation to emit in the module
|
| 845 |
-
:param name: name of the module to generate
|
| 846 |
-
:type name: str
|
| 847 |
-
:param cc: compute capability of the device the module should target
|
| 848 |
-
:type cc: int
|
| 849 |
-
:param jit: whether the module should be just-in-time compiled
|
| 850 |
-
:type jit: bool
|
| 851 |
-
:param sourcedir: directory to which generated source files should be written
|
| 852 |
-
:type sourcedir: str
|
| 853 |
-
|
| 854 |
-
Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or
|
| 855 |
-
weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions
|
| 856 |
-
for H/W/R/S given the same P/Q.
|
| 857 |
-
|
| 858 |
-
:return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise
|
| 859 |
-
"""
|
| 860 |
-
if sourcedir != "" and not os.path.isdir(sourcedir):
|
| 861 |
-
os.makedirs(sourcedir)
|
| 862 |
-
cuda_file = os.path.join(sourcedir, name + "_kernel.cu")
|
| 863 |
-
extra_kw = {}
|
| 864 |
-
if op.conv_kind == ConvKind.Fprop:
|
| 865 |
-
impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x
|
| 866 |
-
cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE
|
| 867 |
-
elif op.conv_kind == ConvKind.Dgrad:
|
| 868 |
-
impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x
|
| 869 |
-
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
| 870 |
-
elif op.conv_kind == ConvKind.Wgrad:
|
| 871 |
-
impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x
|
| 872 |
-
cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE
|
| 873 |
-
extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize()
|
| 874 |
-
extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element]
|
| 875 |
-
cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw})
|
| 876 |
-
cuda_source = SubstituteTemplate(
|
| 877 |
-
_PYTORCH_CUDA_TEMPLATE,
|
| 878 |
-
{
|
| 879 |
-
"includes": _PYTORCH_CONV2D_INCLUDES,
|
| 880 |
-
"declaration": op.rt_module.emit(),
|
| 881 |
-
"procedural_name": op.procedural_name(),
|
| 882 |
-
"impl": cuda_impl,
|
| 883 |
-
"torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element],
|
| 884 |
-
},
|
| 885 |
-
)
|
| 886 |
-
with open(cuda_file, "w") as outfile:
|
| 887 |
-
outfile.write(cuda_source)
|
| 888 |
-
|
| 889 |
-
cpp_file = os.path.join(sourcedir, name + ".cpp")
|
| 890 |
-
cpp_source = SubstituteTemplate(
|
| 891 |
-
cpp_template,
|
| 892 |
-
{"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"},
|
| 893 |
-
)
|
| 894 |
-
with open(cpp_file, "w") as outfile:
|
| 895 |
-
outfile.write(cpp_source)
|
| 896 |
-
|
| 897 |
-
_generate_setup(name, sourcedir)
|
| 898 |
-
|
| 899 |
-
if jit:
|
| 900 |
-
return _jit(name, cc, cpp_file, cuda_file)
|
| 901 |
-
|
| 902 |
-
return None
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""):
|
| 906 |
-
"""
|
| 907 |
-
Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel
|
| 908 |
-
specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time
|
| 909 |
-
compiled, loaded, and returned.
|
| 910 |
-
|
| 911 |
-
The result of this method is files within ``sourcedir`` that can be used for building
|
| 912 |
-
a PyTorch module.
|
| 913 |
-
|
| 914 |
-
:param op: operation to emit in the module
|
| 915 |
-
:param name: name of the module to generate
|
| 916 |
-
:type name: str
|
| 917 |
-
:param cc: compute capability of the device the module should target
|
| 918 |
-
:type cc: int
|
| 919 |
-
:param jit: whether the module should be just-in-time compiled
|
| 920 |
-
:type jit: bool
|
| 921 |
-
:param sourcedir: directory to which generated source files should be written
|
| 922 |
-
:type sourcedir: str
|
| 923 |
-
|
| 924 |
-
:return: loaded PyTorch module (if ``jit=True``) or None
|
| 925 |
-
"""
|
| 926 |
-
device_op = op.device_op()
|
| 927 |
-
if isinstance(op, GemmOperationUniversal):
|
| 928 |
-
return _pytorch_gemm(device_op, name, cc, jit, sourcedir)
|
| 929 |
-
elif isinstance(op, GemmOperationGrouped):
|
| 930 |
-
return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir)
|
| 931 |
-
elif isinstance(op, Conv2dOperation):
|
| 932 |
-
return _pytorch_conv2d(device_op, name, cc, jit, sourcedir)
|
| 933 |
-
else:
|
| 934 |
-
raise Exception(
|
| 935 |
-
f"Operation type {type(op)} is not currently supported for PyTorch emission."
|
| 936 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/__init__.py
DELETED
|
@@ -1,56 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
from cutlass_cppgen.epilogue.epilogue import (
|
| 34 |
-
get_activations,
|
| 35 |
-
get_activation_epilogue,
|
| 36 |
-
gelu,
|
| 37 |
-
hardswish,
|
| 38 |
-
identity,
|
| 39 |
-
leaky_relu,
|
| 40 |
-
relu,
|
| 41 |
-
sigmoid,
|
| 42 |
-
silu,
|
| 43 |
-
tanh,
|
| 44 |
-
trace
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
from cutlass_cppgen.epilogue.evt_ops import (
|
| 48 |
-
max,
|
| 49 |
-
multiply_add,
|
| 50 |
-
sum,
|
| 51 |
-
permute,
|
| 52 |
-
reshape,
|
| 53 |
-
maximum,
|
| 54 |
-
minimum,
|
| 55 |
-
exp
|
| 56 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/epilogue.py
DELETED
|
@@ -1,176 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Registry of elementwise epilogues
|
| 35 |
-
|
| 36 |
-
Elementwise epilogues can be added to many CUTLASS kernels in the CUTLAS Python interface via
|
| 37 |
-
code like the following for GEMM:
|
| 38 |
-
|
| 39 |
-
.. highlight:: python
|
| 40 |
-
.. code-block:: python
|
| 41 |
-
|
| 42 |
-
plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 43 |
-
plan.activation = cutlass_cppgen.epilogue.relu
|
| 44 |
-
"""
|
| 45 |
-
|
| 46 |
-
from cutlass_cppgen.backend import epilogue, device_cc
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
gelu = epilogue.gelu
|
| 50 |
-
hardswish = epilogue.hardswish
|
| 51 |
-
identity = epilogue.identity
|
| 52 |
-
leaky_relu = epilogue.leaky_relu
|
| 53 |
-
relu = epilogue.relu
|
| 54 |
-
sigmoid = epilogue.sigmoid
|
| 55 |
-
silu = epilogue.silu
|
| 56 |
-
tanh = epilogue.tanh
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
_activations = [gelu, hardswish, identity, leaky_relu, relu, sigmoid, silu, tanh]
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def get_activations() -> list:
|
| 63 |
-
"""
|
| 64 |
-
Returns a list of available activation functions
|
| 65 |
-
|
| 66 |
-
:return: list of available activation functions
|
| 67 |
-
:rtype: list
|
| 68 |
-
"""
|
| 69 |
-
return _activations
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def get_activation_epilogue(
|
| 73 |
-
activation,
|
| 74 |
-
element_output,
|
| 75 |
-
elements_per_access,
|
| 76 |
-
element_accumulator,
|
| 77 |
-
element_compute,
|
| 78 |
-
):
|
| 79 |
-
"""
|
| 80 |
-
Return an epilogue corresponding to the activation function, data types, and alignment
|
| 81 |
-
used in the kernel
|
| 82 |
-
|
| 83 |
-
:param activation: elementwise activation function to use
|
| 84 |
-
:param element_output: data type of the output
|
| 85 |
-
:param elements_per_access: alignment of operand C of the kernel
|
| 86 |
-
:type elements_per_access: int
|
| 87 |
-
:param element_accumulator: data type of the accumulated output C
|
| 88 |
-
:param element_compute: data type in which compute operations should be performed
|
| 89 |
-
|
| 90 |
-
:return: epilogue functor
|
| 91 |
-
"""
|
| 92 |
-
if activation not in _activations:
|
| 93 |
-
raise Exception(
|
| 94 |
-
f"Unsupported activation type {activation}. Available activations are: {_activations}"
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
if activation == identity:
|
| 98 |
-
return epilogue.LinearCombination(
|
| 99 |
-
element_output, elements_per_access, element_accumulator, element_compute
|
| 100 |
-
)
|
| 101 |
-
else:
|
| 102 |
-
return epilogue.LinearCombinationGeneric(
|
| 103 |
-
activation,
|
| 104 |
-
element_output,
|
| 105 |
-
elements_per_access,
|
| 106 |
-
element_accumulator,
|
| 107 |
-
element_compute,
|
| 108 |
-
)
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
"""
|
| 112 |
-
Frontend for EVT that generates epilogue functor through tracing the input function
|
| 113 |
-
"""
|
| 114 |
-
from cutlass_cppgen.backend.evt.frontend import PythonASTFrontend
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
def trace(fn, example_tensors, **kwargs):
|
| 118 |
-
"""
|
| 119 |
-
Trace `fn(**example_tensors)` and generates epilogue visitor
|
| 120 |
-
|
| 121 |
-
:param fn or str: Python callable or string of the epilogue function
|
| 122 |
-
:param example_tensors: example inputs for fn
|
| 123 |
-
:type example_tensors: dict
|
| 124 |
-
|
| 125 |
-
.. hightlight:: python
|
| 126 |
-
.. code-block:: python
|
| 127 |
-
import cutlass_cppgen.backend.evt
|
| 128 |
-
|
| 129 |
-
# Define epilogue function as Python callable
|
| 130 |
-
def example_fn(accum, C, alpha, beta, gamma):
|
| 131 |
-
D = ((accum + C) * alpha - gamma) / beta
|
| 132 |
-
return D
|
| 133 |
-
|
| 134 |
-
# Define the example tensors
|
| 135 |
-
example_inputs = {
|
| 136 |
-
"accum": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
|
| 137 |
-
"C": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda"),
|
| 138 |
-
"alpha": 1.5,
|
| 139 |
-
"beta": 0.5,
|
| 140 |
-
"gamma": 2.5,
|
| 141 |
-
"D": torch.empty(size=(6, 512, 512), dtype=torch.float16, device="cuda")
|
| 142 |
-
}
|
| 143 |
-
|
| 144 |
-
# Generate the epilogue functor
|
| 145 |
-
epilogue_visitor = cutlass_cppgen.epilogue.trace(example_fn, example_inputs)
|
| 146 |
-
"""
|
| 147 |
-
if callable(fn):
|
| 148 |
-
class EpilogueFunctor(PythonASTFrontend):
|
| 149 |
-
def __init__(self, cc=None, **kwargs):
|
| 150 |
-
if not cc:
|
| 151 |
-
cc = device_cc()
|
| 152 |
-
super().__init__(cc, **kwargs)
|
| 153 |
-
pass
|
| 154 |
-
setattr(EpilogueFunctor, "__call__", staticmethod(fn))
|
| 155 |
-
|
| 156 |
-
epilogue_functor = EpilogueFunctor(**kwargs)
|
| 157 |
-
epilogue_functor.trace(example_tensors)
|
| 158 |
-
return epilogue_functor
|
| 159 |
-
elif isinstance(fn, str):
|
| 160 |
-
class EpilogueFunctor(PythonASTFrontend):
|
| 161 |
-
def __init__(self, cc=None, **kwargs):
|
| 162 |
-
self.source = textwrap.dedent(fn)
|
| 163 |
-
if not cc:
|
| 164 |
-
cc = device_cc()
|
| 165 |
-
super().__init__(cc, **kwargs)
|
| 166 |
-
|
| 167 |
-
def parse(self, example_inputs) -> None:
|
| 168 |
-
self.example_inputs = example_inputs
|
| 169 |
-
self.ast = ast.parse(self.source)
|
| 170 |
-
self.visit(self.ast)
|
| 171 |
-
|
| 172 |
-
epilogue_functor = EpilogueFunctor(**kwargs)
|
| 173 |
-
epilogue_functor.trace(example_tensors)
|
| 174 |
-
return epilogue_functor
|
| 175 |
-
else:
|
| 176 |
-
raise NotImplementedError("Expect a callable Python function")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/epilogue/evt_ops.py
DELETED
|
@@ -1,98 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Collection of builtin functions used for host reference in EVT
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import numpy as np
|
| 38 |
-
|
| 39 |
-
from cutlass_cppgen.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor
|
| 40 |
-
|
| 41 |
-
if is_torch_available():
|
| 42 |
-
import torch
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def multiply_add(x, y, z):
|
| 46 |
-
return x * y + z
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def sum(x, dim):
|
| 50 |
-
if is_numpy_tensor(x):
|
| 51 |
-
return x.sum(axis=tuple(dim))
|
| 52 |
-
elif is_torch_tensor(x):
|
| 53 |
-
return torch.sum(x, dim)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def max(x, dim):
|
| 57 |
-
if is_numpy_tensor(x):
|
| 58 |
-
return x.max(axis=tuple(dim))
|
| 59 |
-
elif is_torch_tensor(x):
|
| 60 |
-
return torch.amax(x, dim)
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def maximum(x, y):
|
| 64 |
-
if is_numpy_tensor(x):
|
| 65 |
-
return np.maximum(x, y)
|
| 66 |
-
elif is_torch_tensor(x):
|
| 67 |
-
return torch.maximum(x, torch.tensor(y))
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def minimum(x, y):
|
| 71 |
-
if is_numpy_tensor(x):
|
| 72 |
-
return np.minimum(x, y)
|
| 73 |
-
elif is_torch_tensor(x):
|
| 74 |
-
return torch.minimum(x, torch.tensor(y))
|
| 75 |
-
|
| 76 |
-
def exp(x):
|
| 77 |
-
if is_numpy_tensor(x):
|
| 78 |
-
return np.exp(x)
|
| 79 |
-
elif is_torch_tensor(x):
|
| 80 |
-
return torch.exp(x)
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
##############################################################################
|
| 84 |
-
# Layout manipulate nodes
|
| 85 |
-
##############################################################################
|
| 86 |
-
|
| 87 |
-
def permute(x, indices: tuple):
|
| 88 |
-
if is_numpy_tensor(x):
|
| 89 |
-
return np.transpose(x, axes=indices)
|
| 90 |
-
elif is_torch_tensor(x):
|
| 91 |
-
return x.permute(*indices)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def reshape(x, new_shape: tuple):
|
| 95 |
-
if is_numpy_tensor(x):
|
| 96 |
-
return np.reshape(x, newshape=new_shape)
|
| 97 |
-
elif is_torch_tensor(x):
|
| 98 |
-
return x.view(new_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/library_defaults.py
DELETED
|
@@ -1,569 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Classes containing valid operations for a given compute capability and data types.
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from itertools import combinations_with_replacement
|
| 38 |
-
import logging
|
| 39 |
-
|
| 40 |
-
import cutlass_library
|
| 41 |
-
from cutlass_library.library import ConvKind, IteratorAlgorithm, StrideSupport, GroupMode
|
| 42 |
-
|
| 43 |
-
import cutlass_cppgen
|
| 44 |
-
from cutlass_cppgen.utils.check import valid_stage_count
|
| 45 |
-
from cutlass_cppgen.utils.datatypes import td_from_profiler_td, td_from_profiler_op
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
_generator_ccs = [50, 60, 61, 70, 75, 80, 90, 100]
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
class KernelsForDataType:
|
| 52 |
-
"""
|
| 53 |
-
Container class for keeping track of kernels that correspond to a particular combination
|
| 54 |
-
of data types for operands A, B, and accumulator
|
| 55 |
-
"""
|
| 56 |
-
|
| 57 |
-
def __init__(self, datatype_comb: tuple, layout_comb: tuple):
|
| 58 |
-
self.datatype_comb = datatype_comb
|
| 59 |
-
self.layout_comb = layout_comb
|
| 60 |
-
self.math_operations = set()
|
| 61 |
-
|
| 62 |
-
# Dictionary mapping from alignment (int) to a list of kernels that fit the alignment
|
| 63 |
-
# constraint for the data type combination
|
| 64 |
-
self.kernels_by_alignment = {}
|
| 65 |
-
|
| 66 |
-
def add(self, operation):
|
| 67 |
-
"""
|
| 68 |
-
Add an operation to the list of supported kernels
|
| 69 |
-
"""
|
| 70 |
-
alignment_key = f"{operation.A.alignment} {operation.B.alignment} {operation.C.alignment}"
|
| 71 |
-
if alignment_key not in self.kernels_by_alignment:
|
| 72 |
-
self.kernels_by_alignment[alignment_key] = []
|
| 73 |
-
self.kernels_by_alignment[alignment_key].append(operation)
|
| 74 |
-
self.math_operations.add(operation.tile_description.math_instruction.math_operation)
|
| 75 |
-
|
| 76 |
-
def alignments(self, operand: str):
|
| 77 |
-
"""
|
| 78 |
-
Returns an unsorted list of alignments supported by this data type combination
|
| 79 |
-
|
| 80 |
-
:param operand: identifier of operand in question (e.g., A, B, C)
|
| 81 |
-
:type operand: str
|
| 82 |
-
|
| 83 |
-
:return: unsorted list of alignments supported by this data type combination
|
| 84 |
-
:rtype: list
|
| 85 |
-
"""
|
| 86 |
-
operand_idx = self._operand_idx(operand)
|
| 87 |
-
return [int(key.split(" ")[operand_idx]) for key in self.kernels_by_alignment.keys()]
|
| 88 |
-
|
| 89 |
-
@property
|
| 90 |
-
def all_operations(self):
|
| 91 |
-
"""
|
| 92 |
-
Returns a list of all operations supported by this data type combination
|
| 93 |
-
|
| 94 |
-
:return: list of all operations supported by this data type combination
|
| 95 |
-
:rtype: list
|
| 96 |
-
"""
|
| 97 |
-
ops = []
|
| 98 |
-
for _, alignment_ops in self.kernels_by_alignment.items():
|
| 99 |
-
ops.extend(alignment_ops)
|
| 100 |
-
return ops
|
| 101 |
-
|
| 102 |
-
def default_operation(self, math_operation: cutlass_cppgen.MathOperation):
|
| 103 |
-
key = sorted(list(self.kernels_by_alignment.keys()))[0]
|
| 104 |
-
kernels = self.kernels_by_alignment[key]
|
| 105 |
-
if math_operation is not None:
|
| 106 |
-
kernels = [x for x in kernels if x.tile_description.math_instruction.math_operation == math_operation]
|
| 107 |
-
return kernels[0]
|
| 108 |
-
|
| 109 |
-
def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_operation: cutlass_cppgen.MathOperation):
|
| 110 |
-
"""
|
| 111 |
-
Returns operations satisfying the alignment constraints
|
| 112 |
-
|
| 113 |
-
:param alignment_A: alignment constraint of operations to return
|
| 114 |
-
:type alignment_A: int
|
| 115 |
-
:param alignment_B: alignment constraint of operations to return
|
| 116 |
-
:type alignment_B: int
|
| 117 |
-
:param alignment_C: alignment constraint of operations to return
|
| 118 |
-
:type alignment_C: int
|
| 119 |
-
:param math_operation: math operation to consider
|
| 120 |
-
:type math_operation: cutlass_cppgen.MathOperation
|
| 121 |
-
|
| 122 |
-
:return: list of operations
|
| 123 |
-
:rtype: list
|
| 124 |
-
"""
|
| 125 |
-
key = f"{alignment_A} {alignment_B} {alignment_C}"
|
| 126 |
-
|
| 127 |
-
if key not in self.kernels_by_alignment:
|
| 128 |
-
og_key = key
|
| 129 |
-
# Reconcile A, B, and C alignments by trying to align to the minimum
|
| 130 |
-
min_alignment = min(alignment_A, alignment_B, alignment_C)
|
| 131 |
-
key = f"{min_alignment} {min_alignment} {min_alignment}"
|
| 132 |
-
if key not in self.kernels_by_alignment:
|
| 133 |
-
# Finally, go through all available alignment combinations and find
|
| 134 |
-
# one for which all values are less than those passed in.
|
| 135 |
-
key = None
|
| 136 |
-
alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
|
| 137 |
-
for align_A, align_B, align_C in alignments:
|
| 138 |
-
if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0:
|
| 139 |
-
key = f"{align_A} {align_B} {align_C}"
|
| 140 |
-
break
|
| 141 |
-
|
| 142 |
-
if key is None:
|
| 143 |
-
raise Exception(
|
| 144 |
-
f"No operations of alignment {og_key} found for data type and layout "
|
| 145 |
-
f"combination {self.datatype_comb} {self.layout_comb}. Compatible alignments "
|
| 146 |
-
f"are {self.kernels_by_alignment.keys()}"
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
ops = self.kernels_by_alignment[key]
|
| 150 |
-
if math_operation is not None:
|
| 151 |
-
ops = [op for op in ops if op.tile_description.math_instruction.math_operation == math_operation]
|
| 152 |
-
return ops
|
| 153 |
-
|
| 154 |
-
def _operand_idx(self, key: str) -> int:
|
| 155 |
-
operand_list = ["A", "B", "C"]
|
| 156 |
-
if key not in operand_list:
|
| 157 |
-
raise Exception(f"Unexpected operand {operand}")
|
| 158 |
-
|
| 159 |
-
return operand_list.index(key)
|
| 160 |
-
|
| 161 |
-
def find_alignment(self, shape: tuple, layout: cutlass_cppgen.LayoutType, operand=str) -> int:
|
| 162 |
-
"""
|
| 163 |
-
Returns the most preferable alignment for a given shape and layout
|
| 164 |
-
|
| 165 |
-
:param shape: extent of each dimension of the tensor
|
| 166 |
-
:type shape: tuple
|
| 167 |
-
:param layout: layout of the tensor
|
| 168 |
-
:type layout: cutlass_cppgen.LayoutType
|
| 169 |
-
:param operand: descriptor of the operand in question
|
| 170 |
-
:type operand: str
|
| 171 |
-
|
| 172 |
-
:return: maximum alignment supported by the data type combination and tensor size
|
| 173 |
-
:rtype: int
|
| 174 |
-
"""
|
| 175 |
-
operand_idx = self._operand_idx(operand)
|
| 176 |
-
|
| 177 |
-
# Determine the leading dimension of the shape
|
| 178 |
-
if layout == cutlass_cppgen.LayoutType.ColumnMajor:
|
| 179 |
-
ld = shape[-2]
|
| 180 |
-
elif layout == cutlass_cppgen.LayoutType.RowMajor:
|
| 181 |
-
ld = shape[-1]
|
| 182 |
-
elif layout == cutlass_cppgen.LayoutType.TensorNHWC:
|
| 183 |
-
ld = shape[-1]
|
| 184 |
-
else:
|
| 185 |
-
raise Exception(f"Unexpected or unsupported layout {layout}")
|
| 186 |
-
|
| 187 |
-
for alignments in sorted(list(self.kernels_by_alignment.keys()), reverse=True):
|
| 188 |
-
alignment = int(alignments.split(" ")[operand_idx])
|
| 189 |
-
if ld % alignment == 0:
|
| 190 |
-
return alignment
|
| 191 |
-
|
| 192 |
-
# Default to alignment of 1 if no others match
|
| 193 |
-
return 1
|
| 194 |
-
|
| 195 |
-
def sort(self):
|
| 196 |
-
"""
|
| 197 |
-
Sorts each list of kernels in `kernels_by_alignment` in descending order of threadblock shape
|
| 198 |
-
"""
|
| 199 |
-
key = lambda op: (
|
| 200 |
-
op.tile_description.threadblock_shape[0]
|
| 201 |
-
* op.tile_description.threadblock_shape[1]
|
| 202 |
-
* op.tile_description.threadblock_shape[2]
|
| 203 |
-
)
|
| 204 |
-
for alignment in self.kernels_by_alignment.keys():
|
| 205 |
-
self.kernels_by_alignment[alignment].sort(key=key, reverse=True)
|
| 206 |
-
|
| 207 |
-
def supports_math_operation(self, math_operation: cutlass_cppgen.MathOperation) -> bool:
|
| 208 |
-
"""
|
| 209 |
-
Returns whether `math_operation` is supported by at least one operation.
|
| 210 |
-
|
| 211 |
-
:param math_operation: math operation to consider
|
| 212 |
-
:type math_operation: cutlass_cppgen.MathOperation
|
| 213 |
-
|
| 214 |
-
:return: whether math_operation is supported by at least one operation
|
| 215 |
-
:rtype: bool
|
| 216 |
-
"""
|
| 217 |
-
return math_operation is None or math_operation in self.math_operations
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
class ArchOptions:
|
| 221 |
-
"""
|
| 222 |
-
Structure for keeping track of kernels available on a given compute capability
|
| 223 |
-
|
| 224 |
-
:param target_cc: compute capability of the device on which kernels will be run
|
| 225 |
-
:type target_cc: int
|
| 226 |
-
:param kernel_cc: compute capability of the kernels to generate
|
| 227 |
-
:type kernel_cc: int
|
| 228 |
-
:param operation_kind: type of operation to register
|
| 229 |
-
:type operation_kind: cutlass_library.OperationKind
|
| 230 |
-
:param gemm_kinds: types of GEMM operations that can be included
|
| 231 |
-
:type gemm_kinds: list
|
| 232 |
-
:param allowed_math_operations: types of primitive math operations allowed
|
| 233 |
-
:type allowed_math_operations: list
|
| 234 |
-
"""
|
| 235 |
-
|
| 236 |
-
def __init__(
|
| 237 |
-
self,
|
| 238 |
-
target_cc: int,
|
| 239 |
-
kernel_cc: int,
|
| 240 |
-
operation_kind: cutlass_library.OperationKind,
|
| 241 |
-
gemm_kinds: list,
|
| 242 |
-
allowed_math_operations: list = [
|
| 243 |
-
cutlass_library.MathOperation.multiply_add,
|
| 244 |
-
cutlass_library.MathOperation.multiply_add_saturate,
|
| 245 |
-
cutlass_library.MathOperation.multiply_add_mixed_input_upcast,
|
| 246 |
-
cutlass_library.MathOperation.multiply_add_fast_f32
|
| 247 |
-
]
|
| 248 |
-
):
|
| 249 |
-
self.cc = kernel_cc
|
| 250 |
-
|
| 251 |
-
# Dictionary with following structure:
|
| 252 |
-
# Key: OpcodeClass
|
| 253 |
-
# Value: Dictionary with the following structure:
|
| 254 |
-
# Key: tuple of ((DataType, DataType, DataType), (LayoutType, LayoutType, LayoutType),
|
| 255 |
-
# representing ((element_a, element_b, element_accumulator), (layout_a, layout_b))
|
| 256 |
-
# Value: KernelsForDataType
|
| 257 |
-
self.operations_by_opclass = {}
|
| 258 |
-
self.op_class = None
|
| 259 |
-
self.allowed_math_operations = allowed_math_operations
|
| 260 |
-
|
| 261 |
-
if target_cc == 100 and kernel_cc == 90 or target_cc == 90 and kernel_cc == 100:
|
| 262 |
-
return
|
| 263 |
-
|
| 264 |
-
# Identify the method within CUTLASS generator script that generates kernel
|
| 265 |
-
# descriptions for the target CC
|
| 266 |
-
generate_function_name = "GenerateSM" + str(kernel_cc)
|
| 267 |
-
if not hasattr(cutlass_library.generator, generate_function_name):
|
| 268 |
-
cutlass_cppgen.logger.warning(f"No generator found for architecture {kernel_cc}")
|
| 269 |
-
return
|
| 270 |
-
generate_function = getattr(cutlass_library.generator, generate_function_name)
|
| 271 |
-
|
| 272 |
-
# Initialize a default manifest and populate it with valid kernel descriptions
|
| 273 |
-
# for the target CC
|
| 274 |
-
args = [
|
| 275 |
-
"--kernels=all",
|
| 276 |
-
f"--log-level={logging.getLevelName(cutlass_cppgen.logger.level)}"
|
| 277 |
-
]
|
| 278 |
-
manifest_args = cutlass_library.generator.define_parser().parse_args(args)
|
| 279 |
-
manifest = cutlass_library.manifest.Manifest(manifest_args)
|
| 280 |
-
generate_function(manifest, cutlass_cppgen._nvcc_version)
|
| 281 |
-
|
| 282 |
-
if operation_kind not in manifest.operations:
|
| 283 |
-
# No kernels generated for this architecture, this could be because the CUDA
|
| 284 |
-
# toolkit is insufficient to support operations in this CC
|
| 285 |
-
cutlass_cppgen.logger.warning(f"No operations of type {operation_kind} found for CC {kernel_cc}")
|
| 286 |
-
return
|
| 287 |
-
|
| 288 |
-
# Only one CC should be returned, given the setup above of calling only the generation scripts
|
| 289 |
-
# for a given CC
|
| 290 |
-
if len(manifest.operations[operation_kind].keys()) != 1 or kernel_cc not in manifest.operations[operation_kind]:
|
| 291 |
-
raise Exception(f"Error finding kernels for SM{kernel_cc}. Check that your CUDA toolkit version "
|
| 292 |
-
"is sufficient for the architecture in question.")
|
| 293 |
-
|
| 294 |
-
# Iterate through the available operations for this operation kind and
|
| 295 |
-
# find available opclasses and data types
|
| 296 |
-
for name, op_list in manifest.operations[operation_kind][kernel_cc].items():
|
| 297 |
-
for op in op_list:
|
| 298 |
-
|
| 299 |
-
if operation_kind == cutlass_library.OperationKind.Gemm:
|
| 300 |
-
if op.gemm_kind not in gemm_kinds:
|
| 301 |
-
continue
|
| 302 |
-
|
| 303 |
-
mi = op.tile_description.math_instruction
|
| 304 |
-
if mi.math_operation not in self.allowed_math_operations:
|
| 305 |
-
continue
|
| 306 |
-
|
| 307 |
-
# Prune operations that don't fit in shared memory
|
| 308 |
-
td = td_from_profiler_op(op)
|
| 309 |
-
if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]:
|
| 310 |
-
continue
|
| 311 |
-
|
| 312 |
-
if mi.opcode_class not in self.operations_by_opclass:
|
| 313 |
-
self.operations_by_opclass[mi.opcode_class] = {}
|
| 314 |
-
|
| 315 |
-
datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator)
|
| 316 |
-
layout_comb = (op.A.layout, op.B.layout)
|
| 317 |
-
|
| 318 |
-
# Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations
|
| 319 |
-
if datatype_comb == (cutlass_library.DataType.tf32, cutlass_library.DataType.tf32, cutlass_library.DataType.f32):
|
| 320 |
-
# TF32 kernels only supported on SM80 and beyond
|
| 321 |
-
if self.cc < 80:
|
| 322 |
-
continue
|
| 323 |
-
elif self.cc == 90 or self.cc == 100:
|
| 324 |
-
if (op.A.element != cutlass_library.DataType.f32
|
| 325 |
-
or op.B.element != cutlass_library.DataType.f32
|
| 326 |
-
or op.C.element != cutlass_library.DataType.f32):
|
| 327 |
-
continue
|
| 328 |
-
|
| 329 |
-
datatype_comb = (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32)
|
| 330 |
-
|
| 331 |
-
opclass_dict = self.operations_by_opclass[mi.opcode_class]
|
| 332 |
-
key = (datatype_comb, layout_comb)
|
| 333 |
-
if key not in opclass_dict:
|
| 334 |
-
opclass_dict[key] = KernelsForDataType(datatype_comb, layout_comb)
|
| 335 |
-
opclass_dict[key].add(op)
|
| 336 |
-
|
| 337 |
-
# Set the default opclass to TensorOp, if available. Otherwise default to SIMT
|
| 338 |
-
if cutlass_library.OpcodeClass.TensorOp in self.operations_by_opclass:
|
| 339 |
-
self.op_class = cutlass_library.OpcodeClass.TensorOp
|
| 340 |
-
else:
|
| 341 |
-
self.op_class = cutlass_library.OpcodeClass.Simt
|
| 342 |
-
|
| 343 |
-
# The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels.
|
| 344 |
-
# Here, we generate additional versions via a generic TileDescription.
|
| 345 |
-
if cutlass_library.OpcodeClass.Simt not in self.operations_by_opclass:
|
| 346 |
-
self.operations_by_opclass[cutlass_library.OpcodeClass.Simt] = {}
|
| 347 |
-
|
| 348 |
-
if operation_kind == cutlass_library.OperationKind.Gemm:
|
| 349 |
-
types = [
|
| 350 |
-
(cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s8),
|
| 351 |
-
(cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32),
|
| 352 |
-
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
|
| 353 |
-
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
|
| 354 |
-
(cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
|
| 355 |
-
(cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
|
| 356 |
-
]
|
| 357 |
-
|
| 358 |
-
# Add FP8 A/B/C
|
| 359 |
-
fp8_types = [cutlass_library.DataType.e4m3, cutlass_library.DataType.e5m2]
|
| 360 |
-
for type_comb in combinations_with_replacement(fp8_types, 3):
|
| 361 |
-
types.append(type_comb)
|
| 362 |
-
|
| 363 |
-
# Add FP8 A/B with FP32 C
|
| 364 |
-
for type_comb in combinations_with_replacement(fp8_types, 2):
|
| 365 |
-
types.append(type_comb + (cutlass_cppgen.DataType.f32,))
|
| 366 |
-
|
| 367 |
-
layouts = [
|
| 368 |
-
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor),
|
| 369 |
-
(cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor),
|
| 370 |
-
(cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.RowMajor),
|
| 371 |
-
(cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.ColumnMajor),
|
| 372 |
-
]
|
| 373 |
-
elif operation_kind == cutlass_library.OperationKind.Conv2d:
|
| 374 |
-
types = [
|
| 375 |
-
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16),
|
| 376 |
-
(cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32),
|
| 377 |
-
(cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32),
|
| 378 |
-
(cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64),
|
| 379 |
-
]
|
| 380 |
-
|
| 381 |
-
layouts = [
|
| 382 |
-
(cutlass_library.LayoutType.TensorNHWC, cutlass_library.LayoutType.TensorNHWC),
|
| 383 |
-
]
|
| 384 |
-
else:
|
| 385 |
-
raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.")
|
| 386 |
-
|
| 387 |
-
alignment = 1
|
| 388 |
-
epilogue_functor = cutlass_library.EpilogueFunctor.LinearCombination
|
| 389 |
-
swizzling_functor = cutlass_library.SwizzlingFunctor.Identity8
|
| 390 |
-
for type_comb in types:
|
| 391 |
-
for layout_comb in layouts:
|
| 392 |
-
comb = (type_comb, layout_comb)
|
| 393 |
-
if comb in self.operations_by_opclass[cutlass_library.OpcodeClass.Simt]:
|
| 394 |
-
continue
|
| 395 |
-
|
| 396 |
-
A = cutlass_library.TensorDescription(type_comb[0], layout_comb[0], alignment)
|
| 397 |
-
B = cutlass_library.TensorDescription(type_comb[1], layout_comb[1], alignment)
|
| 398 |
-
C = cutlass_library.TensorDescription(type_comb[2], cutlass_library.LayoutType.ColumnMajor, alignment)
|
| 399 |
-
math_inst = cutlass_library.MathInstruction(
|
| 400 |
-
[1, 1, 1],
|
| 401 |
-
type_comb[0],
|
| 402 |
-
type_comb[1],
|
| 403 |
-
type_comb[2],
|
| 404 |
-
cutlass_library.OpcodeClass.Simt,
|
| 405 |
-
cutlass_library.MathOperation.multiply_add
|
| 406 |
-
)
|
| 407 |
-
|
| 408 |
-
td = cutlass_library.TileDescription(
|
| 409 |
-
[128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024)
|
| 410 |
-
|
| 411 |
-
# Prune operations that don't fit in shared memory
|
| 412 |
-
if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td), verbose=False)[0]:
|
| 413 |
-
continue
|
| 414 |
-
|
| 415 |
-
new_kernels = KernelsForDataType(type_comb, layout_comb)
|
| 416 |
-
|
| 417 |
-
if operation_kind == cutlass_library.OperationKind.Gemm:
|
| 418 |
-
new_operation = cutlass_library.manifest.GemmOperation(
|
| 419 |
-
cutlass_library.GemmKind.Universal, td.minimum_compute_capability,
|
| 420 |
-
td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor)
|
| 421 |
-
new_kernels.add(new_operation)
|
| 422 |
-
elif operation_kind == cutlass_library.OperationKind.Conv2d:
|
| 423 |
-
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
| 424 |
-
new_operation = cutlass_library.manifest.Conv2dOperation(
|
| 425 |
-
conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td,
|
| 426 |
-
A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor,
|
| 427 |
-
group_mode=GroupMode.SingleGroup
|
| 428 |
-
)
|
| 429 |
-
new_kernels.add(new_operation)
|
| 430 |
-
|
| 431 |
-
self.operations_by_opclass[cutlass_library.OpcodeClass.Simt][comb] = new_kernels
|
| 432 |
-
|
| 433 |
-
# Sort all operations
|
| 434 |
-
for oc in self.operations_by_opclass.keys():
|
| 435 |
-
for comb in self.operations_by_opclass[oc].keys():
|
| 436 |
-
self.operations_by_opclass[oc][comb].sort()
|
| 437 |
-
|
| 438 |
-
def opclass_supports_combination(
|
| 439 |
-
self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple, math_operation: cutlass_library.MathOperation
|
| 440 |
-
) -> bool:
|
| 441 |
-
"""
|
| 442 |
-
Returns whether the provided operation class supports the provided data type and layout combination
|
| 443 |
-
|
| 444 |
-
:param op_class: operation class to consider
|
| 445 |
-
:type op_class: cutlass_library.OpcodeClass
|
| 446 |
-
:param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator)
|
| 447 |
-
:type datatype_comb: tuple[cutlass_library.DataType]
|
| 448 |
-
:param layout_comb: tuple of data types for (layout_A, layout_B)
|
| 449 |
-
:type layout_comb: tuple[cutlass_library.LayoutType]
|
| 450 |
-
:param math_operation: math operation to consider or None if any can be considered
|
| 451 |
-
:type math_operation: cutlass_cppgen.MathOperation
|
| 452 |
-
|
| 453 |
-
:return: set of operation classes that support the provided data type and layout combination
|
| 454 |
-
:rtype: set
|
| 455 |
-
"""
|
| 456 |
-
if op_class not in self.operations_by_opclass:
|
| 457 |
-
raise Exception(f"Unexpected or unsupported operation class {op_class}")
|
| 458 |
-
|
| 459 |
-
if operations := self.operations_by_opclass[op_class].get((datatype_comb, layout_comb)):
|
| 460 |
-
if math_operation is not None:
|
| 461 |
-
return operations.supports_math_operation(math_operation)
|
| 462 |
-
else:
|
| 463 |
-
return True
|
| 464 |
-
|
| 465 |
-
return False
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
def supporting_opclasses(
|
| 469 |
-
self,
|
| 470 |
-
element_a: cutlass_library.DataType,
|
| 471 |
-
element_b: cutlass_library.DataType,
|
| 472 |
-
element_accumulator: cutlass_library.DataType,
|
| 473 |
-
layout_a: cutlass_library.LayoutType,
|
| 474 |
-
layout_b: cutlass_library.LayoutType,
|
| 475 |
-
math_operation: cutlass_library.MathOperation,
|
| 476 |
-
) -> set:
|
| 477 |
-
"""
|
| 478 |
-
Returns a set of operation classes that support the provided data type combination
|
| 479 |
-
|
| 480 |
-
:param element_a: data type of operand A
|
| 481 |
-
:type element_a: cutlass_library.DataType
|
| 482 |
-
:param element_b: data type of operand B
|
| 483 |
-
:type element_b: cutlass_library.DataType
|
| 484 |
-
:param element_accumulator: data type of accumulator
|
| 485 |
-
:type element_accumulator: cutlass_library.DataType
|
| 486 |
-
:param layout_a: layout of operand A
|
| 487 |
-
:type layout_a: cutlass_library.LayoutType
|
| 488 |
-
:param layout_b: layout of operand B
|
| 489 |
-
:type layout_b: cutlass_library.LayoutType
|
| 490 |
-
:param math_operation: math operation to consider
|
| 491 |
-
:type math_operation: cutlass_cppgen.MathOperation
|
| 492 |
-
|
| 493 |
-
:return: set of operation classes that support the provided data type combination
|
| 494 |
-
:rtype: set
|
| 495 |
-
"""
|
| 496 |
-
supporting_op_classes = set()
|
| 497 |
-
datatype_comb = (element_a, element_b, element_accumulator)
|
| 498 |
-
layout_comb = (layout_a, layout_b)
|
| 499 |
-
|
| 500 |
-
for op_class in self.operations_by_opclass.keys():
|
| 501 |
-
if self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
|
| 502 |
-
supporting_op_classes.add(op_class)
|
| 503 |
-
return supporting_op_classes
|
| 504 |
-
|
| 505 |
-
def operations(
|
| 506 |
-
self,
|
| 507 |
-
op_class: cutlass_library.OpcodeClass,
|
| 508 |
-
element_a: cutlass_library.DataType,
|
| 509 |
-
element_b: cutlass_library.DataType,
|
| 510 |
-
element_accumulator: cutlass_library.DataType,
|
| 511 |
-
layout_a: cutlass_library.LayoutType,
|
| 512 |
-
layout_b: cutlass_library.LayoutType,
|
| 513 |
-
math_operation: cutlass_library.MathOperation,
|
| 514 |
-
) -> KernelsForDataType:
|
| 515 |
-
"""
|
| 516 |
-
Returns whether the provided operation class supports the provided data type combination
|
| 517 |
-
|
| 518 |
-
:param op_class: operation class to consider
|
| 519 |
-
:type op_class: cutlass_library.OpcodeClass
|
| 520 |
-
:param element_a: data type of operand A
|
| 521 |
-
:type element_a: cutlass_library.DataType
|
| 522 |
-
:param element_b: data type of operand B
|
| 523 |
-
:type element_b: cutlass_library.DataType
|
| 524 |
-
:param element_accumulator: data type of accumulator
|
| 525 |
-
:type element_accumulator: cutlass_library.DataType
|
| 526 |
-
:param layout_a: layout of operand A
|
| 527 |
-
:type layout_a: cutlass_library.LayoutType
|
| 528 |
-
:param layout_b: layout of operand B
|
| 529 |
-
:type layout_b: cutlass_library.LayoutType
|
| 530 |
-
:param math_operation: math operation to consider
|
| 531 |
-
:type math_operation: cutlass_cppgen.MathOperation
|
| 532 |
-
|
| 533 |
-
:return: container of kernels by alignment supported by the provided combination of parameters
|
| 534 |
-
:rtype: KernelsForDataType
|
| 535 |
-
"""
|
| 536 |
-
datatype_comb = (element_a, element_b, element_accumulator)
|
| 537 |
-
layout_comb = (layout_a, layout_b)
|
| 538 |
-
if not self.opclass_supports_combination(op_class, datatype_comb, layout_comb, math_operation):
|
| 539 |
-
raise Exception(
|
| 540 |
-
f"Data type layout combination {datatype_comb}, {layout_comb} "
|
| 541 |
-
f"is not supported by opcode class {op_class} on CC {self.cc}."
|
| 542 |
-
)
|
| 543 |
-
return self.operations_by_opclass[op_class][(datatype_comb, layout_comb)]
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
class OptionRegistry:
|
| 547 |
-
"""
|
| 548 |
-
Container of all architecture-specific options
|
| 549 |
-
|
| 550 |
-
:param target_cc: compute capability of the device on which operations will be run
|
| 551 |
-
:type target_cc: int
|
| 552 |
-
"""
|
| 553 |
-
|
| 554 |
-
def __init__(self, target_cc: int):
|
| 555 |
-
self.registry = {}
|
| 556 |
-
|
| 557 |
-
if target_cc > 100 and (target_cc not in [101, 103, 120, 121]):
|
| 558 |
-
raise Exception(f"Unsupported compute capability {target_cc}. The CUTLASS Python interface only supports compute capabilities up to the Blackwell architecture.")
|
| 559 |
-
|
| 560 |
-
gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x]
|
| 561 |
-
operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d]
|
| 562 |
-
# Construct options for each CC
|
| 563 |
-
for kernel_cc in _generator_ccs:
|
| 564 |
-
self.registry[kernel_cc] = {}
|
| 565 |
-
for opkind in operation_kinds:
|
| 566 |
-
self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds)
|
| 567 |
-
|
| 568 |
-
def options_for_cc(self, cc: int, op_kind=cutlass_library.OperationKind.Gemm) -> ArchOptions:
|
| 569 |
-
return self.registry.get(cc, None)[op_kind]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/__init__.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
from cutlass_cppgen.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad
|
| 34 |
-
from cutlass_cppgen.op.gemm import Gemm
|
| 35 |
-
from cutlass_cppgen.op.gemm_grouped import GroupedGemm
|
| 36 |
-
from cutlass_cppgen.op.op import OperationBase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/conv.py
DELETED
|
@@ -1,997 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Ease-of-use interface for constructing, compiling, and running CONVs
|
| 35 |
-
|
| 36 |
-
The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run
|
| 37 |
-
CONV2D operations in CUTLASS via Python, without specifying many configuration parameters.
|
| 38 |
-
Under the hood, the interface will select sensible default parameters for the many template
|
| 39 |
-
parameters for CUTLASS CONVs.
|
| 40 |
-
|
| 41 |
-
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
| 42 |
-
performance, one should specify and tune each configuration parameter.
|
| 43 |
-
|
| 44 |
-
The simplest example of using this interface is the following:
|
| 45 |
-
|
| 46 |
-
.. highlight:: python
|
| 47 |
-
.. code-block:: python
|
| 48 |
-
|
| 49 |
-
# A, B, C, and D are torch/numpy/cupy tensor objects
|
| 50 |
-
plan = cutlass_cppgen.op.Conv(A, B, C, D)
|
| 51 |
-
plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
| 52 |
-
|
| 53 |
-
One can also use the interface by specifying data types of operands at construction
|
| 54 |
-
and using different tensor objects with these data types at runtime:
|
| 55 |
-
|
| 56 |
-
.. highlight:: python
|
| 57 |
-
.. code-block:: python
|
| 58 |
-
|
| 59 |
-
# The following is shorthand for:
|
| 60 |
-
# cutlass_cppgen.op.Conv2d(kind="fprop",
|
| 61 |
-
# element_A=torch.float32, element_B=torch.float32,
|
| 62 |
-
# element_C=torch.float32, element_D=torch.float32,
|
| 63 |
-
# element_accumulator=torch.float32)
|
| 64 |
-
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=torch.float32)
|
| 65 |
-
|
| 66 |
-
A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
| 67 |
-
B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda')
|
| 68 |
-
C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda')
|
| 69 |
-
D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda')
|
| 70 |
-
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
| 71 |
-
|
| 72 |
-
A = torch.rand((32, 128), dtype=torch.float32, device='cuda')
|
| 73 |
-
B = torch.rand((128, 256), dtype=torch.float32, device='cuda')
|
| 74 |
-
C = torch.zeros((32, 256), dtype=torch.float32, device='cuda')
|
| 75 |
-
D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda')
|
| 76 |
-
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
| 77 |
-
|
| 78 |
-
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
|
| 79 |
-
kernel from its execution:
|
| 80 |
-
|
| 81 |
-
.. highlight:: python
|
| 82 |
-
.. code-block:: python
|
| 83 |
-
|
| 84 |
-
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
| 85 |
-
|
| 86 |
-
# Do other work...
|
| 87 |
-
|
| 88 |
-
plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
| 89 |
-
|
| 90 |
-
# Do other work...
|
| 91 |
-
|
| 92 |
-
plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1))
|
| 93 |
-
|
| 94 |
-
Elementwise activation functions are easily fused to the GEMM via the interface:
|
| 95 |
-
|
| 96 |
-
.. highlight:: python
|
| 97 |
-
.. code-block:: python
|
| 98 |
-
|
| 99 |
-
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
| 100 |
-
plan.activation = cutlass_cppgen.epilogue.relu
|
| 101 |
-
|
| 102 |
-
Operations can also be run asynchronously:
|
| 103 |
-
|
| 104 |
-
.. highlight:: python
|
| 105 |
-
.. code-block:: python
|
| 106 |
-
|
| 107 |
-
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=np.float32)
|
| 108 |
-
args = plan.run()
|
| 109 |
-
|
| 110 |
-
# Do other work...
|
| 111 |
-
|
| 112 |
-
args.sync()
|
| 113 |
-
"""
|
| 114 |
-
|
| 115 |
-
from __future__ import annotations
|
| 116 |
-
from typing import Optional
|
| 117 |
-
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 118 |
-
cuda = lazy_import("cuda.cuda")
|
| 119 |
-
cudart = lazy_import("cuda.cudart")
|
| 120 |
-
from cutlass_library import (
|
| 121 |
-
ConvKind,
|
| 122 |
-
ConvMode,
|
| 123 |
-
DataTypeSize,
|
| 124 |
-
IteratorAlgorithm,
|
| 125 |
-
OperationKind,
|
| 126 |
-
SplitKMode,
|
| 127 |
-
StrideSupport,
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
import cutlass_cppgen
|
| 131 |
-
from cutlass_cppgen import epilogue
|
| 132 |
-
from cutlass_cppgen.backend import compiler
|
| 133 |
-
from cutlass_cppgen.backend.conv2d_operation import Conv2dArguments, Conv2dOperation
|
| 134 |
-
from cutlass_cppgen.backend.reduction_operation import ReductionOperation, ReductionArguments
|
| 135 |
-
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
|
| 136 |
-
from cutlass_cppgen.op.op import OperationBase
|
| 137 |
-
from cutlass_cppgen.shape import Conv2DProblemSize, MatrixCoord
|
| 138 |
-
from cutlass_cppgen.utils import check, datatypes
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
class Conv2d(OperationBase):
|
| 142 |
-
"""
|
| 143 |
-
Constructs a ``Conv2d`` object.
|
| 144 |
-
|
| 145 |
-
The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C,
|
| 146 |
-
along with the data type of output D and that used for accumulation, are bound to the ``Conv``
|
| 147 |
-
object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed.
|
| 148 |
-
|
| 149 |
-
The constructor has optional parameters for flexibly setting these parameters. The following
|
| 150 |
-
constructors are equivalent:
|
| 151 |
-
|
| 152 |
-
.. highlight:: python
|
| 153 |
-
.. code-block:: python
|
| 154 |
-
|
| 155 |
-
# Use F32 for A, B, C, D, and accumulation in fprop
|
| 156 |
-
|
| 157 |
-
# Use the generic ``element`` parameter to concisely set all data types for operands to the same values.
|
| 158 |
-
Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f32)
|
| 159 |
-
|
| 160 |
-
# Explicitly specify the data types to use for A, B, C, and D.
|
| 161 |
-
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32,
|
| 162 |
-
element_C=cutlass_cppgen.DataType.f32, element_D=cutlass_cppgen.DataType.f32)
|
| 163 |
-
|
| 164 |
-
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
| 165 |
-
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
| 166 |
-
# have the same data type as those passed in here).
|
| 167 |
-
# A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout
|
| 168 |
-
Conv2d(kind="fprop", A=A, B=B, C=C, D=D)
|
| 169 |
-
|
| 170 |
-
# Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit
|
| 171 |
-
# those passed in via the generic ``element``
|
| 172 |
-
Conv2d(kind="fprop", element_A=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
|
| 173 |
-
element=cutlass_cppgen.DataType.f32)
|
| 174 |
-
|
| 175 |
-
The order of precedence for the setting of the data type for a given operand/output is as follows:
|
| 176 |
-
1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor
|
| 177 |
-
2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those
|
| 178 |
-
3) Otherwise, use the generic values (e.g., ``element``)
|
| 179 |
-
|
| 180 |
-
:param kind: the convolution kind (i.e. fprop, wgrad, and dgrad)
|
| 181 |
-
:type kind: str
|
| 182 |
-
:param A: tensor representing data type of operand A
|
| 183 |
-
:param B: tensor representing data type of operand B
|
| 184 |
-
:param C: tensor representing data type of operand C
|
| 185 |
-
:param D: tensor representing data type of operand D
|
| 186 |
-
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 187 |
-
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 188 |
-
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
| 189 |
-
:type element: cutlass_cppgen.DataType
|
| 190 |
-
:param element_A: data type to be used for operand A
|
| 191 |
-
:type element_A: cutlass_cppgen.DataType
|
| 192 |
-
:param element_B: data type to be used for operand B
|
| 193 |
-
:type element_B: cutlass_cppgen.DataType
|
| 194 |
-
:param element_C: data type to be used for operand C
|
| 195 |
-
:type element_C: cutlass_cppgen.DataType
|
| 196 |
-
:param element_D: data type to be used for operand D
|
| 197 |
-
:type element_D: cutlass_cppgen.DataType
|
| 198 |
-
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
| 199 |
-
:type element_accumulator: cutlass_cppgen.DataType
|
| 200 |
-
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
| 201 |
-
:type cc: int
|
| 202 |
-
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
| 203 |
-
:type kernel_cc: int
|
| 204 |
-
"""
|
| 205 |
-
def __init__(
|
| 206 |
-
self, kind="fprop",
|
| 207 |
-
A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0,
|
| 208 |
-
element=None,
|
| 209 |
-
element_A=None, element_B=None, element_C=None, element_D=None,
|
| 210 |
-
element_accumulator=None,
|
| 211 |
-
cc: int = None, kernel_cc: int = None
|
| 212 |
-
):
|
| 213 |
-
super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d)
|
| 214 |
-
# Verify the kernel cc
|
| 215 |
-
if self.current_cc in [90, 100, 101, 103]:
|
| 216 |
-
# The Conv2d kernel on Hopper (SM90) is currently unsupported
|
| 217 |
-
# Revert to use SM80-tagged kernels
|
| 218 |
-
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
| 219 |
-
self.specified_kernel_cc = 80
|
| 220 |
-
self._reset_options(80)
|
| 221 |
-
|
| 222 |
-
# The arch is used in testing
|
| 223 |
-
self.arch = self.current_cc
|
| 224 |
-
self.name = "conv2d" + kind
|
| 225 |
-
|
| 226 |
-
# The convolution kind. (concept: cutlass_library.library.ConvKind)
|
| 227 |
-
self.conv_kind = datatypes.getattr_enum(ConvKind, kind)
|
| 228 |
-
|
| 229 |
-
# The element types (concept: cutlass library types) of A, B, C, and D
|
| 230 |
-
elements = []
|
| 231 |
-
layouts = []
|
| 232 |
-
|
| 233 |
-
# Complete the data types based on user-provided arguments
|
| 234 |
-
for elt, tens, name in zip([element_A, element_B, element_C, element_D],
|
| 235 |
-
[A, B, C, D],
|
| 236 |
-
["A", "B", "C", "D"]):
|
| 237 |
-
if elt is not None and tens is not None:
|
| 238 |
-
raise Exception(f'Must not specify both element_{name} and tensor {name}')
|
| 239 |
-
if elt is None and tens is None and element is None:
|
| 240 |
-
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
|
| 241 |
-
|
| 242 |
-
elt_to_set = None
|
| 243 |
-
lay_to_set = None
|
| 244 |
-
|
| 245 |
-
if tens is not None:
|
| 246 |
-
elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
|
| 247 |
-
else:
|
| 248 |
-
elt_to_set = elt if elt is not None else element
|
| 249 |
-
|
| 250 |
-
assert elt_to_set is not None
|
| 251 |
-
|
| 252 |
-
# Currently we only support layout TensorNHWC
|
| 253 |
-
lay_to_set = cutlass_cppgen.LayoutType.TensorNHWC
|
| 254 |
-
elements.append(datatypes.library_type(elt_to_set))
|
| 255 |
-
layouts.append(lay_to_set)
|
| 256 |
-
|
| 257 |
-
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
| 258 |
-
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
| 259 |
-
|
| 260 |
-
self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta
|
| 261 |
-
|
| 262 |
-
if element_accumulator is None:
|
| 263 |
-
self._element_accumulator = self._element_c
|
| 264 |
-
else:
|
| 265 |
-
self._element_accumulator = datatypes.library_type(element_accumulator)
|
| 266 |
-
|
| 267 |
-
# Default inputs if none is supplied in run()
|
| 268 |
-
self.A = A
|
| 269 |
-
self.B = B
|
| 270 |
-
self.C = C
|
| 271 |
-
self.D = D
|
| 272 |
-
|
| 273 |
-
self.alpha = alpha
|
| 274 |
-
self.beta = beta
|
| 275 |
-
|
| 276 |
-
# We only specify the stride of the swizzling functor here
|
| 277 |
-
# The actual swizzling functor is determined in run based on conv_kind and stride
|
| 278 |
-
self._swizzling_stride = 1
|
| 279 |
-
|
| 280 |
-
# Arguments that will be set to default value in _reset_operations
|
| 281 |
-
# The default tile_description and op_class are fetched from manifest of cutlass library
|
| 282 |
-
self._tile_description = None
|
| 283 |
-
self.op_class = None
|
| 284 |
-
# The default identity epilogue will be created
|
| 285 |
-
self.epilogue_functor = None
|
| 286 |
-
|
| 287 |
-
self._reset_operations()
|
| 288 |
-
|
| 289 |
-
# Arguments that will be determined online based on arguments of "run"
|
| 290 |
-
# based on stride, input/output channels, alignment, and conv_kind
|
| 291 |
-
self._iterator_algorithm = None
|
| 292 |
-
self._stride_support = None
|
| 293 |
-
|
| 294 |
-
def _reset_operations(self, reset_epilogue: bool = True):
|
| 295 |
-
# Set the default op class
|
| 296 |
-
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
| 297 |
-
layout_comb = (self._layout_a, self._layout_b)
|
| 298 |
-
|
| 299 |
-
self.possible_op_classes = self.options.supporting_opclasses(
|
| 300 |
-
self._element_a, self._element_b, self._element_accumulator,
|
| 301 |
-
self._layout_a, self._layout_b, self._math_operation
|
| 302 |
-
)
|
| 303 |
-
|
| 304 |
-
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
|
| 305 |
-
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
| 306 |
-
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
|
| 307 |
-
self.opclass = cutlass_cppgen.OpcodeClass.Simt
|
| 308 |
-
else:
|
| 309 |
-
if self._math_operation is not None:
|
| 310 |
-
math_op_str = f' and math operation {self._math_operation}'
|
| 311 |
-
else:
|
| 312 |
-
math_op_str = ''
|
| 313 |
-
|
| 314 |
-
raise Exception(f'No kernel configuration found for supported data type and layout '
|
| 315 |
-
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
| 316 |
-
|
| 317 |
-
if reset_epilogue:
|
| 318 |
-
self._reset_epilogue_functor_activation(epilogue.identity)
|
| 319 |
-
|
| 320 |
-
self.alignment_pref_A = min(
|
| 321 |
-
128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
|
| 322 |
-
self.alignment_pref_B = min(
|
| 323 |
-
128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
|
| 324 |
-
self.alignment_pref_C = min(
|
| 325 |
-
128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C")))
|
| 326 |
-
|
| 327 |
-
#
|
| 328 |
-
# Tile description Related
|
| 329 |
-
#
|
| 330 |
-
|
| 331 |
-
@property
|
| 332 |
-
def tile_description(self) -> TileDescription:
|
| 333 |
-
"""
|
| 334 |
-
Returns the tile description
|
| 335 |
-
"""
|
| 336 |
-
return self._tile_description
|
| 337 |
-
|
| 338 |
-
@tile_description.setter
|
| 339 |
-
def tile_description(
|
| 340 |
-
self, td=None):
|
| 341 |
-
"""
|
| 342 |
-
Set the tile description
|
| 343 |
-
|
| 344 |
-
:param td: tile description
|
| 345 |
-
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
|
| 346 |
-
{
|
| 347 |
-
"threadblock_shape": [int, int, int],
|
| 348 |
-
"warp_count": [int, int, int],
|
| 349 |
-
"stages": int,
|
| 350 |
-
"instruction_shape": [int, int, int] (optional),
|
| 351 |
-
"cluster_shape": [int, int, int] (optional)
|
| 352 |
-
}
|
| 353 |
-
"""
|
| 354 |
-
if td is None:
|
| 355 |
-
return
|
| 356 |
-
if isinstance(td, dict):
|
| 357 |
-
if self._tile_description is None:
|
| 358 |
-
op = self.possible_operations.default_operation(self._math_operation)
|
| 359 |
-
self._tile_description = datatypes.td_from_profiler_op(op)
|
| 360 |
-
if "cluster_shape" in td.keys():
|
| 361 |
-
if td["cluster_shape"] != [1, 1, 1]:
|
| 362 |
-
cutlass_cppgen.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.")
|
| 363 |
-
td["cluster_shape"] = [1, 1, 1]
|
| 364 |
-
td = self._tile_description.clone_and_update(td)
|
| 365 |
-
|
| 366 |
-
valid, msg = self._valid_tile_description(td)
|
| 367 |
-
if valid:
|
| 368 |
-
self._tile_description = td
|
| 369 |
-
else:
|
| 370 |
-
raise Exception(msg)
|
| 371 |
-
|
| 372 |
-
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
| 373 |
-
"""
|
| 374 |
-
Checks whether the provided tile description is valid for the given compute capability. At present,
|
| 375 |
-
this checks the following:
|
| 376 |
-
|
| 377 |
-
- Does the tile description use a number of stages supported by the compute capability in question?
|
| 378 |
-
- Does the tile size requested fit within shared memory?
|
| 379 |
-
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
|
| 380 |
-
more non-unit cluster dimensions for pre-SM90 architectures)?
|
| 381 |
-
- Is the kernel schedule being used supported on the architecture in question?
|
| 382 |
-
|
| 383 |
-
:param td: tile description to validate
|
| 384 |
-
:type td: cutlass_cppgen.backend.TileDescription
|
| 385 |
-
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
| 386 |
-
and the second element is a string providing an optional error message.
|
| 387 |
-
:rtype: tuple
|
| 388 |
-
"""
|
| 389 |
-
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td)
|
| 390 |
-
if not valid:
|
| 391 |
-
return (valid, msg)
|
| 392 |
-
|
| 393 |
-
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
|
| 394 |
-
if not valid:
|
| 395 |
-
return (valid, msg)
|
| 396 |
-
|
| 397 |
-
return valid, msg
|
| 398 |
-
|
| 399 |
-
def tile_descriptions(self) -> list:
|
| 400 |
-
"""
|
| 401 |
-
Returns a list of valid tile descriptions for the operations
|
| 402 |
-
|
| 403 |
-
:returns: list of valid tile descriptions for the operations
|
| 404 |
-
:rtype: list
|
| 405 |
-
"""
|
| 406 |
-
descriptions = []
|
| 407 |
-
description_str = []
|
| 408 |
-
for op in self.possible_operations.all_operations:
|
| 409 |
-
td = datatypes.td_from_profiler_op(op)
|
| 410 |
-
|
| 411 |
-
if self._math_operation is not None:
|
| 412 |
-
if td.math_instruction.math_operation != self._math_operation:
|
| 413 |
-
continue
|
| 414 |
-
|
| 415 |
-
if str(td) not in description_str:
|
| 416 |
-
description_str.append(str(td))
|
| 417 |
-
descriptions.append(td)
|
| 418 |
-
return descriptions
|
| 419 |
-
|
| 420 |
-
#
|
| 421 |
-
# Swizzling functor Related
|
| 422 |
-
#
|
| 423 |
-
|
| 424 |
-
@property
|
| 425 |
-
def swizzling_stride(self):
|
| 426 |
-
"""
|
| 427 |
-
Returns the stride of swizzling currently being used by the Conv2d
|
| 428 |
-
|
| 429 |
-
:return: swizzing stride
|
| 430 |
-
"""
|
| 431 |
-
return self._swizzling_stride
|
| 432 |
-
|
| 433 |
-
@swizzling_stride.setter
|
| 434 |
-
def swizzling_stride(self, stride: int):
|
| 435 |
-
"""
|
| 436 |
-
Sets the swizzling functor to the type specified by `swizzling_functor`
|
| 437 |
-
"""
|
| 438 |
-
if not isinstance(stride, int):
|
| 439 |
-
raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}")
|
| 440 |
-
self._swizzling_stride = stride
|
| 441 |
-
|
| 442 |
-
def _propose_swizzling_functor(self, stride):
|
| 443 |
-
"""
|
| 444 |
-
Automatically propose the swizzling functor based on the stride
|
| 445 |
-
"""
|
| 446 |
-
if self.conv_kind == ConvKind.Dgrad:
|
| 447 |
-
if stride[0] != 1 or stride[1] != 1:
|
| 448 |
-
return getattr(cutlass_cppgen.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}")
|
| 449 |
-
|
| 450 |
-
return getattr(cutlass_cppgen.swizzle, f"IdentitySwizzle{self._swizzling_stride}")
|
| 451 |
-
|
| 452 |
-
#
|
| 453 |
-
# Iterator Algorithm Related
|
| 454 |
-
#
|
| 455 |
-
|
| 456 |
-
@property
|
| 457 |
-
def iterator_algorithm(self) -> IteratorAlgorithm:
|
| 458 |
-
"""
|
| 459 |
-
Returns the iterator algorithm
|
| 460 |
-
"""
|
| 461 |
-
return self._iterator_algorithm
|
| 462 |
-
|
| 463 |
-
@iterator_algorithm.setter
|
| 464 |
-
def iterator_algorithm(self, alg: str):
|
| 465 |
-
"""
|
| 466 |
-
Sets the iterator algorithm
|
| 467 |
-
|
| 468 |
-
:param alg: The iterator algorithm
|
| 469 |
-
:type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels"
|
| 470 |
-
"""
|
| 471 |
-
iterator_alg = datatypes.getattr_enum(IteratorAlgorithm, alg)
|
| 472 |
-
|
| 473 |
-
# Check if the iterator algorithm is valid
|
| 474 |
-
if iterator_alg in [IteratorAlgorithm.FewChannels, IteratorAlgorithm.FixedChannels] and self.conv_kind != ConvKind.Fprop:
|
| 475 |
-
raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.")
|
| 476 |
-
|
| 477 |
-
self._iterator_algorithm = iterator_alg
|
| 478 |
-
|
| 479 |
-
def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> IteratorAlgorithm:
|
| 480 |
-
"""
|
| 481 |
-
Propose a valid iterator algorithm based on problem size and alignment
|
| 482 |
-
"""
|
| 483 |
-
if self.conv_kind == ConvKind.Fprop:
|
| 484 |
-
# Check whether the fixed channel is applicable
|
| 485 |
-
if problem_size.C == alignment_a:
|
| 486 |
-
return IteratorAlgorithm.FixedChannels
|
| 487 |
-
elif (problem_size.C % alignment_a == 0 and
|
| 488 |
-
problem_size.R <= 32 and problem_size.S <= 32):
|
| 489 |
-
return IteratorAlgorithm.Optimized
|
| 490 |
-
else:
|
| 491 |
-
return IteratorAlgorithm.Analytic
|
| 492 |
-
elif self.conv_kind == ConvKind.Dgrad:
|
| 493 |
-
if (problem_size.K % alignment_a == 0 and
|
| 494 |
-
problem_size.R <= 32 and problem_size.S <= 32 and
|
| 495 |
-
problem_size.C % alignment_b == 0):
|
| 496 |
-
return IteratorAlgorithm.Optimized
|
| 497 |
-
else:
|
| 498 |
-
return IteratorAlgorithm.Analytic
|
| 499 |
-
elif self.conv_kind == ConvKind.Wgrad:
|
| 500 |
-
if (problem_size.K % alignment_a == 0 and
|
| 501 |
-
problem_size.C % alignment_b == 0):
|
| 502 |
-
return IteratorAlgorithm.Optimized
|
| 503 |
-
else:
|
| 504 |
-
return IteratorAlgorithm.Analytic
|
| 505 |
-
|
| 506 |
-
def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool:
|
| 507 |
-
"""
|
| 508 |
-
Validate whether the user provide iterator algorithm works for the given problem size
|
| 509 |
-
"""
|
| 510 |
-
if self.conv_kind == ConvKind.Fprop:
|
| 511 |
-
if iterator_algorithm == IteratorAlgorithm.FixedChannels:
|
| 512 |
-
return problem_size.C == alignment_a
|
| 513 |
-
elif iterator_algorithm == IteratorAlgorithm.Optimized:
|
| 514 |
-
return (problem_size.C % alignment_a == 0 and
|
| 515 |
-
problem_size.R <= 32 and problem_size.S <= 32)
|
| 516 |
-
elif iterator_algorithm == IteratorAlgorithm.FewChannels:
|
| 517 |
-
return problem_size.C % alignment_a == 0
|
| 518 |
-
elif self.conv_kind == ConvKind.Dgrad:
|
| 519 |
-
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
| 520 |
-
return (problem_size.K % alignment_a == 0 and
|
| 521 |
-
problem_size.R <= 32 and problem_size.S <= 32 and
|
| 522 |
-
problem_size.C % alignment_b == 0)
|
| 523 |
-
elif self.conv_kind == ConvKind.Wgrad:
|
| 524 |
-
if iterator_algorithm == IteratorAlgorithm.Optimized:
|
| 525 |
-
return (problem_size.K % alignment_a == 0 and
|
| 526 |
-
problem_size.C % alignment_b == 0)
|
| 527 |
-
|
| 528 |
-
return True
|
| 529 |
-
|
| 530 |
-
#
|
| 531 |
-
# Stride Support Related
|
| 532 |
-
#
|
| 533 |
-
|
| 534 |
-
def _propose_stride_support(self, stride):
|
| 535 |
-
if self.conv_kind == ConvKind.Dgrad:
|
| 536 |
-
if stride[0] == 1 and stride[1] == 1:
|
| 537 |
-
return StrideSupport.Unity
|
| 538 |
-
|
| 539 |
-
return StrideSupport.Strided
|
| 540 |
-
|
| 541 |
-
#
|
| 542 |
-
# Construct and Compilation
|
| 543 |
-
#
|
| 544 |
-
|
| 545 |
-
def construct(
|
| 546 |
-
self, tile_description: TileDescription = None,
|
| 547 |
-
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
| 548 |
-
iterator_algorithm: IteratorAlgorithm = None,
|
| 549 |
-
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
|
| 550 |
-
epilogue_functor=None) -> cutlass_cppgen.backend.Conv2dOperation:
|
| 551 |
-
"""
|
| 552 |
-
Constructs a ``cutlass_cppgen.backend.Conv2dOperation`` based on the input parameters and current
|
| 553 |
-
kernel specification of the ``Conv2d`` object.
|
| 554 |
-
|
| 555 |
-
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
| 556 |
-
:type tile_description: cutlass_cppgen.backend.TileDescription
|
| 557 |
-
:param alignment_A: alignment of operand A
|
| 558 |
-
:type alignment_A: int
|
| 559 |
-
:param alignment_B: alignment of operand B
|
| 560 |
-
:type alignment_B: int
|
| 561 |
-
:param alignment_C: alignment of operand C
|
| 562 |
-
:type alignment_C: int
|
| 563 |
-
:param iterator_algorithm: the iterator algorithm used
|
| 564 |
-
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
| 565 |
-
:param stride_support: the stride support of dgrad
|
| 566 |
-
:type stride_support: cutlass_library.library.StrideSupport
|
| 567 |
-
:param swizzling_functor: the swizzling functor
|
| 568 |
-
:type swizzling_functor: cutlass_cppgen.swizzle
|
| 569 |
-
:param epilogue_functor: the epilogue functor
|
| 570 |
-
|
| 571 |
-
:return: operation that was constructed
|
| 572 |
-
:rtype: cutlass_cppgen.backend.Conv2dOperation
|
| 573 |
-
"""
|
| 574 |
-
# Get alignment
|
| 575 |
-
alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A)
|
| 576 |
-
alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B)
|
| 577 |
-
alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C)
|
| 578 |
-
|
| 579 |
-
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
|
| 580 |
-
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
| 581 |
-
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
| 582 |
-
|
| 583 |
-
if tile_description is None:
|
| 584 |
-
if self.tile_description is not None:
|
| 585 |
-
tile_description = self.tile_description
|
| 586 |
-
else:
|
| 587 |
-
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
| 588 |
-
tile_description = datatypes.td_from_profiler_op(op)
|
| 589 |
-
else:
|
| 590 |
-
valid, err_str = self._valid_tile_description(tile_description)
|
| 591 |
-
if not valid:
|
| 592 |
-
raise Exception(f"Invalid tile description. {err_str}")
|
| 593 |
-
self.tile_description = tile_description
|
| 594 |
-
|
| 595 |
-
if iterator_algorithm is None:
|
| 596 |
-
# If the iterator algorithm is already set
|
| 597 |
-
if self.iterator_algorithm is not None:
|
| 598 |
-
iterator_algorithm = self.iterator_algorithm
|
| 599 |
-
else:
|
| 600 |
-
# Otherwise, we conservatively use the analytic iterator for correctness
|
| 601 |
-
iterator_algorithm = IteratorAlgorithm.Analytic
|
| 602 |
-
|
| 603 |
-
if stride_support is None:
|
| 604 |
-
# If the stride support is already set
|
| 605 |
-
if self._stride_support is not None:
|
| 606 |
-
stride_support = self._stride_support
|
| 607 |
-
else:
|
| 608 |
-
# Otherwise, we assume strided
|
| 609 |
-
stride_support = StrideSupport.Strided
|
| 610 |
-
|
| 611 |
-
if swizzling_functor is None:
|
| 612 |
-
# If the swizzling functor is already set
|
| 613 |
-
swizzling_functor = self._propose_swizzling_functor(stride=(2, 2))
|
| 614 |
-
|
| 615 |
-
if epilogue_functor is None:
|
| 616 |
-
if self.epilogue_functor is not None:
|
| 617 |
-
epilogue_functor = self.epilogue_functor
|
| 618 |
-
else:
|
| 619 |
-
epilogue_functor = self._create_epilogue_functor_activation(self._activation)
|
| 620 |
-
|
| 621 |
-
# Reset the alignment of the epilogue functor
|
| 622 |
-
epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor)
|
| 623 |
-
|
| 624 |
-
operation = Conv2dOperation(
|
| 625 |
-
conv_kind=self.conv_kind,
|
| 626 |
-
iterator_algorithm=iterator_algorithm,
|
| 627 |
-
arch=self.current_cc,
|
| 628 |
-
tile_description=tile_description,
|
| 629 |
-
A=tensor_A, B=tensor_B, C=tensor_C,
|
| 630 |
-
stride_support=stride_support,
|
| 631 |
-
epilogue_functor=epilogue_functor,
|
| 632 |
-
swizzling_functor=swizzling_functor,
|
| 633 |
-
)
|
| 634 |
-
|
| 635 |
-
return operation
|
| 636 |
-
|
| 637 |
-
def compile(self, tile_description: TileDescription = None,
|
| 638 |
-
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
| 639 |
-
iterator_algorithm: IteratorAlgorithm = None,
|
| 640 |
-
stride_support = None, swizzling_functor: cutlass_cppgen.swizzle = None,
|
| 641 |
-
epilogue_functor = None, print_module: bool = False) -> cutlass_cppgen.backend.Conv2dOperation:
|
| 642 |
-
"""
|
| 643 |
-
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
| 644 |
-
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
| 645 |
-
tile description and alignments. Otherwise, a default tile description and alignment
|
| 646 |
-
will be used.
|
| 647 |
-
|
| 648 |
-
::param tile_description: tile description specifying shapes and operand types to use in the kernel
|
| 649 |
-
:type tile_description: cutlass_cppgen.backend.TileDescription
|
| 650 |
-
:param alignment_A: alignment of operand A
|
| 651 |
-
:type alignment_A: int
|
| 652 |
-
:param alignment_B: alignment of operand B
|
| 653 |
-
:type alignment_B: int
|
| 654 |
-
:param alignment_C: alignment of operand C
|
| 655 |
-
:type alignment_C: int
|
| 656 |
-
:param iterator_algorithm: the iterator algorithm used
|
| 657 |
-
:type iterator_algorithm: cutlass_library.library.IteratorAlgorithm
|
| 658 |
-
:param stride_support: the stride support of dgrad
|
| 659 |
-
:type stride_support: cutlass_library.library.StrideSupport
|
| 660 |
-
:param swizzling_functor: the swizzling functor
|
| 661 |
-
:type swizzling_functor: cutlass_cppgen.swizzle
|
| 662 |
-
:param epilogue_functor: the epilogue functor
|
| 663 |
-
|
| 664 |
-
:return: operation that was compiled
|
| 665 |
-
:rtype: cutlass_cppgen.backend.Conv2dOperation
|
| 666 |
-
"""
|
| 667 |
-
|
| 668 |
-
self.operation = self.construct(
|
| 669 |
-
tile_description, alignment_A, alignment_B, alignment_C,
|
| 670 |
-
iterator_algorithm, stride_support, swizzling_functor, epilogue_functor)
|
| 671 |
-
|
| 672 |
-
if print_module:
|
| 673 |
-
print(self.operation.rt_module.emit())
|
| 674 |
-
|
| 675 |
-
compiler.add_module([self.operation,])
|
| 676 |
-
return self.operation
|
| 677 |
-
|
| 678 |
-
#
|
| 679 |
-
# Run Related
|
| 680 |
-
#
|
| 681 |
-
|
| 682 |
-
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
| 683 |
-
"""
|
| 684 |
-
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
| 685 |
-
is raised if it does not.
|
| 686 |
-
|
| 687 |
-
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
| 688 |
-
:type tensor: numpy/cupy/torch array/tensor object
|
| 689 |
-
:param ref_dtype: data type for the tensor that this object was initialized to
|
| 690 |
-
:param name: identifier of the tensor to verify. Used in raising exceptions
|
| 691 |
-
:type name: str
|
| 692 |
-
"""
|
| 693 |
-
dtype, _ = datatypes.get_datatype_and_layout(tensor)
|
| 694 |
-
if dtype != ref_type:
|
| 695 |
-
raise Exception(f'Tensor {name} with type and layout {dtype} '
|
| 696 |
-
f'does not match the expected type of {ref_type}.')
|
| 697 |
-
|
| 698 |
-
def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation):
|
| 699 |
-
if self.conv_kind == ConvKind.Fprop:
|
| 700 |
-
input = A
|
| 701 |
-
weight = B
|
| 702 |
-
output = C
|
| 703 |
-
output_tensor = "C"
|
| 704 |
-
elif self.conv_kind == ConvKind.Dgrad:
|
| 705 |
-
output = A
|
| 706 |
-
weight = B
|
| 707 |
-
input = C
|
| 708 |
-
output_tensor = "A"
|
| 709 |
-
elif self.conv_kind == ConvKind.Wgrad:
|
| 710 |
-
output = A
|
| 711 |
-
input = B
|
| 712 |
-
weight = C
|
| 713 |
-
output_tensor = "A"
|
| 714 |
-
else:
|
| 715 |
-
raise Exception(f"Convolution kind {self.conv_kind} is not supported")
|
| 716 |
-
|
| 717 |
-
N_, H_, W_, C_ = datatypes.get_tensor_shape(input, op="CONV")
|
| 718 |
-
K_, R_, S_, _ = datatypes.get_tensor_shape(weight, op="CONV")
|
| 719 |
-
_, P_, Q_, _ = datatypes.get_tensor_shape(output, op="CONV")
|
| 720 |
-
|
| 721 |
-
problem_size = Conv2DProblemSize(
|
| 722 |
-
N_, H_, W_, C_,
|
| 723 |
-
K_, R_, S_, C_,
|
| 724 |
-
padding[0], padding[1],
|
| 725 |
-
stride[0], stride[1],
|
| 726 |
-
dilation[0], dilation[1],
|
| 727 |
-
ConvMode.CrossCorrelation,
|
| 728 |
-
1, 1
|
| 729 |
-
)
|
| 730 |
-
|
| 731 |
-
if P_ != problem_size.P or Q_ != problem_size.Q:
|
| 732 |
-
raise Exception(
|
| 733 |
-
f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})")
|
| 734 |
-
|
| 735 |
-
return problem_size
|
| 736 |
-
|
| 737 |
-
def run(self, A=None, B=None, C=None, D=None,
|
| 738 |
-
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
|
| 739 |
-
alpha=None, beta=None,
|
| 740 |
-
split_k=("serial", 1), sync: bool = True,
|
| 741 |
-
print_module: bool = False,
|
| 742 |
-
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
| 743 |
-
"""
|
| 744 |
-
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
| 745 |
-
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
| 746 |
-
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
|
| 747 |
-
parameters provided in the call, or from those
|
| 748 |
-
passed in on the construction of this object -- one of the two must be specified.
|
| 749 |
-
|
| 750 |
-
By default, this call returns only once the kernel has completed. To launch the kernel
|
| 751 |
-
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
| 752 |
-
caller to syncrhonize the results of the kernel before attempting to access outputs
|
| 753 |
-
by calling ``sync()`` on the arguments returned from this call.
|
| 754 |
-
|
| 755 |
-
:param A: tensor representing data type and layout of operand A
|
| 756 |
-
:param B: tensor representing data type and layout of operand B
|
| 757 |
-
:param C: tensor representing data type and layout of operand C
|
| 758 |
-
:param D: tensor representing data type and layout of operand D
|
| 759 |
-
:param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1)
|
| 760 |
-
:param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0)
|
| 761 |
-
:param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1)
|
| 762 |
-
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 763 |
-
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 764 |
-
:param split_k: a tuple (split_k_mode, split_k_slices)
|
| 765 |
-
:param sync: whether the call should wait for the kernel to complete before returning
|
| 766 |
-
:type sync: bool
|
| 767 |
-
:param print_module: whether to print the emitted C++ code
|
| 768 |
-
:type print_module: bool
|
| 769 |
-
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
| 770 |
-
:type stream: :class:`cuda.cuda.CUstream`
|
| 771 |
-
|
| 772 |
-
:return: arguments passed in to the kernel
|
| 773 |
-
:rtype: cutlass_cppgen.backend.Conv2dArguments
|
| 774 |
-
"""
|
| 775 |
-
if not stream:
|
| 776 |
-
stream = cuda.CUstream(0)
|
| 777 |
-
super().run_setup()
|
| 778 |
-
|
| 779 |
-
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
| 780 |
-
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
| 781 |
-
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
| 782 |
-
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
|
| 783 |
-
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
| 784 |
-
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
| 785 |
-
|
| 786 |
-
# handle the case when there is no C
|
| 787 |
-
if C is None:
|
| 788 |
-
if beta != 0:
|
| 789 |
-
raise Exception(f"With beta {beta} != 0, C has to be provided.")
|
| 790 |
-
else:
|
| 791 |
-
C = D
|
| 792 |
-
|
| 793 |
-
# Construct problem size based on input
|
| 794 |
-
# It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
|
| 795 |
-
problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)
|
| 796 |
-
|
| 797 |
-
# Propose stride support based on input
|
| 798 |
-
stride_support = self._propose_stride_support(stride)
|
| 799 |
-
|
| 800 |
-
# Propose swizzling functor
|
| 801 |
-
swizzling_functor = self._propose_swizzling_functor(stride)
|
| 802 |
-
|
| 803 |
-
shape_a = datatypes.get_tensor_shape(A, op="CONV")
|
| 804 |
-
shape_b = datatypes.get_tensor_shape(B, op="CONV")
|
| 805 |
-
shape_c = datatypes.get_tensor_shape(C, op="CONV")
|
| 806 |
-
|
| 807 |
-
# Get the alignment
|
| 808 |
-
alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A")
|
| 809 |
-
alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B")
|
| 810 |
-
alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C")
|
| 811 |
-
|
| 812 |
-
alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
|
| 813 |
-
alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
|
| 814 |
-
alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)
|
| 815 |
-
|
| 816 |
-
# Propose iterator algorithm based on input
|
| 817 |
-
if self._iterator_algorithm is None:
|
| 818 |
-
# Propose a default iterator algorithm based on the problem size
|
| 819 |
-
iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
|
| 820 |
-
else:
|
| 821 |
-
if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
|
| 822 |
-
iterator_algorithm = self._iterator_algorithm
|
| 823 |
-
else:
|
| 824 |
-
raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")
|
| 825 |
-
|
| 826 |
-
epilogue_args = [alpha, beta]
|
| 827 |
-
|
| 828 |
-
if hasattr(self, "_activation_args"):
|
| 829 |
-
if isinstance(self._activation_args, list):
|
| 830 |
-
epilogue_args += self._activation_args
|
| 831 |
-
else:
|
| 832 |
-
epilogue_args.append(self._activation_args)
|
| 833 |
-
|
| 834 |
-
if split_k[0] == "parallel" and split_k[1] > 1:
|
| 835 |
-
epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
|
| 836 |
-
else:
|
| 837 |
-
epilogue_functor = self.epilogue_functor
|
| 838 |
-
|
| 839 |
-
# The alignment is determined by the iterator function (I believe)
|
| 840 |
-
self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
| 841 |
-
alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
|
| 842 |
-
swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)
|
| 843 |
-
|
| 844 |
-
# Create reduction operation for parallel split-k
|
| 845 |
-
if split_k[0] == "parallel" and split_k[1] > 1:
|
| 846 |
-
epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
|
| 847 |
-
self.reduction_operation = ReductionOperation(
|
| 848 |
-
shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
|
| 849 |
-
element_accumulator=self._element_accumulator,
|
| 850 |
-
element_compute=self._element_accumulator,
|
| 851 |
-
epilogue_functor=epilogue_functor_reduction,
|
| 852 |
-
count=alignment_c
|
| 853 |
-
)
|
| 854 |
-
if print_module:
|
| 855 |
-
print(self.reduction_operation.rt_module.emit())
|
| 856 |
-
compiler.add_module([self.reduction_operation,])
|
| 857 |
-
|
| 858 |
-
arguments = Conv2dArguments(
|
| 859 |
-
operation=self.operation, problem_size=problem_size,
|
| 860 |
-
A=A, B=B, C=C, D=D,
|
| 861 |
-
output_op=self.operation.epilogue_type(*epilogue_args),
|
| 862 |
-
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
|
| 863 |
-
split_k_slices=split_k[1],
|
| 864 |
-
stream=stream
|
| 865 |
-
)
|
| 866 |
-
|
| 867 |
-
self.operation.run(arguments)
|
| 868 |
-
|
| 869 |
-
if split_k[0] == "parallel" and split_k[1] > 1:
|
| 870 |
-
implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind)
|
| 871 |
-
reduction_arguments = ReductionArguments(
|
| 872 |
-
self.reduction_operation,
|
| 873 |
-
problem_size=[implicit_gemm_size.m, implicit_gemm_size.n],
|
| 874 |
-
partitions=split_k[1],
|
| 875 |
-
workspace=arguments.ptr_D,
|
| 876 |
-
destination=D,
|
| 877 |
-
source=C,
|
| 878 |
-
output_op=self.reduction_operation.epilogue_type(*epilogue_args),
|
| 879 |
-
stream=stream
|
| 880 |
-
)
|
| 881 |
-
self.reduction_operation.run(reduction_arguments)
|
| 882 |
-
|
| 883 |
-
if sync:
|
| 884 |
-
if split_k[0] == "parallel" and split_k[1] > 1:
|
| 885 |
-
reduction_arguments.sync()
|
| 886 |
-
|
| 887 |
-
# Free memory allocated by args because we are not
|
| 888 |
-
# calling `arguments.sync()` in this case (which will free memory)
|
| 889 |
-
arguments.free()
|
| 890 |
-
else:
|
| 891 |
-
arguments.sync()
|
| 892 |
-
|
| 893 |
-
return arguments
|
| 894 |
-
|
| 895 |
-
#
|
| 896 |
-
# Helper functions
|
| 897 |
-
#
|
| 898 |
-
@staticmethod
|
| 899 |
-
def output_size(input_size, weight_size, padding, stride, dilation):
|
| 900 |
-
problem_size = Conv2DProblemSize(
|
| 901 |
-
*input_size,
|
| 902 |
-
*weight_size,
|
| 903 |
-
padding[0], padding[1],
|
| 904 |
-
stride[0], stride[1],
|
| 905 |
-
dilation[0], dilation[1],
|
| 906 |
-
ConvMode.CrossCorrelation,
|
| 907 |
-
1, 1
|
| 908 |
-
)
|
| 909 |
-
return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K)
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
#
|
| 913 |
-
# Easy to use interfaces for fprop, wgrad, and dgrad
|
| 914 |
-
#
|
| 915 |
-
|
| 916 |
-
class Conv2dFprop(Conv2d):
|
| 917 |
-
def __init__(
|
| 918 |
-
self,
|
| 919 |
-
input=None, weight=None, C=None, output=None, alpha=1, beta=0,
|
| 920 |
-
element=None,
|
| 921 |
-
element_input=None, element_weight=None, element_C=None, element_output=None,
|
| 922 |
-
element_accumulator=None,
|
| 923 |
-
cc: int = None, kernel_cc: int = None):
|
| 924 |
-
A, B, D = input, weight, output
|
| 925 |
-
element_A, element_B, element_D = element_input, element_weight, element_output
|
| 926 |
-
super().__init__(
|
| 927 |
-
"fprop", A, B, C, D, alpha, beta, element,
|
| 928 |
-
element_A, element_B, element_C, element_D,
|
| 929 |
-
element_accumulator, cc, kernel_cc)
|
| 930 |
-
|
| 931 |
-
def run(
|
| 932 |
-
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
| 933 |
-
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
| 934 |
-
sync: bool = True, print_module: bool = False,
|
| 935 |
-
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
| 936 |
-
|
| 937 |
-
if not stream:
|
| 938 |
-
stream = cuda.CUstream(0)
|
| 939 |
-
|
| 940 |
-
A, B, D = input, weight, output
|
| 941 |
-
return super().run(
|
| 942 |
-
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
class Conv2dDgrad(Conv2d):
|
| 946 |
-
def __init__(
|
| 947 |
-
self,
|
| 948 |
-
grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0,
|
| 949 |
-
element=None,
|
| 950 |
-
element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None,
|
| 951 |
-
element_accumulator=None,
|
| 952 |
-
cc: int = None, kernel_cc: int = None):
|
| 953 |
-
A, B, D = grad_output, weight, grad_input
|
| 954 |
-
element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input
|
| 955 |
-
super().__init__(
|
| 956 |
-
"dgrad", A, B, C, D, alpha, beta, element,
|
| 957 |
-
element_A, element_B, element_C, element_D,
|
| 958 |
-
element_accumulator, cc, kernel_cc)
|
| 959 |
-
|
| 960 |
-
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
| 961 |
-
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
| 962 |
-
sync: bool = True, print_module: bool = False,
|
| 963 |
-
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
| 964 |
-
#
|
| 965 |
-
if not stream:
|
| 966 |
-
stream = cuda.CUstream(0)
|
| 967 |
-
|
| 968 |
-
A, B, D = grad_output, weight, grad_input
|
| 969 |
-
return super().run(
|
| 970 |
-
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
class Conv2dWgrad(Conv2d):
|
| 974 |
-
def __init__(
|
| 975 |
-
self,
|
| 976 |
-
grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0,
|
| 977 |
-
element=None,
|
| 978 |
-
element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None,
|
| 979 |
-
element_accumulator=None,
|
| 980 |
-
cc: int = None, kernel_cc: int = None):
|
| 981 |
-
A, B, D = grad_output, input, grad_weight
|
| 982 |
-
element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight
|
| 983 |
-
super().__init__(
|
| 984 |
-
"wgrad", A, B, C, D, alpha, beta, element,
|
| 985 |
-
element_A, element_B, element_C, element_D,
|
| 986 |
-
element_accumulator, cc, kernel_cc)
|
| 987 |
-
|
| 988 |
-
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
| 989 |
-
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
| 990 |
-
sync: bool = True, print_module: bool = False,
|
| 991 |
-
stream: Optional[cuda.CUstream] = None) -> Conv2dArguments:
|
| 992 |
-
if not stream:
|
| 993 |
-
stream = cuda.CUstream(0)
|
| 994 |
-
|
| 995 |
-
A, B, D = grad_output, input, grad_weight
|
| 996 |
-
return super().run(
|
| 997 |
-
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm.py
DELETED
|
@@ -1,725 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Ease-of-use interface for constructing, compiling, and running GEMMs.
|
| 35 |
-
|
| 36 |
-
The ``Gemm`` interface is meant to allow one to easily instantiate, compile, and run
|
| 37 |
-
GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
|
| 38 |
-
Under the hood, the interface will select sensible default parameters for the many template
|
| 39 |
-
parameters for CUTLASS GEMMs.
|
| 40 |
-
|
| 41 |
-
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
| 42 |
-
performance, one should specify and tune each configuration parameter.
|
| 43 |
-
|
| 44 |
-
The simplest example of using this interface is the following:
|
| 45 |
-
|
| 46 |
-
.. highlight:: python
|
| 47 |
-
.. code-block:: python
|
| 48 |
-
|
| 49 |
-
# A, B, C, and D are torch/numpy/cupy tensor objects
|
| 50 |
-
plan = cutlass_cppgen.op.Gemm(A, B, C, D)
|
| 51 |
-
plan.run()
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
One can also use the interface by specifying data types of operands at construction
|
| 55 |
-
and using different tensor objects with these data types at runtime:
|
| 56 |
-
|
| 57 |
-
.. highlight:: python
|
| 58 |
-
.. code-block:: python
|
| 59 |
-
|
| 60 |
-
# The following is shorthand for:
|
| 61 |
-
# cutlass_cppgen.op.Gemm(element_A=torch.float32, element_B=torch.float32,
|
| 62 |
-
# element_C=torch.float32, element_D=torch.float32,
|
| 63 |
-
# element_accumulator=torch.float32,
|
| 64 |
-
# layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 65 |
-
plan = cutlass_cppgen.op.Gemm(element=torch.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 66 |
-
|
| 67 |
-
A0 = torch.rand((128, 256), device='cuda')
|
| 68 |
-
B0 = torch.rand((256, 64), device='cuda')
|
| 69 |
-
C0 = torch.zeros((128, 64), device='cuda')
|
| 70 |
-
D0 = torch.zeros((128, 64), device.'cuda')
|
| 71 |
-
plan.run(A0, B0, C0, D0)
|
| 72 |
-
|
| 73 |
-
A = torch.rand((32, 128), device='cuda')
|
| 74 |
-
B = torch.rand((128, 256), device='cuda')
|
| 75 |
-
C = torch.zeros((32, 256), device='cuda')
|
| 76 |
-
D = torch.zeros((32, 256), device.'cuda')
|
| 77 |
-
plan.run(A1, B1, C1, D1)
|
| 78 |
-
|
| 79 |
-
The interface additionally enables one to decouple the compilation of the underlying CUTLASS
|
| 80 |
-
kernel from its execution:
|
| 81 |
-
|
| 82 |
-
.. highlight:: python
|
| 83 |
-
.. code-block:: python
|
| 84 |
-
|
| 85 |
-
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 86 |
-
plan.compile()
|
| 87 |
-
|
| 88 |
-
# Do other work...
|
| 89 |
-
|
| 90 |
-
plan.run(A0, B0, C0, D0)
|
| 91 |
-
|
| 92 |
-
# Do other work...
|
| 93 |
-
|
| 94 |
-
plan.run(A1, B1, C1, D1)
|
| 95 |
-
|
| 96 |
-
Elementwise activation functions are easily fused to the GEMM via the interface:
|
| 97 |
-
|
| 98 |
-
.. highlight:: python
|
| 99 |
-
.. code-block:: python
|
| 100 |
-
|
| 101 |
-
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 102 |
-
plan.activation = cutlass_cppgen.epilogue.relu
|
| 103 |
-
|
| 104 |
-
Operations can also be run asynchronously:
|
| 105 |
-
|
| 106 |
-
.. highlight:: python
|
| 107 |
-
.. code-block:: python
|
| 108 |
-
|
| 109 |
-
plan = cutlass_cppgen.op.Gemm(element=np.float32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 110 |
-
args = plan.run()
|
| 111 |
-
|
| 112 |
-
# Do other work...
|
| 113 |
-
|
| 114 |
-
args.sync()
|
| 115 |
-
"""
|
| 116 |
-
from __future__ import annotations
|
| 117 |
-
from typing import Optional
|
| 118 |
-
from math import prod
|
| 119 |
-
|
| 120 |
-
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 121 |
-
cuda = lazy_import("cuda.cuda")
|
| 122 |
-
from cutlass_library import (
|
| 123 |
-
DataType,
|
| 124 |
-
DataTypeSize,
|
| 125 |
-
GemmUniversalMode,
|
| 126 |
-
KernelScheduleSuffixes,
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
import cutlass_cppgen
|
| 130 |
-
from cutlass_cppgen import epilogue, swizzle
|
| 131 |
-
from cutlass_cppgen.backend import compiler
|
| 132 |
-
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
| 133 |
-
from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
|
| 134 |
-
from cutlass_cppgen.backend.library import TensorDescription, TileDescription
|
| 135 |
-
from cutlass_cppgen.op.op import OperationBase
|
| 136 |
-
from cutlass_cppgen.shape import GemmCoord
|
| 137 |
-
from cutlass_cppgen.utils import check, datatypes
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
class Gemm(OperationBase):
|
| 141 |
-
"""
|
| 142 |
-
Constructs a ``Gemm`` object.
|
| 143 |
-
|
| 144 |
-
The data types and layouts of operands A, B, and C, along with the data type of output D
|
| 145 |
-
and that used for accumulation, are bound to the ``Gemm`` object throughout its lifetime --
|
| 146 |
-
these are not to be changed after a ``Gemm`` has been constructed.
|
| 147 |
-
|
| 148 |
-
The constructor has optional parameters for flexibly setting these parameters. The following
|
| 149 |
-
constructors are equivalent:
|
| 150 |
-
|
| 151 |
-
.. highlight:: python
|
| 152 |
-
.. code-block:: python
|
| 153 |
-
|
| 154 |
-
# Use F32 for A, B, C, D, and accumulation. All operands are row major.
|
| 155 |
-
|
| 156 |
-
# Use the generic ``element`` and ``layout`` parameters to concisely set all data types and layouts
|
| 157 |
-
# for operands to the same values.
|
| 158 |
-
Gemm(element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 159 |
-
|
| 160 |
-
# Explicitly specify the data types to use for A, B, C, and D. Use the generic ``layout``.
|
| 161 |
-
Gemm(element_A=cutlass_cppgen.DataType.f32, element_B=cutlass_cppgen.DataType.f32, element_C=cutlass_cppgen.DataType.f32,
|
| 162 |
-
element_D=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 163 |
-
|
| 164 |
-
# Set the data types and elements from existing tensors. Note that one can use different tensors when
|
| 165 |
-
# executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must
|
| 166 |
-
# have the same data type and layout as those passed in here).
|
| 167 |
-
# A, B, C, and D are row-major torch.Tensor objects of type torch.float32
|
| 168 |
-
Gemm(A=A, B=B, C=C, D=D)
|
| 169 |
-
|
| 170 |
-
# Use the generic ``element`` and explicitly specify the layouts to use for A, B, and C (layout of D is
|
| 171 |
-
# the same as that for D, at present)
|
| 172 |
-
Gemm(element=cutlass_cppgen.DataType.f32, layout_A=cutlass_cppgen.LayoutType.RowMajor,
|
| 173 |
-
layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor)
|
| 174 |
-
|
| 175 |
-
# Explicitly specify the data type and layout for only some of A, B, C, and D. Unspecified data types
|
| 176 |
-
# and layouts will inherit those passed in via the generic ``element`` and ``layout``
|
| 177 |
-
Gemm(element_A=cutlass_cppgen.DataType.f32, layout_B=cutlass_cppgen.LayoutType.RowMajor,
|
| 178 |
-
element=cutlass_cppgen.DataType.f32, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 179 |
-
|
| 180 |
-
The order of precedence for the setting of the data type and layout for a given operand/output is as follows:
|
| 181 |
-
1) If the tensor type is specified (e.g., ``A``), use the data type and layout inferred from this tensor
|
| 182 |
-
2) Otherwise, if the data type/layout (e.g., ``element_A``, ``layout_A``) is specified, use those
|
| 183 |
-
3) Otherwise, use the generic values (e.g., ``element``, ``layout``)
|
| 184 |
-
|
| 185 |
-
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
| 186 |
-
:type cc: int
|
| 187 |
-
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
| 188 |
-
:type kernel_cc: int
|
| 189 |
-
:param A: tensor representing data type and layout of operand A
|
| 190 |
-
:param B: tensor representing data type and layout of operand B
|
| 191 |
-
:param C: tensor representing data type and layout of operand C
|
| 192 |
-
:param D: tensor representing data type and layout of operand D
|
| 193 |
-
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 194 |
-
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 195 |
-
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
| 196 |
-
:type element_accumulator: cutlass_cppgen.DataType
|
| 197 |
-
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
| 198 |
-
:type element: cutlass_cppgen.DataType
|
| 199 |
-
:param layout: generic layout type to be used for operands A, B, C, and D
|
| 200 |
-
:type layout: cutlass_cppgen.LayoutType
|
| 201 |
-
:param element_A: data type to be used for operand A
|
| 202 |
-
:type element_A: cutlass_cppgen.DataType
|
| 203 |
-
:param element_B: data type to be used for operand B
|
| 204 |
-
:type element_B: cutlass_cppgen.DataType
|
| 205 |
-
:param element_C: data type to be used for operand C
|
| 206 |
-
:type element_C: cutlass_cppgen.DataType
|
| 207 |
-
:param element_D: data type to be used for operand D
|
| 208 |
-
:type element_D: cutlass_cppgen.DataType
|
| 209 |
-
:param layout_A: layout of operand A
|
| 210 |
-
:type layout_A: cutlass_cppgen.LayoutType
|
| 211 |
-
:param layout_B: layout of operand B
|
| 212 |
-
:type layout_B: cutlass_cppgen.LayoutType
|
| 213 |
-
:param layout_C: layout of operand C
|
| 214 |
-
:type layout_C: cutlass_cppgen.LayoutType
|
| 215 |
-
:param layout_D: layout of operand D
|
| 216 |
-
:type layout_D: cutlass_cppgen.LayoutType
|
| 217 |
-
"""
|
| 218 |
-
|
| 219 |
-
def __init__(
|
| 220 |
-
self, A=None, B=None, C=None, D=None,
|
| 221 |
-
alpha=1.0, beta=0.0, element_accumulator=None,
|
| 222 |
-
element=None, layout=None,
|
| 223 |
-
element_A=None, element_B=None, element_C=None, element_D=None,
|
| 224 |
-
layout_A=None, layout_B=None, layout_C=None,
|
| 225 |
-
cc: int = None, kernel_cc: int = None
|
| 226 |
-
):
|
| 227 |
-
super().__init__(cc=cc, kernel_cc=kernel_cc)
|
| 228 |
-
self.name = "gemm"
|
| 229 |
-
self.compiled = False
|
| 230 |
-
|
| 231 |
-
elements = []
|
| 232 |
-
layouts = []
|
| 233 |
-
|
| 234 |
-
# Check that at least one of the following is set for each tensor (illustrated assuming tensor A):
|
| 235 |
-
# ``A``, ``element_A``, ``element`` and ``A``, ``layout_A``, ``layout``
|
| 236 |
-
for elt, lay, tens, name in zip([element_A, element_B, element_C, element_D],
|
| 237 |
-
[layout_A, layout_B, layout_C, layout_C],
|
| 238 |
-
[A, B, C, D],
|
| 239 |
-
["A", "B", "C", "D"]):
|
| 240 |
-
if elt is not None and tens is not None:
|
| 241 |
-
raise Exception(f'Must not specify both element_{name} and tensor {name}')
|
| 242 |
-
if lay is not None and tens is not None:
|
| 243 |
-
raise Exception(f'Must not specify both layout_{name} and tensor {name}')
|
| 244 |
-
if elt is None and tens is None and element is None:
|
| 245 |
-
raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')
|
| 246 |
-
if lay is None and tens is None and layout is None:
|
| 247 |
-
raise Exception(f'Must specify one of layout_{name}, tensor {name}, or generic layout.')
|
| 248 |
-
|
| 249 |
-
elt_to_set = None
|
| 250 |
-
lay_to_set = None
|
| 251 |
-
if tens is not None:
|
| 252 |
-
elt_to_set, lay_to_set = datatypes.get_datatype_and_layout(tens)
|
| 253 |
-
else:
|
| 254 |
-
elt_to_set = elt if elt is not None else element
|
| 255 |
-
lay_to_set = lay if lay is not None else layout
|
| 256 |
-
|
| 257 |
-
elements.append(datatypes.library_type(elt_to_set))
|
| 258 |
-
layouts.append(lay_to_set)
|
| 259 |
-
|
| 260 |
-
self._element_a, self._element_b, self._element_c, self._element_d = elements
|
| 261 |
-
self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts
|
| 262 |
-
|
| 263 |
-
if element_accumulator is None:
|
| 264 |
-
self._element_accumulator = self._element_c
|
| 265 |
-
else:
|
| 266 |
-
self._element_accumulator = datatypes.library_type(element_accumulator)
|
| 267 |
-
|
| 268 |
-
self.A = A
|
| 269 |
-
self.B = B
|
| 270 |
-
self.C = C
|
| 271 |
-
self.D = D
|
| 272 |
-
|
| 273 |
-
self.alpha = alpha
|
| 274 |
-
self.beta = beta
|
| 275 |
-
|
| 276 |
-
self.epilogue_functor = None
|
| 277 |
-
self.op_class = None
|
| 278 |
-
self._tile_description = None
|
| 279 |
-
|
| 280 |
-
self._reset_operations()
|
| 281 |
-
|
| 282 |
-
self._swizzling_functor = cutlass_cppgen.swizzle.IdentitySwizzle1
|
| 283 |
-
|
| 284 |
-
def _reset_operations(self, reset_epilogue: bool = True):
|
| 285 |
-
# Set the default op class
|
| 286 |
-
datatype_comb = (self._element_a, self._element_b, self._element_accumulator)
|
| 287 |
-
layout_comb = (self._layout_a, self._layout_b)
|
| 288 |
-
|
| 289 |
-
self.possible_op_classes = self.options.supporting_opclasses(
|
| 290 |
-
self._element_a, self._element_b, self._element_accumulator,
|
| 291 |
-
self._layout_a, self._layout_b, self._math_operation)
|
| 292 |
-
|
| 293 |
-
if cutlass_cppgen.OpcodeClass.TensorOp in self.possible_op_classes:
|
| 294 |
-
self.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
| 295 |
-
elif cutlass_cppgen.OpcodeClass.Simt in self.possible_op_classes:
|
| 296 |
-
self.opclass = cutlass_cppgen.OpcodeClass.Simt
|
| 297 |
-
else:
|
| 298 |
-
if self._math_operation is not None:
|
| 299 |
-
math_op_str = f' and math operation {self._math_operation}'
|
| 300 |
-
else:
|
| 301 |
-
math_op_str = ''
|
| 302 |
-
|
| 303 |
-
raise Exception(f'No kernel configuration found for supported data type and layout '
|
| 304 |
-
f'combination {datatype_comb}x{layout_comb}{math_op_str}')
|
| 305 |
-
|
| 306 |
-
if reset_epilogue:
|
| 307 |
-
self._reset_epilogue_functor_activation(cutlass_cppgen.epilogue.identity)
|
| 308 |
-
|
| 309 |
-
@property
|
| 310 |
-
def swizzling_functor(self):
|
| 311 |
-
"""
|
| 312 |
-
Returns the type of the swizzling functor currently being used by the GEMM
|
| 313 |
-
|
| 314 |
-
:return: swizzing functor type
|
| 315 |
-
"""
|
| 316 |
-
return self._swizzling_functor
|
| 317 |
-
|
| 318 |
-
@swizzling_functor.setter
|
| 319 |
-
def swizzling_functor(self, swizzling_functor):
|
| 320 |
-
"""
|
| 321 |
-
Sets the swizzling functor to the type specified by `swizzling_functor`
|
| 322 |
-
"""
|
| 323 |
-
if swizzling_functor == cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK:
|
| 324 |
-
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
| 325 |
-
raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp')
|
| 326 |
-
|
| 327 |
-
if self.current_cc in [90, 100, 101, 103]:
|
| 328 |
-
raise Exception('ThreadblockSwizzleStreamK is currently unsupported on SM90+')
|
| 329 |
-
self._swizzling_functor = swizzling_functor
|
| 330 |
-
|
| 331 |
-
#
|
| 332 |
-
# Tile description Related
|
| 333 |
-
#
|
| 334 |
-
|
| 335 |
-
@property
|
| 336 |
-
def tile_description(self) -> TileDescription:
|
| 337 |
-
"""
|
| 338 |
-
Returns the tile description
|
| 339 |
-
"""
|
| 340 |
-
return self._tile_description
|
| 341 |
-
|
| 342 |
-
@tile_description.setter
|
| 343 |
-
def tile_description(
|
| 344 |
-
self, td=None):
|
| 345 |
-
"""
|
| 346 |
-
Set the tile description
|
| 347 |
-
|
| 348 |
-
:param td: tile description
|
| 349 |
-
:type td: cutlass_cppgen.backend.TileDescription, or a dict with keys
|
| 350 |
-
{
|
| 351 |
-
"threadblock_shape": [int, int, int],
|
| 352 |
-
"warp_count": [int, int, int],
|
| 353 |
-
"stages": int,
|
| 354 |
-
"instruction_shape": [int, int, int] (optional),
|
| 355 |
-
"cluster_shape": [int, int, int] (optional)
|
| 356 |
-
}
|
| 357 |
-
"""
|
| 358 |
-
if td is None:
|
| 359 |
-
return
|
| 360 |
-
if isinstance(td, dict):
|
| 361 |
-
if self._tile_description is None:
|
| 362 |
-
op = self.possible_operations.default_operation(self._math_operation)
|
| 363 |
-
self._tile_description = datatypes.td_from_profiler_op(op)
|
| 364 |
-
td = self._tile_description.clone_and_update(td)
|
| 365 |
-
|
| 366 |
-
valid, msg = self._valid_tile_description(td)
|
| 367 |
-
if valid:
|
| 368 |
-
self._tile_description = td
|
| 369 |
-
else:
|
| 370 |
-
raise Exception(msg)
|
| 371 |
-
|
| 372 |
-
def _valid_tile_description(self, td: TileDescription) -> tuple:
|
| 373 |
-
"""
|
| 374 |
-
Checks whether the provided tile description is valid for the given compute capability. At present,
|
| 375 |
-
this checks the following:
|
| 376 |
-
|
| 377 |
-
- Does the tile description use a number of stages supported by the compute capability in question?
|
| 378 |
-
- Does the tile size requested fit within shared memory?
|
| 379 |
-
- Are cluster dimensions outside the valid range requested for a given architecture (e.g.,
|
| 380 |
-
more non-unit cluster dimensions for pre-SM90 architectures)?
|
| 381 |
-
- Is the kernel schedule being used supported on the architecture in question?
|
| 382 |
-
|
| 383 |
-
:param td: tile description to validate
|
| 384 |
-
:type td: cutlass_cppgen.backend.TileDescription
|
| 385 |
-
:return: tuple in which the first element is a bool indicating that the tile description is valid
|
| 386 |
-
and the second element is a string providing an optional error message.
|
| 387 |
-
:rtype: tuple
|
| 388 |
-
"""
|
| 389 |
-
valid, msg = check.valid_stage_count(self.cc, self.current_cc, td, self._element_c, self._element_d)
|
| 390 |
-
if not valid:
|
| 391 |
-
return (valid, msg)
|
| 392 |
-
|
| 393 |
-
valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape)
|
| 394 |
-
if not valid:
|
| 395 |
-
return (valid, msg)
|
| 396 |
-
|
| 397 |
-
valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler)
|
| 398 |
-
|
| 399 |
-
if self.cc in [100, 101, 103] and td.kernel_schedule is not None and td.is_2sm and td.cluster_shape[0] % 2 != 0:
|
| 400 |
-
valid = False
|
| 401 |
-
msg = "Cluster shape must be divisible by 2 for 2SM kernels on SM100, SM101, and SM103"
|
| 402 |
-
|
| 403 |
-
return valid, msg
|
| 404 |
-
|
| 405 |
-
def tile_descriptions(self) -> list:
|
| 406 |
-
"""
|
| 407 |
-
Returns a list of valid tile descriptions for the operations
|
| 408 |
-
|
| 409 |
-
:returns: list of valid tile descriptions for the operations
|
| 410 |
-
:rtype: list
|
| 411 |
-
"""
|
| 412 |
-
tds = [datatypes.td_from_profiler_op(op) for op in self.possible_operations.all_operations]
|
| 413 |
-
if self._math_operation is not None:
|
| 414 |
-
tds = [td for td in tds if td.math_instruction.math_operation == self._math_operation]
|
| 415 |
-
return tds
|
| 416 |
-
|
| 417 |
-
def construct(
|
| 418 |
-
self, tile_description: TileDescription = None,
|
| 419 |
-
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None) -> GemmOperationUniversal:
|
| 420 |
-
"""
|
| 421 |
-
Constructs a ``cutlass_cppgen.backend.GemmUniversalOperation`` based on the input parameters and current
|
| 422 |
-
kernel specification of the ``Gemm`` object.
|
| 423 |
-
|
| 424 |
-
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
| 425 |
-
:type tile_description: cutlass_cppgen.backend.TileDescription
|
| 426 |
-
:param alignment_A: alignment of operand A
|
| 427 |
-
:type alignment_A: int
|
| 428 |
-
:param alignment_B: alignment of operand B
|
| 429 |
-
:type alignment_B: int
|
| 430 |
-
:param alignment_C: alignment of operand C
|
| 431 |
-
:type alignment_C: int
|
| 432 |
-
|
| 433 |
-
:return: operation that was constructed
|
| 434 |
-
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
|
| 435 |
-
"""
|
| 436 |
-
alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A")))
|
| 437 |
-
alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B")))
|
| 438 |
-
alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A)
|
| 439 |
-
alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B)
|
| 440 |
-
|
| 441 |
-
tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A)
|
| 442 |
-
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
| 443 |
-
|
| 444 |
-
if alignment_C is None:
|
| 445 |
-
alignment_C = max(self.possible_operations.alignments("C"))
|
| 446 |
-
if self._element_c != DataType.void:
|
| 447 |
-
alignment_C = min(128 // DataTypeSize[self._element_c], alignment_C)
|
| 448 |
-
|
| 449 |
-
if tile_description is None:
|
| 450 |
-
if self._tile_description is None:
|
| 451 |
-
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
| 452 |
-
tile_description = datatypes.td_from_profiler_op(op)
|
| 453 |
-
|
| 454 |
-
# The selected op may have lower alignment than that determined above, so we must
|
| 455 |
-
# reset alignment here.
|
| 456 |
-
alignment_C = op.C.alignment
|
| 457 |
-
else:
|
| 458 |
-
tile_description = self._tile_description
|
| 459 |
-
else:
|
| 460 |
-
valid, err_str = self._valid_tile_description(tile_description)
|
| 461 |
-
if not valid:
|
| 462 |
-
raise Exception(f"Invalid tile description. {err_str}")
|
| 463 |
-
self._tile_description = tile_description
|
| 464 |
-
|
| 465 |
-
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
| 466 |
-
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
| 467 |
-
|
| 468 |
-
operation = GemmOperationUniversal(
|
| 469 |
-
arch=self.current_cc,
|
| 470 |
-
tile_description=tile_description,
|
| 471 |
-
A=tensor_A, B=tensor_B, C=tensor_C,
|
| 472 |
-
epilogue_functor=self.epilogue_functor,
|
| 473 |
-
swizzling_functor=self._swizzling_functor,
|
| 474 |
-
)
|
| 475 |
-
|
| 476 |
-
return operation
|
| 477 |
-
|
| 478 |
-
def compile(self, tile_description: TileDescription = None,
|
| 479 |
-
alignment_A: int = None, alignment_B: int = None, alignment_C: int = None,
|
| 480 |
-
print_module: bool = False) -> cutlass_cppgen.backend.GemmOperationUniversal:
|
| 481 |
-
"""
|
| 482 |
-
Emits and compiles the kernel currently specified. If ``tile_description`` and any
|
| 483 |
-
of the ``alignment`` parameters are set, the kernel will be chosen using this
|
| 484 |
-
tile description and alignments. Otherwise, a default tile description and alignment
|
| 485 |
-
will be used.
|
| 486 |
-
|
| 487 |
-
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
| 488 |
-
:type tile_description: cutlass_cppgen.backend.TileDescription
|
| 489 |
-
:param alignment_A: alignment of operand A
|
| 490 |
-
:type alignment_A: int
|
| 491 |
-
:param alignment_B: alignment of operand B
|
| 492 |
-
:type alignment_B: int
|
| 493 |
-
:param alignment_C: alignment of operand C
|
| 494 |
-
:type alignment_C: int
|
| 495 |
-
:param print_module: whether to print the emitted C++ code
|
| 496 |
-
:type print_module: bool
|
| 497 |
-
|
| 498 |
-
:return: operation that was compiled
|
| 499 |
-
:rtype: cutlass_cppgen.backend.GemmOperationUniversal
|
| 500 |
-
"""
|
| 501 |
-
self.operation = self.construct(tile_description, alignment_A, alignment_B, alignment_C)
|
| 502 |
-
|
| 503 |
-
if print_module:
|
| 504 |
-
print(self.operation.rt_module.emit())
|
| 505 |
-
|
| 506 |
-
compiler.add_module([self.operation,])
|
| 507 |
-
return self.operation
|
| 508 |
-
|
| 509 |
-
def _verify_rank(self, tensor):
|
| 510 |
-
"""
|
| 511 |
-
Verifies that ``tensor`` has rank greater than 1
|
| 512 |
-
|
| 513 |
-
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
| 514 |
-
:type tensor: numpy/cupy/torch array/tensor object
|
| 515 |
-
"""
|
| 516 |
-
if len(tensor.shape) < 2:
|
| 517 |
-
raise Exception(f"Tensors must be of rank greater than 1. Received tensor of shape: {tensor.shape}")
|
| 518 |
-
|
| 519 |
-
def _get_batch_count(self, A, B, C, D) -> int:
|
| 520 |
-
"""
|
| 521 |
-
Returns the batch count specified by the tensors A, B, C, and D and verifies that these
|
| 522 |
-
tensors match in batch size. Presence of a batch dimension is detected by one of the
|
| 523 |
-
tensors being rank 3. If a batch dimension is present, it must be present in one of
|
| 524 |
-
operands A, B, or C (but need not be in all), and must be present in D.
|
| 525 |
-
|
| 526 |
-
:param A: tensor A
|
| 527 |
-
:type A: numpy/cupy/torch array/tensor object
|
| 528 |
-
:param B: tensor B
|
| 529 |
-
:type B: numpy/cupy/torch array/tensor object
|
| 530 |
-
:param C: tensor C
|
| 531 |
-
:type C: numpy/cupy/torch array/tensor object
|
| 532 |
-
:param D: tensor D
|
| 533 |
-
:type D: numpy/cupy/torch array/tensor object
|
| 534 |
-
|
| 535 |
-
:return: tuple of batch count dimensions
|
| 536 |
-
:rtype: tuple
|
| 537 |
-
"""
|
| 538 |
-
A_batch = prod(A.shape[:-2]) if len(A.shape) > 2 else 1
|
| 539 |
-
B_batch = prod(B.shape[:-2]) if len(B.shape) > 2 else 1
|
| 540 |
-
|
| 541 |
-
if 1 not in [A_batch, B_batch]:
|
| 542 |
-
if A_batch != B_batch:
|
| 543 |
-
raise Exception(f"Get invalid batch counts: A={A_batch}, B={B_batch}")
|
| 544 |
-
return max(A_batch, B_batch)
|
| 545 |
-
|
| 546 |
-
def _get_batch_stride(self, tensor) -> int:
|
| 547 |
-
"""
|
| 548 |
-
Returns the batch stride of ``tensor``. If ``tensor`` is only rank-2, batch stride is 0.
|
| 549 |
-
|
| 550 |
-
:param tensor: tensor object to process
|
| 551 |
-
:type tensor: numpy/cupy/torch array/tensor object
|
| 552 |
-
|
| 553 |
-
:return: stride between each matrix in the batch
|
| 554 |
-
:rtype: int
|
| 555 |
-
"""
|
| 556 |
-
if tensor is not None and len(tensor.shape) > 2:
|
| 557 |
-
return tensor.shape[-2] * tensor.shape[-1]
|
| 558 |
-
else:
|
| 559 |
-
return 0
|
| 560 |
-
|
| 561 |
-
def _get_problem_args(self, A, B, C, D) -> tuple:
|
| 562 |
-
"""
|
| 563 |
-
Returns the problem size and GEMM universal mode to use for the
|
| 564 |
-
given operands.
|
| 565 |
-
|
| 566 |
-
:param A: tensor A
|
| 567 |
-
:type A: numpy/cupy/torch array/tensor object
|
| 568 |
-
:param B: tensor B
|
| 569 |
-
:type B: numpy/cupy/torch array/tensor object
|
| 570 |
-
:param C: tensor C
|
| 571 |
-
:type C: numpy/cupy/torch array/tensor object
|
| 572 |
-
:param D: tensor D
|
| 573 |
-
:type D: numpy/cupy/torch array/tensor object
|
| 574 |
-
|
| 575 |
-
:return: tuple containing the problem size (cutlass_cppgen.shape.GemmCoord), the GEMM mode (cutlass_cppgen.GemmUniversalMode), and the batch count (int)
|
| 576 |
-
:rtype: tuple
|
| 577 |
-
"""
|
| 578 |
-
M, K = A.shape[-2:]
|
| 579 |
-
N = B.shape[-1]
|
| 580 |
-
mode = GemmUniversalMode.Gemm
|
| 581 |
-
|
| 582 |
-
batch_count = self._get_batch_count(A, B, C, D)
|
| 583 |
-
returned_batch_count = batch_count
|
| 584 |
-
|
| 585 |
-
# If we are running a batched GEMM in which there is a nonzero batch stride
|
| 586 |
-
# only for A, then we can fold the batched dimension of A into the M dimension
|
| 587 |
-
# (i.e., (b, m, k) x (k, n) -> (m*b, k) x (k, n)). This works only if both A
|
| 588 |
-
# and C are row major. A similar operation can be performed if only B has a nonzero
|
| 589 |
-
# batch dimension
|
| 590 |
-
if batch_count > 1:
|
| 591 |
-
A_row = self._layout_a == cutlass_cppgen.LayoutType.RowMajor
|
| 592 |
-
B_row = self._layout_b == cutlass_cppgen.LayoutType.RowMajor
|
| 593 |
-
C_row = self._layout_c == cutlass_cppgen.LayoutType.RowMajor
|
| 594 |
-
|
| 595 |
-
# Consider a Tensor to be batched if its rank is > 2 and
|
| 596 |
-
# the product of the modes beyond rank 2 equals our pre-determined batch size.
|
| 597 |
-
batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count)
|
| 598 |
-
|
| 599 |
-
if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row:
|
| 600 |
-
M *= batch_count
|
| 601 |
-
returned_batch_count = 1
|
| 602 |
-
elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row:
|
| 603 |
-
N *= batch_count
|
| 604 |
-
returned_batch_count = 1
|
| 605 |
-
else:
|
| 606 |
-
mode = GemmUniversalMode.Batched
|
| 607 |
-
|
| 608 |
-
return GemmCoord(M, N, K), mode, returned_batch_count
|
| 609 |
-
|
| 610 |
-
def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name):
|
| 611 |
-
"""
|
| 612 |
-
Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception
|
| 613 |
-
is raised if it does not.
|
| 614 |
-
|
| 615 |
-
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
| 616 |
-
:type tensor: numpy/cupy/torch array/tensor object
|
| 617 |
-
:param ref_dtype: data type for the tensor that this object was initialized to
|
| 618 |
-
:param ref_layout: layout for the tensor that this object was initialized to
|
| 619 |
-
:param name: identifier of the tensor to verify. Used in raising exceptions
|
| 620 |
-
:type name: str
|
| 621 |
-
"""
|
| 622 |
-
dtype, layout = datatypes.get_datatype_and_layout(tensor)
|
| 623 |
-
if dtype != ref_type or layout != ref_layout:
|
| 624 |
-
try:
|
| 625 |
-
# Attempt to transpose the tensor to fit the desired layout
|
| 626 |
-
tensor = tensor.transpose(-1, -2)
|
| 627 |
-
except:
|
| 628 |
-
raise Exception(f'Tensor {name} with type and layout ({dtype}, {layout}) '
|
| 629 |
-
f'does not match the expected type and '
|
| 630 |
-
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
|
| 631 |
-
|
| 632 |
-
def run(self, A=None, B=None, C=None, D=None,
|
| 633 |
-
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
|
| 634 |
-
stream: Optional[cuda.CUstream] = None) -> GemmArguments:
|
| 635 |
-
"""
|
| 636 |
-
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
| 637 |
-
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
| 638 |
-
``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta``
|
| 639 |
-
parameters provided in this call, or from those
|
| 640 |
-
passed in on the construction of this object -- one of the two must be specified.
|
| 641 |
-
|
| 642 |
-
By default, this call returns only once the kernel has completed. To launch the kernel
|
| 643 |
-
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
| 644 |
-
caller to syncrhonize the results of the kernel before attempting to access outputs
|
| 645 |
-
by calling ``sync()`` on the arguments returned from this call.
|
| 646 |
-
|
| 647 |
-
:param A: tensor representing data type and layout of operand A
|
| 648 |
-
:param B: tensor representing data type and layout of operand B
|
| 649 |
-
:param C: tensor representing data type and layout of operand C
|
| 650 |
-
:param D: tensor representing data type and layout of operand D
|
| 651 |
-
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 652 |
-
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 653 |
-
:param sync: whether the call should wait for the kernel to complete before returning
|
| 654 |
-
:type sync: bool
|
| 655 |
-
:param print_module: whether to print the emitted C++ code
|
| 656 |
-
:type print_module: bool
|
| 657 |
-
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
| 658 |
-
:type stream: :class:`cuda.cuda.CUstream`
|
| 659 |
-
|
| 660 |
-
:return: arguments passed in to the kernel
|
| 661 |
-
:rtype: cutlass_cppgen.backend.GemmArguments
|
| 662 |
-
"""
|
| 663 |
-
if not stream:
|
| 664 |
-
stream = cuda.CUstream(0)
|
| 665 |
-
super().run_setup()
|
| 666 |
-
A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
|
| 667 |
-
B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
|
| 668 |
-
C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
|
| 669 |
-
D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
|
| 670 |
-
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
| 671 |
-
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
| 672 |
-
|
| 673 |
-
is_void_c = self._element_c == DataType.void
|
| 674 |
-
|
| 675 |
-
self._verify_rank(A)
|
| 676 |
-
self._verify_rank(B)
|
| 677 |
-
if not is_void_c:
|
| 678 |
-
self._verify_rank(C)
|
| 679 |
-
self._verify_rank(D)
|
| 680 |
-
|
| 681 |
-
alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A")
|
| 682 |
-
alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B")
|
| 683 |
-
|
| 684 |
-
# Set C alignment based on D.shape so as to correctly get an alignment with void-C
|
| 685 |
-
# kernels, for which `C` is None.
|
| 686 |
-
alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C")
|
| 687 |
-
self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
| 688 |
-
alignment_C=alignment_c, print_module=print_module)
|
| 689 |
-
|
| 690 |
-
problem_size, mode, batch_count = self._get_problem_args(A, B, C, D)
|
| 691 |
-
|
| 692 |
-
if mode == GemmUniversalMode.Gemm or batch_count == 1:
|
| 693 |
-
kwargs = {'split_k_slices': 1}
|
| 694 |
-
else:
|
| 695 |
-
kwargs = {
|
| 696 |
-
'batch': batch_count,
|
| 697 |
-
'batch_strides': {
|
| 698 |
-
'A': self._get_batch_stride(A),
|
| 699 |
-
'B': self._get_batch_stride(B),
|
| 700 |
-
'C': self._get_batch_stride(C),
|
| 701 |
-
'D': self._get_batch_stride(D)
|
| 702 |
-
}
|
| 703 |
-
}
|
| 704 |
-
|
| 705 |
-
kwargs['stream'] = stream
|
| 706 |
-
|
| 707 |
-
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
|
| 708 |
-
output_op = self.operation.epilogue_type(visitor_args)
|
| 709 |
-
else:
|
| 710 |
-
output_op = self.operation.epilogue_type(alpha, beta)
|
| 711 |
-
|
| 712 |
-
arguments = GemmArguments(
|
| 713 |
-
operation=self.operation, problem_size=problem_size,
|
| 714 |
-
A=A, B=B, C=C, D=D,
|
| 715 |
-
output_op=output_op,
|
| 716 |
-
gemm_mode=mode,
|
| 717 |
-
**kwargs
|
| 718 |
-
)
|
| 719 |
-
|
| 720 |
-
self.operation.run(arguments)
|
| 721 |
-
|
| 722 |
-
if sync:
|
| 723 |
-
arguments.sync()
|
| 724 |
-
|
| 725 |
-
return arguments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/gemm_grouped.py
DELETED
|
@@ -1,269 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Ease-of-use interface for constructing, compiling, and running GEMMs.
|
| 35 |
-
|
| 36 |
-
The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run
|
| 37 |
-
grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
|
| 38 |
-
Under the hood, the interface will select sensible default parameters for the many template
|
| 39 |
-
parameters for CUTLASS grouped GEMMs.
|
| 40 |
-
|
| 41 |
-
Note: optimal performance is not to be expected from this interface. To achieve optimal
|
| 42 |
-
performance, one should specify and tune each configuration parameter.
|
| 43 |
-
|
| 44 |
-
The simplest example of using this interface is the following:
|
| 45 |
-
|
| 46 |
-
.. highlight:: python
|
| 47 |
-
.. code-block:: python
|
| 48 |
-
|
| 49 |
-
# As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
|
| 50 |
-
plan = cutlass_cppgen.op.GroupedGemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 51 |
-
plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
|
| 52 |
-
"""
|
| 53 |
-
from __future__ import annotations
|
| 54 |
-
from typing import Optional
|
| 55 |
-
from cutlass_library import DataTypeSize
|
| 56 |
-
|
| 57 |
-
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 58 |
-
cuda = lazy_import("cuda.cuda")
|
| 59 |
-
from cutlass_cppgen.backend.gemm_operation import (
|
| 60 |
-
GemmGroupedArguments,
|
| 61 |
-
GemmOperationGrouped,
|
| 62 |
-
)
|
| 63 |
-
from cutlass_cppgen.backend.library import (
|
| 64 |
-
SchedulerMode,
|
| 65 |
-
TensorDescription,
|
| 66 |
-
TileDescription,
|
| 67 |
-
)
|
| 68 |
-
from cutlass_cppgen.op.gemm import Gemm
|
| 69 |
-
from cutlass_cppgen.shape import GemmCoord
|
| 70 |
-
from cutlass_cppgen.utils import check, datatypes
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
class GroupedGemm(Gemm):
|
| 74 |
-
"""
|
| 75 |
-
Constructs a ``GroupedGemm`` object.
|
| 76 |
-
|
| 77 |
-
The data types and layouts of operands A, B, and C, along with the data type of output D
|
| 78 |
-
and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --
|
| 79 |
-
these are not to be changed after a ``GroupedGemm`` has been constructed.
|
| 80 |
-
|
| 81 |
-
The constructor has optional parameters for flexibly setting these parameters. Please see the constructor
|
| 82 |
-
for ``Gemm`` for examples of these.
|
| 83 |
-
|
| 84 |
-
:param cc: compute capability of device to generate kernels for
|
| 85 |
-
:type cc: int
|
| 86 |
-
:param A: tensor representing data type and layout of operands A
|
| 87 |
-
:param B: tensor representing data type and layout of operands B
|
| 88 |
-
:param C: tensor representing data type and layout of operands C
|
| 89 |
-
:param D: tensor representing data type and layout of operands D
|
| 90 |
-
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 91 |
-
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 92 |
-
:param element_accumulator: data type to be used in accumulation of the product of operands A and B
|
| 93 |
-
:type element_accumulator: cutlass_cppgen.DataType
|
| 94 |
-
:param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
|
| 95 |
-
:type element: cutlass_cppgen.DataType
|
| 96 |
-
:param layout: generic layout type to be used for operands A, B, C, and D
|
| 97 |
-
:type layout: cutlass_cppgen.LayoutType
|
| 98 |
-
:param element_A: data type to be used for operand A
|
| 99 |
-
:type element_A: cutlass_cppgen.DataType
|
| 100 |
-
:param element_B: data type to be used for operand B
|
| 101 |
-
:type element_B: cutlass_cppgen.DataType
|
| 102 |
-
:param element_C: data type to be used for operand C
|
| 103 |
-
:type element_C: cutlass_cppgen.DataType
|
| 104 |
-
:param element_D: data type to be used for operand D
|
| 105 |
-
:type element_D: cutlass_cppgen.DataType
|
| 106 |
-
:type layout_A: layout of operand A
|
| 107 |
-
:param layout_A: cutlass_cppgen.LayoutType
|
| 108 |
-
:type layout_B: layout of operand B
|
| 109 |
-
:param layout_B: cutlass_cppgen.LayoutType
|
| 110 |
-
:type layout_C: layout of operand C
|
| 111 |
-
:param layout_C: cutlass_cppgen.LayoutType
|
| 112 |
-
:type layout_D: layout of operand D
|
| 113 |
-
:param layout_D: cutlass_cppgen.LayoutType
|
| 114 |
-
"""
|
| 115 |
-
|
| 116 |
-
def __init__(
|
| 117 |
-
self, A=None, B=None, C=None, D=None,
|
| 118 |
-
alpha=1.0, beta=0.0, element_accumulator=None,
|
| 119 |
-
element=None, layout=None,
|
| 120 |
-
element_A=None, element_B=None, element_C=None, element_D=None,
|
| 121 |
-
layout_A=None, layout_B=None, layout_C=None,
|
| 122 |
-
cc: int = None,
|
| 123 |
-
):
|
| 124 |
-
super().__init__(
|
| 125 |
-
A=A, B=B, C=C, D=D,
|
| 126 |
-
alpha=alpha, beta=beta,
|
| 127 |
-
element_accumulator=element_accumulator,
|
| 128 |
-
element=element, layout=layout,
|
| 129 |
-
element_A=element_A, element_B=element_B,
|
| 130 |
-
element_C=element_C, element_D=element_D,
|
| 131 |
-
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
|
| 132 |
-
cc=cc
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
# Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80
|
| 136 |
-
if self.current_cc in [90, 100, 101, 103]:
|
| 137 |
-
self._reset_options(80)
|
| 138 |
-
self._reset_operations(reset_epilogue=False)
|
| 139 |
-
|
| 140 |
-
self.name = "grouped_gemm"
|
| 141 |
-
|
| 142 |
-
@Gemm.swizzling_functor.setter
|
| 143 |
-
def swizzling_functor(self, swizzling_functor):
|
| 144 |
-
"""
|
| 145 |
-
Sets the swizzling functor to the type specified by `swizzling_functor`
|
| 146 |
-
"""
|
| 147 |
-
raise Exception('Grouped GEMM does not currently support different swizzling functors')
|
| 148 |
-
|
| 149 |
-
def construct(self, tile_description: TileDescription = None,
|
| 150 |
-
alignment_A: int = None,
|
| 151 |
-
alignment_B: int = None,
|
| 152 |
-
alignment_C: int = None) -> GemmOperationGrouped:
|
| 153 |
-
"""
|
| 154 |
-
Constructs a ``cutlass_cppgen.backend.GemmOperationGrouped`` based on the input parameters and current
|
| 155 |
-
kernel specification of the ``Gemm`` object.
|
| 156 |
-
|
| 157 |
-
:param tile_description: tile description specifying shapes and operand types to use in the kernel
|
| 158 |
-
:type tile_description: cutlass_cppgen.backend.TileDescription
|
| 159 |
-
:param alignment_A: alignment of operand A
|
| 160 |
-
:type alignment_A: int
|
| 161 |
-
:param alignment_B: alignment of operand B
|
| 162 |
-
:type alignment_B: int
|
| 163 |
-
:param alignment_C: alignment of operand C
|
| 164 |
-
:type alignment_C: int
|
| 165 |
-
|
| 166 |
-
:return: operation that was constructed
|
| 167 |
-
:rtype: cutlass_cppgen.backend.GemmOperationGrouped
|
| 168 |
-
"""
|
| 169 |
-
alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A")))
|
| 170 |
-
alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B")))
|
| 171 |
-
alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C")))
|
| 172 |
-
|
| 173 |
-
self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)
|
| 174 |
-
|
| 175 |
-
tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
|
| 176 |
-
tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
|
| 177 |
-
tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)
|
| 178 |
-
|
| 179 |
-
if tile_description is None:
|
| 180 |
-
op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
|
| 181 |
-
tile_description = datatypes.td_from_profiler_op(op)
|
| 182 |
-
else:
|
| 183 |
-
valid, err_str = self._valid_tile_description(tile_description)
|
| 184 |
-
if not valid:
|
| 185 |
-
raise Exception(f"Invalid tile description. {err_str}")
|
| 186 |
-
self.tile_description = tile_description
|
| 187 |
-
|
| 188 |
-
operation = GemmOperationGrouped(
|
| 189 |
-
arch=self.current_cc,
|
| 190 |
-
tile_description=tile_description,
|
| 191 |
-
A=tensor_A, B=tensor_B, C=tensor_C,
|
| 192 |
-
epilogue_functor=self.epilogue_functor,
|
| 193 |
-
swizzling_functor=self._swizzling_functor,
|
| 194 |
-
precompute_mode=SchedulerMode.Device)
|
| 195 |
-
|
| 196 |
-
return operation
|
| 197 |
-
|
| 198 |
-
def run(self, A, B, C, D,
|
| 199 |
-
alpha=None, beta=None, sync: bool = True,
|
| 200 |
-
print_module: bool = False,
|
| 201 |
-
stream: Optional[cuda.CUstream] = None) -> GemmGroupedArguments:
|
| 202 |
-
"""
|
| 203 |
-
Runs the kernel currently specified.
|
| 204 |
-
|
| 205 |
-
By default, this call returns only once the kernel has completed. To launch the kernel
|
| 206 |
-
and immediately return, set ``sync=False``. In this case, it is the responsibility of the
|
| 207 |
-
caller to syncrhonize the results of the kernel before attempting to access outputs
|
| 208 |
-
by calling ``sync()`` on the arguments returned from this call.
|
| 209 |
-
|
| 210 |
-
:param A: list of tensors representing data type and layout of operand A
|
| 211 |
-
:type A: list
|
| 212 |
-
:param B: list of tensors representing data type and layout of operand B
|
| 213 |
-
:type B: list
|
| 214 |
-
:param C: list of tensors representing data type and layout of operand C
|
| 215 |
-
:type C: list
|
| 216 |
-
:param D: list of tensors representing data type and layout of operand D
|
| 217 |
-
:type D: list
|
| 218 |
-
:param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
|
| 219 |
-
:param beta: scalar parameter beta from GEMM operation that scales operand C
|
| 220 |
-
:param sync: whether the call should wait for the kernel to complete before returning
|
| 221 |
-
:type sync: bool
|
| 222 |
-
:param print_module: whether to print the emitted C++ code
|
| 223 |
-
:type print_module: bool
|
| 224 |
-
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
| 225 |
-
:type stream: :class:`cuda.cuda.CUstream`
|
| 226 |
-
|
| 227 |
-
:return: arguments passed in to the kernel
|
| 228 |
-
:rtype: cutlass_cppgen.backend.GemmGroupedArguments
|
| 229 |
-
"""
|
| 230 |
-
if not stream:
|
| 231 |
-
stream = cuda.CUstream(0)
|
| 232 |
-
|
| 233 |
-
super().run_setup()
|
| 234 |
-
|
| 235 |
-
if len(A) != len(B) or len(A) != len(C) or len(A) != len(D):
|
| 236 |
-
raise Exception("Lengths of A, B, C, and D lists must be equal")
|
| 237 |
-
|
| 238 |
-
problem_sizes = []
|
| 239 |
-
As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4))
|
| 240 |
-
for i in range(len(A)):
|
| 241 |
-
As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A")
|
| 242 |
-
Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
|
| 243 |
-
Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
|
| 244 |
-
Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
|
| 245 |
-
problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))
|
| 246 |
-
|
| 247 |
-
alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
|
| 248 |
-
beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")
|
| 249 |
-
|
| 250 |
-
alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As))
|
| 251 |
-
alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs))
|
| 252 |
-
alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs))
|
| 253 |
-
self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
|
| 254 |
-
alignment_C=alignment_c, print_module=print_module)
|
| 255 |
-
|
| 256 |
-
arguments = GemmGroupedArguments(
|
| 257 |
-
operation=self.operation,
|
| 258 |
-
problem_sizes=problem_sizes,
|
| 259 |
-
A=As, B=Bs, C=Cs, D=Ds,
|
| 260 |
-
output_op=self.operation.epilogue_type(alpha, beta),
|
| 261 |
-
stream=stream
|
| 262 |
-
)
|
| 263 |
-
|
| 264 |
-
self.operation.run(arguments)
|
| 265 |
-
|
| 266 |
-
if sync:
|
| 267 |
-
arguments.sync()
|
| 268 |
-
|
| 269 |
-
return arguments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/op/op.py
DELETED
|
@@ -1,431 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from bisect import bisect_left
|
| 38 |
-
|
| 39 |
-
from cutlass_library import (
|
| 40 |
-
DataType,
|
| 41 |
-
DataTypeSize,
|
| 42 |
-
MathOperation,
|
| 43 |
-
OperationKind,
|
| 44 |
-
SharedMemPerCC
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
import cutlass_cppgen
|
| 48 |
-
from cutlass_cppgen import get_option_registry
|
| 49 |
-
from cutlass_cppgen.backend.evt import EpilogueFunctorVisitor
|
| 50 |
-
from cutlass_cppgen.backend.evt.passes.util import cc_map
|
| 51 |
-
from cutlass_cppgen.backend.utils.device import device_cc
|
| 52 |
-
from cutlass_cppgen.epilogue import get_activations, get_activation_epilogue, identity
|
| 53 |
-
from cutlass_cppgen.library_defaults import KernelsForDataType, _generator_ccs
|
| 54 |
-
from cutlass_cppgen.swizzle import get_swizzling_functors
|
| 55 |
-
from cutlass_cppgen.utils import datatypes, check
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class OperationBase:
|
| 59 |
-
"""
|
| 60 |
-
Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d)
|
| 61 |
-
"""
|
| 62 |
-
|
| 63 |
-
def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm):
|
| 64 |
-
"""
|
| 65 |
-
:param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90
|
| 66 |
-
:type cc: int
|
| 67 |
-
:param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80
|
| 68 |
-
:type kernel_cc: int
|
| 69 |
-
:param operation_kind: class of operation that will be performed (e.g., GEMM, Conv)
|
| 70 |
-
:type operation_kind: cutlass_library.OperationKind
|
| 71 |
-
"""
|
| 72 |
-
self.operation_kind = operation_kind
|
| 73 |
-
self.cc = cc if cc is not None else device_cc()
|
| 74 |
-
self.specified_kernel_cc = kernel_cc is not None
|
| 75 |
-
self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc)
|
| 76 |
-
self.tile_description = None
|
| 77 |
-
self._math_operation = None
|
| 78 |
-
|
| 79 |
-
self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind)
|
| 80 |
-
|
| 81 |
-
if self.options is None:
|
| 82 |
-
raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}")
|
| 83 |
-
|
| 84 |
-
# Default activation function: identity
|
| 85 |
-
self._activation = identity
|
| 86 |
-
|
| 87 |
-
def _find_closest_cc(self, cc: int) -> int:
|
| 88 |
-
"""
|
| 89 |
-
Returns the closest CC in _generator_ccs less than or equal to `cc`
|
| 90 |
-
|
| 91 |
-
:param cc: compute capability to query
|
| 92 |
-
:type cc: int
|
| 93 |
-
|
| 94 |
-
:returns: closest CC in _generator_ccs less than or equal to `cc`
|
| 95 |
-
:rtype: int
|
| 96 |
-
"""
|
| 97 |
-
if cc in _generator_ccs:
|
| 98 |
-
return cc
|
| 99 |
-
|
| 100 |
-
# Find closest CC lower than this CC
|
| 101 |
-
idx = bisect_left(_generator_ccs, cc)
|
| 102 |
-
if idx == 0:
|
| 103 |
-
raise Exception(f'No valid CC to fall back to for {cc}')
|
| 104 |
-
return _generator_ccs[idx-1]
|
| 105 |
-
|
| 106 |
-
def activations(self) -> list:
|
| 107 |
-
"""
|
| 108 |
-
Returns possible activation functions that can be used
|
| 109 |
-
|
| 110 |
-
:return: list of activation functions that can be used
|
| 111 |
-
:rtype: list
|
| 112 |
-
"""
|
| 113 |
-
return get_activations()
|
| 114 |
-
|
| 115 |
-
def swizzling_functors(self) -> list:
|
| 116 |
-
"""
|
| 117 |
-
Returns possible swizzling functions that can be used
|
| 118 |
-
|
| 119 |
-
:return: list of swizzling functions that can be used
|
| 120 |
-
:rtype: list
|
| 121 |
-
"""
|
| 122 |
-
return get_swizzling_functors()
|
| 123 |
-
|
| 124 |
-
def _reset_options(self, cc: int):
|
| 125 |
-
"""
|
| 126 |
-
Resets the kernel options based on cc
|
| 127 |
-
|
| 128 |
-
:param cc: compute capability to reset to
|
| 129 |
-
:type cc: int
|
| 130 |
-
"""
|
| 131 |
-
if cc != self.current_cc:
|
| 132 |
-
if cc not in _generator_ccs:
|
| 133 |
-
raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.')
|
| 134 |
-
self.current_cc = cc
|
| 135 |
-
self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind)
|
| 136 |
-
|
| 137 |
-
def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name):
|
| 138 |
-
"""
|
| 139 |
-
Verifies the following properties:
|
| 140 |
-
1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``)
|
| 141 |
-
2) If ``scalar`` is not ``None``, its datatype must match matches the current version
|
| 142 |
-
set by the plan (i.e., those in ``ref_dtype``)
|
| 143 |
-
|
| 144 |
-
If either of these properties does not hold, an exception is raised. If these properties hold and
|
| 145 |
-
``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned.
|
| 146 |
-
|
| 147 |
-
:param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
| 148 |
-
:type scalar: numpy/cupy/torch scalar
|
| 149 |
-
:param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
| 150 |
-
:type ref_scalar: numpy/cupy/torch scalar
|
| 151 |
-
:param ref_dtype: data type for the scalar that this object was initialized to
|
| 152 |
-
:param name: identifier of the scalar to verify. Used in raising exceptions
|
| 153 |
-
:type name: str
|
| 154 |
-
|
| 155 |
-
:return: valid scalar to use
|
| 156 |
-
:rtype: numpy/cupy/torch scalar
|
| 157 |
-
"""
|
| 158 |
-
if scalar is None:
|
| 159 |
-
if ref_scalar is None:
|
| 160 |
-
raise Exception(f"Scalar {name} must be set.")
|
| 161 |
-
return ref_scalar
|
| 162 |
-
if hasattr(scalar, "dtype"):
|
| 163 |
-
dtype = datatypes.library_type(scalar.dtype)
|
| 164 |
-
if dtype != ref_dtype:
|
| 165 |
-
raise Exception(
|
| 166 |
-
f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}."
|
| 167 |
-
)
|
| 168 |
-
return scalar
|
| 169 |
-
|
| 170 |
-
def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name):
|
| 171 |
-
"""
|
| 172 |
-
Verifies the following properties:
|
| 173 |
-
If ref_dtype is not void:
|
| 174 |
-
1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``)
|
| 175 |
-
2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions
|
| 176 |
-
set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``)
|
| 177 |
-
If ref_dtype is void:
|
| 178 |
-
Neither ``tensor`` nor ``ref_tensor`` are set
|
| 179 |
-
|
| 180 |
-
If either of these properties does not hold, an exception is raised. If these properties hold and
|
| 181 |
-
``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned.
|
| 182 |
-
|
| 183 |
-
:param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in
|
| 184 |
-
:type tensor: numpy/cupy/torch array/tensor object
|
| 185 |
-
:param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in
|
| 186 |
-
:type ref_tensor: numpy/cupy/torch array/tensor object
|
| 187 |
-
:param ref_dtype: data type for the tensor that this object was initialized to
|
| 188 |
-
:param ref_layout: layout for the tensor that this object was initialized to
|
| 189 |
-
:param name: identifier of the tensor to verify. Used in raising exceptions
|
| 190 |
-
:type name: str
|
| 191 |
-
|
| 192 |
-
:return: valid tensor object to use
|
| 193 |
-
:rtype: numpy/cupy/torch array/tensor object
|
| 194 |
-
"""
|
| 195 |
-
if ref_dtype == DataType.void:
|
| 196 |
-
if tensor is not None or ref_tensor is not None:
|
| 197 |
-
raise Exception("Operands with element DataType.void must not be provided a tensor")
|
| 198 |
-
return None
|
| 199 |
-
|
| 200 |
-
if tensor is None:
|
| 201 |
-
if ref_tensor is None:
|
| 202 |
-
raise Exception(f"Tensor {name} must be set.")
|
| 203 |
-
return ref_tensor
|
| 204 |
-
|
| 205 |
-
self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
|
| 206 |
-
return tensor
|
| 207 |
-
|
| 208 |
-
@property
|
| 209 |
-
def opclass(self) -> cutlass_cppgen.OpcodeClass:
|
| 210 |
-
"""
|
| 211 |
-
Returns the opcode class currently in use
|
| 212 |
-
|
| 213 |
-
:return: opcode class currently in use
|
| 214 |
-
:rtype: cutlass_cppgen.OpcodeClass
|
| 215 |
-
"""
|
| 216 |
-
return self.op_class
|
| 217 |
-
|
| 218 |
-
@opclass.setter
|
| 219 |
-
def opclass(self, oc: cutlass_cppgen.OpcodeClass):
|
| 220 |
-
if isinstance(oc, str):
|
| 221 |
-
oc = datatypes.getattr_enum(cutlass_cppgen.OpcodeClass, oc)
|
| 222 |
-
if oc in self.possible_op_classes:
|
| 223 |
-
self.op_class = oc
|
| 224 |
-
else:
|
| 225 |
-
raise Exception(
|
| 226 |
-
f'Unsupported operation class {oc} for CC {self.cc} and data type combination '
|
| 227 |
-
f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and '
|
| 228 |
-
f'layout combination ({self._layout_a}, {self._layout_b}).')
|
| 229 |
-
|
| 230 |
-
# Changing the op class also changes the possible operations available. Reset these.
|
| 231 |
-
self.possible_operations = self.options.operations(
|
| 232 |
-
self.op_class, self._element_a, self._element_b,
|
| 233 |
-
self._element_accumulator, self._layout_a, self._layout_b, self._math_operation)
|
| 234 |
-
|
| 235 |
-
# Changing the op class changes the elements per access in the epilogue. Reset this.
|
| 236 |
-
if self.epilogue_functor is not None:
|
| 237 |
-
self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor)
|
| 238 |
-
|
| 239 |
-
@property
|
| 240 |
-
def math_operation(self) -> cutlass_cppgen.MathOperation:
|
| 241 |
-
"""
|
| 242 |
-
Returns the math operation currently in use
|
| 243 |
-
|
| 244 |
-
:return: math operation currently in use
|
| 245 |
-
:rtype: cutlass_cppgen.MathOperation
|
| 246 |
-
"""
|
| 247 |
-
return self._math_operation
|
| 248 |
-
|
| 249 |
-
@math_operation.setter
|
| 250 |
-
def math_operation(self, mo: cutlass_cppgen.MathOperation):
|
| 251 |
-
if isinstance(mo, str):
|
| 252 |
-
mo = datatypes.getattr_enum(cutlass_cppgen.MathOperation, mo)
|
| 253 |
-
|
| 254 |
-
if not self.specified_kernel_cc:
|
| 255 |
-
if self.current_cc in [90, 100, 101, 103]:
|
| 256 |
-
# CUTLASS 3.0 kernels do not use different math operations. If one is specified, we
|
| 257 |
-
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
| 258 |
-
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
| 259 |
-
self._reset_options(80)
|
| 260 |
-
self._reset_operations(reset_epilogue=False)
|
| 261 |
-
elif self.current_cc in [90, 100, 101, 103]:
|
| 262 |
-
raise Exception("CUTLASS 3.0 kernels do not use different math operations. "
|
| 263 |
-
"To use 2.x kernels with a specific math operation, do not set the `kernel_cc`"
|
| 264 |
-
"parameter when constructing the plan.")
|
| 265 |
-
|
| 266 |
-
self._math_operation = mo
|
| 267 |
-
self._reset_operations()
|
| 268 |
-
|
| 269 |
-
def _elements_per_access(self):
|
| 270 |
-
if self.op_class == cutlass_cppgen.OpcodeClass.Simt:
|
| 271 |
-
return 1
|
| 272 |
-
elif self._element_c != DataType.void:
|
| 273 |
-
return 128 // DataTypeSize[self._element_c]
|
| 274 |
-
else:
|
| 275 |
-
return 128 // max(self.possible_operations.alignments("C"))
|
| 276 |
-
|
| 277 |
-
def _create_epilogue_functor_activation(self, activation):
|
| 278 |
-
"""
|
| 279 |
-
Returns the epilogue functor with given activation function
|
| 280 |
-
"""
|
| 281 |
-
if self.epilogue_functor is None:
|
| 282 |
-
elements_per_access = self._elements_per_access()
|
| 283 |
-
else:
|
| 284 |
-
elements_per_access = self.epilogue_functor.epilogue_vector_length
|
| 285 |
-
|
| 286 |
-
if not self.specified_kernel_cc:
|
| 287 |
-
if self.current_cc in [90, 100, 101, 103] and activation != identity:
|
| 288 |
-
# CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation,
|
| 289 |
-
# revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels.
|
| 290 |
-
cutlass_cppgen.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
|
| 291 |
-
if self._element_c != self._element_d:
|
| 292 |
-
raise Exception("CUTLASS 2.x kernels require element C to be the same as element D")
|
| 293 |
-
self._reset_options(80)
|
| 294 |
-
self._reset_operations(reset_epilogue=False)
|
| 295 |
-
elif (self.cc in [90, 100, 101, 103] and self.current_cc not in [90, 100, 101, 103] and activation == identity and self._math_operation is None):
|
| 296 |
-
# SM80 fallback kernels are currently used. Since an identity activation is requested,
|
| 297 |
-
# we can switch back to using SM90 kernels.
|
| 298 |
-
self._reset_options(self.cc)
|
| 299 |
-
self._reset_operations(reset_epilogue=False)
|
| 300 |
-
else:
|
| 301 |
-
if self.current_cc in [90, 100, 101, 103] and activation != identity:
|
| 302 |
-
raise Exception("Epilogues with elementwise fusion are not currently supported "
|
| 303 |
-
"in the Python interface for 3.x kernels. To use 2.x kernels "
|
| 304 |
-
"with fused elementwise epilogues, do not set the `kernel_cc` "
|
| 305 |
-
"parameter when constructing the plan.")
|
| 306 |
-
|
| 307 |
-
return get_activation_epilogue(
|
| 308 |
-
activation,
|
| 309 |
-
self._element_d,
|
| 310 |
-
elements_per_access,
|
| 311 |
-
self._element_accumulator,
|
| 312 |
-
self._element_accumulator,
|
| 313 |
-
)
|
| 314 |
-
|
| 315 |
-
def _reset_epilogue_functor_activation(self, activation):
|
| 316 |
-
"""
|
| 317 |
-
Set the epilogue functor based on the provided activation function
|
| 318 |
-
"""
|
| 319 |
-
self.epilogue_functor = self._create_epilogue_functor_activation(activation)
|
| 320 |
-
|
| 321 |
-
def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor):
|
| 322 |
-
"""
|
| 323 |
-
Reset the alignment of the current epilogue functor based on alignment C
|
| 324 |
-
"""
|
| 325 |
-
if isinstance(epilogue_functor, EpilogueFunctorVisitor):
|
| 326 |
-
return epilogue_functor
|
| 327 |
-
|
| 328 |
-
if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'):
|
| 329 |
-
# Identity epilogue does not have 'activation_functor'
|
| 330 |
-
activation = identity
|
| 331 |
-
else:
|
| 332 |
-
activation = epilogue_functor.activation_functor
|
| 333 |
-
|
| 334 |
-
epilogue_functor = get_activation_epilogue(
|
| 335 |
-
activation,
|
| 336 |
-
self._element_d,
|
| 337 |
-
alignment,
|
| 338 |
-
self._element_accumulator,
|
| 339 |
-
self._element_accumulator,
|
| 340 |
-
)
|
| 341 |
-
return epilogue_functor
|
| 342 |
-
|
| 343 |
-
@property
|
| 344 |
-
def activation(self):
|
| 345 |
-
"""
|
| 346 |
-
Returns the type of the current activation function used
|
| 347 |
-
"""
|
| 348 |
-
if hasattr(self.epilogue_functor, "activation_functor"):
|
| 349 |
-
return self.epilogue_functor.activation_functor
|
| 350 |
-
else:
|
| 351 |
-
return identity
|
| 352 |
-
|
| 353 |
-
@activation.setter
|
| 354 |
-
def activation(self, act):
|
| 355 |
-
"""
|
| 356 |
-
Sets the type of the activation function to use
|
| 357 |
-
Activation can come with a set of arguments
|
| 358 |
-
|
| 359 |
-
:param act: type of activation function to use
|
| 360 |
-
:type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01)
|
| 361 |
-
|
| 362 |
-
"""
|
| 363 |
-
if isinstance(act, tuple):
|
| 364 |
-
if isinstance(act[0], str):
|
| 365 |
-
act_fn = getattr(cutlass_cppgen.backend.epilogue, act[0])
|
| 366 |
-
else:
|
| 367 |
-
act_fn = act[0]
|
| 368 |
-
self._reset_epilogue_functor_activation(act_fn)
|
| 369 |
-
self._activation_args = act[1]
|
| 370 |
-
self._activation = act[0]
|
| 371 |
-
else:
|
| 372 |
-
if isinstance(act, str):
|
| 373 |
-
act = getattr(cutlass_cppgen.backend.epilogue, act)
|
| 374 |
-
self._reset_epilogue_functor_activation(act)
|
| 375 |
-
self._activation = act
|
| 376 |
-
|
| 377 |
-
@property
|
| 378 |
-
def epilogue_visitor(self):
|
| 379 |
-
"""
|
| 380 |
-
Return the epilogue functor
|
| 381 |
-
"""
|
| 382 |
-
return self.epilogue_functor
|
| 383 |
-
|
| 384 |
-
@epilogue_visitor.setter
|
| 385 |
-
def epilogue_visitor(self, visitor):
|
| 386 |
-
"""
|
| 387 |
-
Create the epilogue visitor
|
| 388 |
-
"""
|
| 389 |
-
self.epilogue_functor = EpilogueFunctorVisitor(cc_map[self.cc], visitor)
|
| 390 |
-
|
| 391 |
-
# The epilogue_functor may consume too much shared memory
|
| 392 |
-
# Reset the possible operations
|
| 393 |
-
if self.cc not in [90, 100, 101, 103]:
|
| 394 |
-
# The shared memory is only a concern for sm90+ epilogue
|
| 395 |
-
# In sm80, the epilogue and mainloop share the shared memory
|
| 396 |
-
return
|
| 397 |
-
|
| 398 |
-
datatype_comb = self.possible_operations.datatype_comb
|
| 399 |
-
layout_comb = self.possible_operations.layout_comb
|
| 400 |
-
new_possible_operations = KernelsForDataType(datatype_comb, layout_comb)
|
| 401 |
-
for operation in self.possible_operations.all_operations:
|
| 402 |
-
td = datatypes.td_from_profiler_op(operation)
|
| 403 |
-
# Filter invalid epilogue schedules
|
| 404 |
-
if cc_map[self.cc] == 90 and td.epilogue_schedule not in [
|
| 405 |
-
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized,
|
| 406 |
-
cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative]:
|
| 407 |
-
continue
|
| 408 |
-
epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td)
|
| 409 |
-
|
| 410 |
-
# Verify the maximum number of mainloop stages
|
| 411 |
-
mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm)
|
| 412 |
-
smem_capacity_bytes = SharedMemPerCC[self.cc] << 10
|
| 413 |
-
mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage
|
| 414 |
-
if mainloop_stages < 2:
|
| 415 |
-
# Mainloop stages must >= 2
|
| 416 |
-
continue
|
| 417 |
-
|
| 418 |
-
new_possible_operations.add(operation)
|
| 419 |
-
if len(new_possible_operations.all_operations) == 0:
|
| 420 |
-
raise RuntimeError(
|
| 421 |
-
"The epilogue consumes too much shared memory. "
|
| 422 |
-
"No valid tile description is found in the generator.")
|
| 423 |
-
self.possible_operations = new_possible_operations
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
def run_setup(self):
|
| 427 |
-
"""
|
| 428 |
-
Steps that must be taken before caling `plan.run()`
|
| 429 |
-
"""
|
| 430 |
-
# Initialize the memory pool if, if not already done
|
| 431 |
-
cutlass_cppgen.get_memory_pool()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/shape.py
DELETED
|
@@ -1,184 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Utilities for expressing shapes
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from cutlass_library import (
|
| 38 |
-
ConvMode,
|
| 39 |
-
ConvKind,
|
| 40 |
-
LayoutType
|
| 41 |
-
)
|
| 42 |
-
from cutlass_cppgen.backend.c_types import (
|
| 43 |
-
Conv2DProblemSize_,
|
| 44 |
-
GemmCoord_,
|
| 45 |
-
GemmCoordBatched_
|
| 46 |
-
)
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
class MatrixCoord:
|
| 50 |
-
def __init__(self, row, col):
|
| 51 |
-
self._row = row
|
| 52 |
-
self._col = col
|
| 53 |
-
|
| 54 |
-
@property
|
| 55 |
-
def row(self):
|
| 56 |
-
return self._row
|
| 57 |
-
|
| 58 |
-
@property
|
| 59 |
-
def column(self):
|
| 60 |
-
return self._col
|
| 61 |
-
|
| 62 |
-
def leading_dimension(self, layout: LayoutType) -> int:
|
| 63 |
-
"""
|
| 64 |
-
Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord.
|
| 65 |
-
|
| 66 |
-
:param layout: layout of matrix
|
| 67 |
-
:type layout: cutlass_library.LayoutType
|
| 68 |
-
|
| 69 |
-
:returns: leading dimension
|
| 70 |
-
:rtype: int
|
| 71 |
-
"""
|
| 72 |
-
if layout == LayoutType.RowMajor:
|
| 73 |
-
return self._col
|
| 74 |
-
elif layout == LayoutType.ColumnMajor:
|
| 75 |
-
return self._row
|
| 76 |
-
else:
|
| 77 |
-
raise Exception(f'Unsupported layout for leading dimension calculation: {layout}')
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class GemmCoord:
|
| 81 |
-
def __init__(self, m: int, n: int, k: int):
|
| 82 |
-
self._m = m
|
| 83 |
-
self._n = n
|
| 84 |
-
self._k = k
|
| 85 |
-
|
| 86 |
-
@property
|
| 87 |
-
def m(self) -> int:
|
| 88 |
-
return self._m
|
| 89 |
-
|
| 90 |
-
@property
|
| 91 |
-
def n(self) -> int:
|
| 92 |
-
return self._n
|
| 93 |
-
|
| 94 |
-
@property
|
| 95 |
-
def k(self) -> int:
|
| 96 |
-
return self._k
|
| 97 |
-
|
| 98 |
-
@property
|
| 99 |
-
def mk(self) -> MatrixCoord:
|
| 100 |
-
return MatrixCoord(self._m, self._k)
|
| 101 |
-
|
| 102 |
-
@property
|
| 103 |
-
def mn(self) -> MatrixCoord:
|
| 104 |
-
return MatrixCoord(self._m, self._n)
|
| 105 |
-
|
| 106 |
-
@property
|
| 107 |
-
def kn(self) -> MatrixCoord:
|
| 108 |
-
return MatrixCoord(self._k, self._n)
|
| 109 |
-
|
| 110 |
-
@property
|
| 111 |
-
def ctype(self) -> GemmCoord_:
|
| 112 |
-
return GemmCoord_(self._m, self._n, self._k)
|
| 113 |
-
|
| 114 |
-
def batched_ctype(self, batch_count: int) -> GemmCoordBatched_:
|
| 115 |
-
return GemmCoordBatched_(self._m, self._n, self._k, batch_count)
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
class Conv2DProblemSize:
|
| 119 |
-
def __init__(
|
| 120 |
-
self, n: int, h: int, w: int, c: int,
|
| 121 |
-
k: int, r: int, s: int, c_: int,
|
| 122 |
-
pad_h: int, pad_w: int, stride_h: int, stride_w: int,
|
| 123 |
-
dilation_h: int, dilation_w: int, mode: ConvMode=ConvMode.CrossCorrelation,
|
| 124 |
-
split_k_slices: int=1, groups: int=1):
|
| 125 |
-
|
| 126 |
-
self.N = n
|
| 127 |
-
self.H = h
|
| 128 |
-
self.W = w
|
| 129 |
-
self.C = c
|
| 130 |
-
self.K = k
|
| 131 |
-
self.R = r
|
| 132 |
-
self.S = s
|
| 133 |
-
self.pad_h = pad_h
|
| 134 |
-
self.pad_w = pad_w
|
| 135 |
-
self.stride_h = stride_h
|
| 136 |
-
self.stride_w = stride_w
|
| 137 |
-
self.dilation_h = dilation_h
|
| 138 |
-
self.dilation_w = dilation_w
|
| 139 |
-
self.mode = int(mode)
|
| 140 |
-
self.split_k_slices = split_k_slices
|
| 141 |
-
self.groups = groups
|
| 142 |
-
self.P = ((h + pad_h * 2 - r * dilation_h) // stride_h) + 1
|
| 143 |
-
self.Q = ((w + pad_w * 2 - s * dilation_w) // stride_w) + 1
|
| 144 |
-
|
| 145 |
-
@property
|
| 146 |
-
def ctype(self) -> Conv2DProblemSize_:
|
| 147 |
-
return Conv2DProblemSize_(self)
|
| 148 |
-
|
| 149 |
-
def implicit_gemm_size(self, kind: ConvKind):
|
| 150 |
-
if kind == ConvKind.Fprop:
|
| 151 |
-
return GemmCoord(
|
| 152 |
-
self.N * self.P * self.Q,
|
| 153 |
-
self.K,
|
| 154 |
-
self.R * self.S * self.C // self.groups
|
| 155 |
-
)
|
| 156 |
-
elif kind == ConvKind.Dgrad:
|
| 157 |
-
return GemmCoord(
|
| 158 |
-
self.N * self.H * self.W,
|
| 159 |
-
self.C,
|
| 160 |
-
self.R * self.S * self.K
|
| 161 |
-
)
|
| 162 |
-
elif kind == ConvKind.Wgrad:
|
| 163 |
-
return GemmCoord(
|
| 164 |
-
self.K,
|
| 165 |
-
self.R * self.S * self.C,
|
| 166 |
-
self.N * self.P * self.Q
|
| 167 |
-
)
|
| 168 |
-
|
| 169 |
-
@staticmethod
|
| 170 |
-
def from_sizes(input_size, weight_size):
|
| 171 |
-
K, R, S, _ = weight_size
|
| 172 |
-
pad_h = R // 2
|
| 173 |
-
pad_w = S // 2
|
| 174 |
-
stride_h = 1
|
| 175 |
-
stride_w = 1
|
| 176 |
-
dilation_h = 1
|
| 177 |
-
dilation_w = 1
|
| 178 |
-
return Conv2DProblemSize(
|
| 179 |
-
*input_size,
|
| 180 |
-
*weight_size,
|
| 181 |
-
pad_h, pad_w,
|
| 182 |
-
stride_h, stride_w,
|
| 183 |
-
dilation_h, dilation_w
|
| 184 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/swizzle.py
DELETED
|
@@ -1,65 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Registry of swizzling functions
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
from cutlass_library import SwizzlingFunctor
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
IdentitySwizzle1 = SwizzlingFunctor.Identity1
|
| 41 |
-
IdentitySwizzle2 = SwizzlingFunctor.Identity2
|
| 42 |
-
IdentitySwizzle4 = SwizzlingFunctor.Identity4
|
| 43 |
-
IdentitySwizzle8 = SwizzlingFunctor.Identity8
|
| 44 |
-
HorizontalSwizzle = SwizzlingFunctor.Horizontal
|
| 45 |
-
ThreadblockSwizzleStreamK = SwizzlingFunctor.StreamK
|
| 46 |
-
StridedDgradIdentitySwizzle1 = SwizzlingFunctor.StridedDgradIdentity1
|
| 47 |
-
StridedDgradIdentitySwizzle4 = SwizzlingFunctor.StridedDgradIdentity4
|
| 48 |
-
StridedDgradHorizontalSwizzle = SwizzlingFunctor.StridedDgradHorizontal
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
_swizzling_functors = [
|
| 52 |
-
IdentitySwizzle1,
|
| 53 |
-
IdentitySwizzle2,
|
| 54 |
-
IdentitySwizzle4,
|
| 55 |
-
IdentitySwizzle8,
|
| 56 |
-
HorizontalSwizzle,
|
| 57 |
-
ThreadblockSwizzleStreamK,
|
| 58 |
-
StridedDgradIdentitySwizzle1,
|
| 59 |
-
StridedDgradIdentitySwizzle4,
|
| 60 |
-
StridedDgradHorizontalSwizzle,
|
| 61 |
-
]
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def get_swizzling_functors():
|
| 65 |
-
return _swizzling_functors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/__init__.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
from cutlass_cppgen.utils.check import (
|
| 34 |
-
alignment_or_default,
|
| 35 |
-
calculate_smem_usage,
|
| 36 |
-
calculate_smem_usage_per_stage,
|
| 37 |
-
valid_cluster_shape,
|
| 38 |
-
valid_schedule,
|
| 39 |
-
valid_stage_count,
|
| 40 |
-
update_alignment,
|
| 41 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/check.py
DELETED
|
@@ -1,262 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Utility functions for checking constraints on kernels and calculating kernel attributes
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import ctypes
|
| 38 |
-
|
| 39 |
-
from cutlass_library import DataTypeSize, KernelScheduleSuffixes, OperationKind, SharedMemPerCC
|
| 40 |
-
|
| 41 |
-
import cutlass_cppgen
|
| 42 |
-
from cutlass_cppgen.backend.library import TileDescription
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int:
|
| 46 |
-
"""
|
| 47 |
-
Returns the amount of shared memory in bytes consumed in a single stage of a kernel.
|
| 48 |
-
|
| 49 |
-
:param td: tile description to compute shared memory of
|
| 50 |
-
:type td: TileDescription
|
| 51 |
-
:param operation_kind: identifier for the type of operation being performed
|
| 52 |
-
:type operation_kind: cutlass_library.OperationKind
|
| 53 |
-
|
| 54 |
-
:return: number of bytes of shared memory consumed by a single stage
|
| 55 |
-
:rtype: int
|
| 56 |
-
"""
|
| 57 |
-
m, n, k = td.blackwell_threadblock_shape
|
| 58 |
-
if td.is_2sm:
|
| 59 |
-
m //= 2
|
| 60 |
-
|
| 61 |
-
if operation_kind == OperationKind.Gemm:
|
| 62 |
-
stage_barrier_bytes = 32
|
| 63 |
-
return (
|
| 64 |
-
(DataTypeSize[td.math_instruction.element_a] * m * k // 8)
|
| 65 |
-
+ (DataTypeSize[td.math_instruction.element_b] * k * n // 8)
|
| 66 |
-
+ stage_barrier_bytes
|
| 67 |
-
)
|
| 68 |
-
else:
|
| 69 |
-
raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
def calculate_smem_usage(operation) -> int:
|
| 73 |
-
"""
|
| 74 |
-
Returns the amount of shared memory in bytes consumed by a kernel.
|
| 75 |
-
|
| 76 |
-
:return: number of bytes of shared memory consumed by the operation
|
| 77 |
-
:return: int
|
| 78 |
-
"""
|
| 79 |
-
_per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind)
|
| 80 |
-
return _per_stage * operation.tile_description.stages
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def valid_stage_count(
|
| 84 |
-
cc: int,
|
| 85 |
-
kernel_cc: int,
|
| 86 |
-
td: TileDescription,
|
| 87 |
-
element_C: cutlass_cppgen.DataType = None,
|
| 88 |
-
element_D: cutlass_cppgen.DataType = None,
|
| 89 |
-
verbose: bool = True) -> tuple:
|
| 90 |
-
"""
|
| 91 |
-
Checks whether a device with `cc` supports the number of stages within `tile_description`, both
|
| 92 |
-
based on raw limits on the number of stages and based on shared memory capacity
|
| 93 |
-
|
| 94 |
-
:param cc: compute capability of device in question
|
| 95 |
-
:type cc: int
|
| 96 |
-
:param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS)
|
| 97 |
-
:type kernel_cc: int
|
| 98 |
-
:param td: tile description to check
|
| 99 |
-
:type td: TileDescription
|
| 100 |
-
:param element_C: data type of operand C
|
| 101 |
-
:type element_C: cutlass_cppgen.DataType
|
| 102 |
-
:param element_D: data type of operand D
|
| 103 |
-
:type element_D: cutlass_cppgen.DataType
|
| 104 |
-
:param verbose: whether to log warnings
|
| 105 |
-
:type verbose: bool
|
| 106 |
-
|
| 107 |
-
:return: tuple with the first element indicating whether the provided tile description is
|
| 108 |
-
valid for the provided device and the second element being an error message
|
| 109 |
-
:rtype: tuple
|
| 110 |
-
"""
|
| 111 |
-
if kernel_cc in [90, 100, 101, 103]:
|
| 112 |
-
if (td.stages is None or td.stages == 0):
|
| 113 |
-
# Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
|
| 114 |
-
# determines the stage count to use. Thus, all settings are valid in these scenarios.
|
| 115 |
-
return (True, "")
|
| 116 |
-
elif verbose:
|
| 117 |
-
cutlass_cppgen.logger.warning(
|
| 118 |
-
"Setting an explicit stage count for SM90 kernels currently may "
|
| 119 |
-
"result in compilation errors if the combination of tile shape, "
|
| 120 |
-
"stage count, and shared memory requirement of the epilogue exceeds "
|
| 121 |
-
"the available shared memory per SM.")
|
| 122 |
-
|
| 123 |
-
if td.stages <= 0:
|
| 124 |
-
return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")
|
| 125 |
-
|
| 126 |
-
if cc < 80 and td.stages != 2:
|
| 127 |
-
return (False, f"Tile description has stage count of {td.stages}, "
|
| 128 |
-
f"but only 2 stages are supported on SM{cc}.")
|
| 129 |
-
|
| 130 |
-
# The calculation below does not consider shared memory used by the epilogue and, thus,
|
| 131 |
-
# only catches cases in which the mainloop exceeds the device's shared memory capacity.
|
| 132 |
-
# This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the
|
| 133 |
-
# mainloop and epilogue is shared.
|
| 134 |
-
smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm)
|
| 135 |
-
smem_usage_mainloop = (smem_per_stage * td.stages)
|
| 136 |
-
smem_arch = SharedMemPerCC[cc] << 10
|
| 137 |
-
if smem_usage_mainloop > smem_arch:
|
| 138 |
-
return ( False,
|
| 139 |
-
"Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
|
| 140 |
-
f"Details:\n"
|
| 141 |
-
f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and "
|
| 142 |
-
f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n"
|
| 143 |
-
f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.")
|
| 144 |
-
|
| 145 |
-
return (True, "")
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
|
| 149 |
-
"""
|
| 150 |
-
Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`.
|
| 151 |
-
|
| 152 |
-
:param cc: compute capability of device in question
|
| 153 |
-
:type cc: int
|
| 154 |
-
:param cluster_shape: dimensions of thread block cluster shape to check
|
| 155 |
-
:type cluster_shape: list
|
| 156 |
-
|
| 157 |
-
:return: tuple with the first element indicating whether the provided cluster shape is
|
| 158 |
-
valid for the provided device and the second element being an error message
|
| 159 |
-
:rtype: tuple
|
| 160 |
-
"""
|
| 161 |
-
|
| 162 |
-
if cc < 90 or cc in [120, 121]:
|
| 163 |
-
if cluster_shape != [1, 1, 1]:
|
| 164 |
-
return (False,
|
| 165 |
-
f"Cluster shape for pre-SM90 architectures and SM 120 and 121 must be [1, 1, 1]. Received cluster shape of "
|
| 166 |
-
f"{cluster_shape} for SM{cc}.")
|
| 167 |
-
else:
|
| 168 |
-
return (True, "")
|
| 169 |
-
|
| 170 |
-
if len(cluster_shape) != 3:
|
| 171 |
-
return (False,
|
| 172 |
-
f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}")
|
| 173 |
-
|
| 174 |
-
if cluster_shape[2] != 1:
|
| 175 |
-
return (False,
|
| 176 |
-
"CUTLASS kernels currently require the third dimension of cluster shape to be 1. "
|
| 177 |
-
f"Received cluster shape of {cluster_shape}.")
|
| 178 |
-
|
| 179 |
-
return (True, "")
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
def valid_schedule(
|
| 183 |
-
cc: int,
|
| 184 |
-
kernel_schedule: cutlass_cppgen.KernelScheduleType,
|
| 185 |
-
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
|
| 186 |
-
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> tuple:
|
| 187 |
-
"""
|
| 188 |
-
Checks that the kernel and epilogue schedules passed in are a valid combination for
|
| 189 |
-
a device of compute capability ``cc``.
|
| 190 |
-
|
| 191 |
-
:param cc: compute capability of device in question
|
| 192 |
-
:type cc: int
|
| 193 |
-
:param kernel_schedule: kernel schedule type
|
| 194 |
-
:type kernel_schedule: cutlass_cppgen.KernelScheduleType
|
| 195 |
-
:param epilogue_schedule: epilogue schedule type
|
| 196 |
-
:type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
|
| 197 |
-
:param tile_scheduler: tile scheduler type
|
| 198 |
-
:type tile_scheduler: cutlass_cppgen.TileSchedulerType
|
| 199 |
-
|
| 200 |
-
:return: tuple with the first element indicating whether the provided schedules are
|
| 201 |
-
valid for the provided device and the second element being an error message
|
| 202 |
-
:rtype: tuple
|
| 203 |
-
"""
|
| 204 |
-
kernel_auto = (kernel_schedule == cutlass_cppgen.KernelScheduleType.ScheduleAuto)
|
| 205 |
-
epilogue_auto = (epilogue_schedule == cutlass_cppgen.EpilogueScheduleType.ScheduleAuto)
|
| 206 |
-
tile_scheduler_default = (tile_scheduler == cutlass_cppgen.TileSchedulerType.Default)
|
| 207 |
-
if (cc < 90 or cc in [120, 121]) and not (kernel_auto and epilogue_auto and tile_scheduler_default):
|
| 208 |
-
return (False, "Non-default schedules are only supported on SM90 and beyond (excluding SM120 and SM121)")
|
| 209 |
-
|
| 210 |
-
if cc == 90 and ((kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto)):
|
| 211 |
-
return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")
|
| 212 |
-
|
| 213 |
-
if not tile_scheduler_default:
|
| 214 |
-
cooperative_kernels = [cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
|
| 215 |
-
cutlass_cppgen.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
|
| 216 |
-
if cc == 90 and (tile_scheduler == cutlass_cppgen.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
|
| 217 |
-
return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
|
| 218 |
-
return (True, "")
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
def alignment_or_default(alignment_provided: int, default_alignment: int) -> int:
|
| 222 |
-
"""
|
| 223 |
-
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
|
| 224 |
-
that `alignment_provided` does not exceed `default_alignment`.
|
| 225 |
-
|
| 226 |
-
:param alignment_provided: alignment preference specified. Can be None.
|
| 227 |
-
:type alignment_provided: int
|
| 228 |
-
:param default_alignment: alignment to use if `alignment_provided` is None
|
| 229 |
-
:type default_alignment: int
|
| 230 |
-
|
| 231 |
-
:return: alignment to use
|
| 232 |
-
:rtype: int
|
| 233 |
-
"""
|
| 234 |
-
if alignment_provided is not None:
|
| 235 |
-
if alignment_provided > default_alignment:
|
| 236 |
-
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
|
| 237 |
-
return alignment_provided
|
| 238 |
-
|
| 239 |
-
return default_alignment
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
def update_alignment(alignment_provided:int, default_alignment: int) -> int:
|
| 243 |
-
"""
|
| 244 |
-
Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
|
| 245 |
-
that `alignment_provided` does not exceed `default_alignment`.
|
| 246 |
-
|
| 247 |
-
:param alignment_provided: alignment preference specified. Can be None.
|
| 248 |
-
:type alignment_provided: int
|
| 249 |
-
:param default_alignment: alignment to use if `alignment_provided` is None
|
| 250 |
-
:type default_alignment: int
|
| 251 |
-
|
| 252 |
-
:return: alignment to use
|
| 253 |
-
:rtype: int
|
| 254 |
-
"""
|
| 255 |
-
if alignment_provided is not None:
|
| 256 |
-
if alignment_provided > default_alignment:
|
| 257 |
-
if alignment_provided % default_alignment == 0:
|
| 258 |
-
return default_alignment
|
| 259 |
-
raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
|
| 260 |
-
return alignment_provided
|
| 261 |
-
|
| 262 |
-
return default_alignment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/datatypes.py
DELETED
|
@@ -1,362 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Utility functions for converting between frontend datatypes and CUTLASS datatypes
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import cutlass_cppgen
|
| 38 |
-
from cutlass_library import (
|
| 39 |
-
DataTypeSize,
|
| 40 |
-
MathOperation,
|
| 41 |
-
MathInstruction
|
| 42 |
-
)
|
| 43 |
-
from cutlass_cppgen.backend.library import (
|
| 44 |
-
TileDescription,
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
bfloat16_available = None
|
| 48 |
-
cupy_available = None
|
| 49 |
-
numpy_available = None
|
| 50 |
-
torch_available = None
|
| 51 |
-
_library_to_cupy_dict = None
|
| 52 |
-
_library_to_numpy_dict = None
|
| 53 |
-
_library_to_torch_dict = None
|
| 54 |
-
_torch_to_library_dict = None
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
def is_numpy_available():
|
| 58 |
-
global numpy_available, _library_to_numpy_dict
|
| 59 |
-
if numpy_available is None:
|
| 60 |
-
try:
|
| 61 |
-
import numpy as np
|
| 62 |
-
|
| 63 |
-
numpy_available = True
|
| 64 |
-
_library_to_numpy_dict = {
|
| 65 |
-
cutlass_cppgen.DataType.f16: np.float16,
|
| 66 |
-
cutlass_cppgen.DataType.f32: np.float32,
|
| 67 |
-
cutlass_cppgen.DataType.f64: np.float64,
|
| 68 |
-
cutlass_cppgen.DataType.s8: np.int8,
|
| 69 |
-
cutlass_cppgen.DataType.s32: np.int32,
|
| 70 |
-
}
|
| 71 |
-
except ImportError:
|
| 72 |
-
numpy_available = False
|
| 73 |
-
_library_to_numpy_dict = {}
|
| 74 |
-
return numpy_available
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def is_numpy_tensor(inp) -> bool:
|
| 78 |
-
if is_numpy_available():
|
| 79 |
-
import numpy as np
|
| 80 |
-
return isinstance(inp, np.ndarray)
|
| 81 |
-
return False
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def numpy_library_type(inp) -> cutlass_cppgen.DataType:
|
| 85 |
-
if is_numpy_available():
|
| 86 |
-
import numpy as np
|
| 87 |
-
if inp == np.float16:
|
| 88 |
-
return cutlass_cppgen.DataType.f16
|
| 89 |
-
elif inp == np.float32:
|
| 90 |
-
return cutlass_cppgen.DataType.f32
|
| 91 |
-
elif inp == np.float64:
|
| 92 |
-
return cutlass_cppgen.DataType.f64
|
| 93 |
-
elif inp == np.int8:
|
| 94 |
-
return cutlass_cppgen.DataType.s8
|
| 95 |
-
elif inp == np.int32:
|
| 96 |
-
return cutlass_cppgen.DataType.s32
|
| 97 |
-
return None
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
def numpy_type(inp):
|
| 101 |
-
return _library_to_numpy_dict.get(inp, None)
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
def is_cupy_available():
|
| 105 |
-
global cupy_available
|
| 106 |
-
if cupy_available is None:
|
| 107 |
-
try:
|
| 108 |
-
import cupy as cp
|
| 109 |
-
|
| 110 |
-
cupy_available = True
|
| 111 |
-
_library_to_cupy_dict = {
|
| 112 |
-
cutlass_cppgen.DataType.f16: cp.float16,
|
| 113 |
-
cutlass_cppgen.DataType.f32: cp.float32,
|
| 114 |
-
cutlass_cppgen.DataType.f64: cp.float64,
|
| 115 |
-
cutlass_cppgen.DataType.s8: cp.int8,
|
| 116 |
-
cutlass_cppgen.DataType.s32: cp.int32,
|
| 117 |
-
}
|
| 118 |
-
except ImportError:
|
| 119 |
-
cupy_available = False
|
| 120 |
-
_library_to_cupy_dict = {}
|
| 121 |
-
return cupy_available
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def is_cupy_tensor(inp) -> bool:
|
| 125 |
-
if is_cupy_available():
|
| 126 |
-
import cupy as cp
|
| 127 |
-
return isinstance(inp, cp.ndarray)
|
| 128 |
-
return False
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def cupy_library_type(inp) -> cutlass_cppgen.DataType:
|
| 132 |
-
if is_cupy_available():
|
| 133 |
-
import cupy as cp
|
| 134 |
-
if inp == cp.float16:
|
| 135 |
-
return cutlass_cppgen.DataType.f16
|
| 136 |
-
elif inp == cp.float32:
|
| 137 |
-
return cutlass_cppgen.DataType.f32
|
| 138 |
-
elif inp == cp.float64:
|
| 139 |
-
return cutlass_cppgen.DataType.f64
|
| 140 |
-
return None
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
def cupy_type(inp):
|
| 144 |
-
return _library_to_cupy_dict.get(inp, None)
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
def is_torch_available():
|
| 148 |
-
global torch_available, _library_to_torch_dict, _torch_to_library_dict
|
| 149 |
-
if torch_available is None:
|
| 150 |
-
try:
|
| 151 |
-
import torch
|
| 152 |
-
|
| 153 |
-
torch_available = True
|
| 154 |
-
_torch_to_library_dict = {
|
| 155 |
-
torch.half: cutlass_cppgen.DataType.f16,
|
| 156 |
-
torch.float16: cutlass_cppgen.DataType.f16,
|
| 157 |
-
torch.bfloat16: cutlass_cppgen.DataType.bf16,
|
| 158 |
-
torch.float: cutlass_cppgen.DataType.f32,
|
| 159 |
-
torch.float32: cutlass_cppgen.DataType.f32,
|
| 160 |
-
torch.double: cutlass_cppgen.DataType.f64,
|
| 161 |
-
torch.float64: cutlass_cppgen.DataType.f64,
|
| 162 |
-
torch.int8: cutlass_cppgen.DataType.s8,
|
| 163 |
-
torch.int32: cutlass_cppgen.DataType.s32,
|
| 164 |
-
torch.uint8: cutlass_cppgen.DataType.u8,
|
| 165 |
-
}
|
| 166 |
-
|
| 167 |
-
_library_to_torch_dict = {
|
| 168 |
-
cutlass_cppgen.DataType.f16: torch.half,
|
| 169 |
-
cutlass_cppgen.DataType.f16: torch.float16,
|
| 170 |
-
cutlass_cppgen.DataType.bf16: torch.bfloat16,
|
| 171 |
-
cutlass_cppgen.DataType.f32: torch.float,
|
| 172 |
-
cutlass_cppgen.DataType.f32: torch.float32,
|
| 173 |
-
cutlass_cppgen.DataType.f64: torch.double,
|
| 174 |
-
cutlass_cppgen.DataType.f64: torch.float64,
|
| 175 |
-
cutlass_cppgen.DataType.s8: torch.int8,
|
| 176 |
-
cutlass_cppgen.DataType.s32: torch.int32,
|
| 177 |
-
cutlass_cppgen.DataType.u8: torch.uint8,
|
| 178 |
-
}
|
| 179 |
-
|
| 180 |
-
def possibly_add_type(torch_type_name, cutlass_type):
|
| 181 |
-
# Only try adding the type if the version of torch being used supports it
|
| 182 |
-
if hasattr(torch, torch_type_name):
|
| 183 |
-
torch_type = getattr(torch, torch_type_name)
|
| 184 |
-
_torch_to_library_dict[torch_type] = cutlass_type
|
| 185 |
-
_library_to_torch_dict[cutlass_type] = torch_type
|
| 186 |
-
|
| 187 |
-
possibly_add_type("float8_e4m3fn", cutlass_cppgen.DataType.e4m3)
|
| 188 |
-
possibly_add_type("float8_e5m2", cutlass_cppgen.DataType.e5m2)
|
| 189 |
-
|
| 190 |
-
except ImportError:
|
| 191 |
-
torch_available = False
|
| 192 |
-
_torch_to_library_dict = {}
|
| 193 |
-
_library_to_torch_dict = {}
|
| 194 |
-
return torch_available
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def is_torch_tensor(inp) -> bool:
|
| 198 |
-
if is_torch_available():
|
| 199 |
-
import torch
|
| 200 |
-
return isinstance(inp, torch.Tensor)
|
| 201 |
-
return False
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
def torch_library_type(inp) -> cutlass_cppgen.DataType:
|
| 205 |
-
return _torch_to_library_dict.get(inp, None)
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
def torch_type(inp):
|
| 209 |
-
return _library_to_torch_dict.get(inp, None)
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
def is_bfloat16_available():
|
| 213 |
-
global bfloat16_available
|
| 214 |
-
|
| 215 |
-
if bfloat16_available is None:
|
| 216 |
-
try:
|
| 217 |
-
import bfloat16
|
| 218 |
-
|
| 219 |
-
bfloat16_available = True
|
| 220 |
-
except ImportError:
|
| 221 |
-
bfloat16_available = False
|
| 222 |
-
return bfloat16_available
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
def bfloat16_library_type(inp) -> cutlass_cppgen.DataType:
|
| 226 |
-
if is_bfloat16_available():
|
| 227 |
-
import bfloat16
|
| 228 |
-
if inp == bfloat16.bfloat16:
|
| 229 |
-
return cutlass_cppgen.DataType.bf16
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
def bfloat16_type(inp):
|
| 233 |
-
if is_bfloat16_available():
|
| 234 |
-
import bfloat16
|
| 235 |
-
if inp == cutlass_cppgen.DataType.bf16:
|
| 236 |
-
return bfloat16.bfloat16
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
def library_type(inp):
|
| 240 |
-
if inp in DataTypeSize:
|
| 241 |
-
return inp
|
| 242 |
-
|
| 243 |
-
for cvt_fn in [
|
| 244 |
-
bfloat16_library_type,
|
| 245 |
-
cupy_library_type,
|
| 246 |
-
numpy_library_type,
|
| 247 |
-
torch_library_type,
|
| 248 |
-
]:
|
| 249 |
-
out = cvt_fn(inp)
|
| 250 |
-
if out is not None:
|
| 251 |
-
return out
|
| 252 |
-
|
| 253 |
-
raise Exception(f"No available conversion from type {inp} to a library type.")
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
def _tensor_from_numpy(np_tensor):
|
| 257 |
-
dtype = library_type(np_tensor.dtype)
|
| 258 |
-
if np_tensor.flags.c_contiguous:
|
| 259 |
-
layout = cutlass_cppgen.LayoutType.RowMajor
|
| 260 |
-
elif np_tensor.flags.f_contiguous:
|
| 261 |
-
layout = cutlass_cppgen.LayoutType.ColumnMajor
|
| 262 |
-
return (dtype, layout)
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
def _tensor_from_torch(pt_tensor):
|
| 266 |
-
dtype = library_type(pt_tensor.dtype)
|
| 267 |
-
return (dtype, cutlass_cppgen.LayoutType.RowMajor)
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
def get_datatype_and_layout(tensor):
|
| 271 |
-
if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
|
| 272 |
-
return _tensor_from_numpy(tensor)
|
| 273 |
-
elif is_torch_tensor(tensor):
|
| 274 |
-
return _tensor_from_torch(tensor)
|
| 275 |
-
elif isinstance(tensor, float) or isinstance(tensor, int):
|
| 276 |
-
return (cutlass_cppgen.DataType.f32, cutlass_cppgen.LayoutType.RowMajor)
|
| 277 |
-
else:
|
| 278 |
-
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
def get_tensor_shape(tensor, op="GEMM"):
|
| 282 |
-
if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)):
|
| 283 |
-
return tensor.shape
|
| 284 |
-
elif is_torch_tensor(tensor):
|
| 285 |
-
size = tensor.size()
|
| 286 |
-
if op == "CONV":
|
| 287 |
-
# PyTorch Tensors have shape NCHW
|
| 288 |
-
return (size[0], size[2], size[3], size[1])
|
| 289 |
-
else:
|
| 290 |
-
return tuple(tensor.size())
|
| 291 |
-
elif isinstance(tensor, float) or isinstance(tensor, int):
|
| 292 |
-
return (1,)
|
| 293 |
-
else:
|
| 294 |
-
raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
_math_operation_value_map = {x.value: x for x in MathOperation}
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
def backend_math_operation(math_op: MathOperation):
|
| 301 |
-
if math_op.value not in _math_operation_value_map.keys():
|
| 302 |
-
raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.")
|
| 303 |
-
return _math_operation_value_map[math_op.value]
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
def construct_backend_td(td: cutlass_cppgen.TileDescription,
|
| 307 |
-
kernel_schedule: cutlass_cppgen.KernelScheduleType,
|
| 308 |
-
epilogue_schedule: cutlass_cppgen.EpilogueScheduleType,
|
| 309 |
-
tile_scheduler: cutlass_cppgen.TileSchedulerType) -> TileDescription:
|
| 310 |
-
mi = td.math_instruction
|
| 311 |
-
backend_mi = MathInstruction(
|
| 312 |
-
mi.instruction_shape,
|
| 313 |
-
mi.element_a,
|
| 314 |
-
mi.element_b,
|
| 315 |
-
mi.element_accumulator,
|
| 316 |
-
mi.opcode_class,
|
| 317 |
-
backend_math_operation(mi.math_operation)
|
| 318 |
-
)
|
| 319 |
-
cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1]
|
| 320 |
-
return TileDescription(td.threadblock_shape, td.stages, td.warp_count,
|
| 321 |
-
backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler)
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
def td_from_profiler_op(op) -> TileDescription:
|
| 325 |
-
"""
|
| 326 |
-
Converts the profiler's TileDescription in ``op`` into the backend TileDescription
|
| 327 |
-
|
| 328 |
-
:param op: profiler Operation
|
| 329 |
-
|
| 330 |
-
:returns: backend TileDescription
|
| 331 |
-
:rtype: cutlass_cppgen.backend.TileDescription
|
| 332 |
-
"""
|
| 333 |
-
kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None
|
| 334 |
-
eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None
|
| 335 |
-
tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None
|
| 336 |
-
return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule)
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
def td_from_profiler_td(td: TileDescription) -> TileDescription:
|
| 340 |
-
"""
|
| 341 |
-
Converts the profiler's TileDescription into the backend TileDescription
|
| 342 |
-
|
| 343 |
-
:param td: profiler TileDescription
|
| 344 |
-
:type td: cutlass_cppgen.TileDescription
|
| 345 |
-
|
| 346 |
-
:returns: backend TileDescription
|
| 347 |
-
:rtype: cutlass_cppgen.backend.TileDescription
|
| 348 |
-
"""
|
| 349 |
-
return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None)
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
def to_camel_case(snake_str):
|
| 353 |
-
return "".join(x.capitalize() for x in snake_str.lower().split("_"))
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
def getattr_enum(obj, attr_name):
|
| 357 |
-
# The attr_name is under the snake_case
|
| 358 |
-
camel_attr = to_camel_case(attr_name)
|
| 359 |
-
if hasattr(obj, camel_attr):
|
| 360 |
-
return getattr(obj, camel_attr)
|
| 361 |
-
else:
|
| 362 |
-
raise Exception(f"Invalid option: {attr_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/lazy_import.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
import importlib
|
| 33 |
-
from typing import Any
|
| 34 |
-
|
| 35 |
-
def lazy_import(mod_name: str) -> Any:
|
| 36 |
-
class Lazy:
|
| 37 |
-
def __getattr__(self, name:str) -> Any:
|
| 38 |
-
module = importlib.import_module(mod_name)
|
| 39 |
-
return getattr(module, name)
|
| 40 |
-
|
| 41 |
-
return Lazy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_cppgen/utils/profiler.py
DELETED
|
@@ -1,196 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Profiler based on the cuda events
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import re
|
| 38 |
-
import subprocess
|
| 39 |
-
|
| 40 |
-
from cutlass_cppgen.utils.lazy_import import lazy_import
|
| 41 |
-
cuda = lazy_import("cuda.cuda")
|
| 42 |
-
cudart = lazy_import("cuda.cudart")
|
| 43 |
-
import numpy as np
|
| 44 |
-
|
| 45 |
-
from cutlass_cppgen import CUTLASS_PATH
|
| 46 |
-
from cutlass_cppgen.backend.library import DataTypeSize
|
| 47 |
-
from cutlass_cppgen.op.op import OperationBase
|
| 48 |
-
from cutlass_cppgen.shape import GemmCoord
|
| 49 |
-
from cutlass_cppgen.utils.datatypes import is_numpy_tensor
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
class GpuTimer:
|
| 53 |
-
def __init__(self) -> None:
|
| 54 |
-
self.events = [
|
| 55 |
-
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
| 56 |
-
cuda.cuEventCreate(cuda.CUevent_flags.CU_EVENT_DEFAULT)[1],
|
| 57 |
-
]
|
| 58 |
-
|
| 59 |
-
def start(self, stream=None):
|
| 60 |
-
if not stream:
|
| 61 |
-
stream = cuda.CUstream(0)
|
| 62 |
-
|
| 63 |
-
(err,) = cuda.cuEventRecord(self.events[0], stream)
|
| 64 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 65 |
-
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 66 |
-
|
| 67 |
-
def stop(self, stream=None):
|
| 68 |
-
if not stream:
|
| 69 |
-
stream = cuda.CUstream(0)
|
| 70 |
-
|
| 71 |
-
(err,) = cuda.cuEventRecord(self.events[1], stream)
|
| 72 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 73 |
-
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 74 |
-
pass
|
| 75 |
-
|
| 76 |
-
def stop_and_wait(self, stream=None):
|
| 77 |
-
if not stream:
|
| 78 |
-
stream = cuda.CUstream(0)
|
| 79 |
-
|
| 80 |
-
self.stop(stream)
|
| 81 |
-
if stream:
|
| 82 |
-
(err,) = cuda.cuStreamSynchronize(stream)
|
| 83 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 84 |
-
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 85 |
-
else:
|
| 86 |
-
(err,) = cudart.cudaDeviceSynchronize()
|
| 87 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 88 |
-
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 89 |
-
|
| 90 |
-
def duration(self, iterations=1):
|
| 91 |
-
err, duration = cuda.cuEventElapsedTime(self.events[0], self.events[1])
|
| 92 |
-
if err != cuda.CUresult.CUDA_SUCCESS:
|
| 93 |
-
raise RuntimeError(f"CUDA Error {str(err)}")
|
| 94 |
-
return duration / float(iterations)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
class CUDAEventProfiler:
|
| 98 |
-
def __init__(self, op: OperationBase, warmup_iterations: int=500, iterations: int=500, *args, **kwargs) -> None:
|
| 99 |
-
self.arguments = op.run(*args, **kwargs)
|
| 100 |
-
self.operation = op.operation
|
| 101 |
-
self.warmup_iterations = warmup_iterations
|
| 102 |
-
self.iterations = iterations
|
| 103 |
-
self.timer = GpuTimer()
|
| 104 |
-
|
| 105 |
-
#
|
| 106 |
-
# Cutlass Python Interface Profiler
|
| 107 |
-
#
|
| 108 |
-
|
| 109 |
-
def __call__(self):
|
| 110 |
-
for _ in range(self.warmup_iterations):
|
| 111 |
-
self.operation.run(self.arguments)
|
| 112 |
-
|
| 113 |
-
self.timer.start()
|
| 114 |
-
for _ in range(self.iterations):
|
| 115 |
-
self.operation.run(self.arguments)
|
| 116 |
-
|
| 117 |
-
self.timer.stop_and_wait()
|
| 118 |
-
runtime = self.timer.duration(self.iterations)
|
| 119 |
-
return runtime
|
| 120 |
-
|
| 121 |
-
#
|
| 122 |
-
# CUTLASS Profiler
|
| 123 |
-
#
|
| 124 |
-
|
| 125 |
-
def run_cutlass_profiler(self):
|
| 126 |
-
alpha = 1.0
|
| 127 |
-
beta = 1.0
|
| 128 |
-
|
| 129 |
-
profiler_path = CUTLASS_PATH + "/build/tools/profiler/cutlass_profiler"
|
| 130 |
-
kernel_name = self.operation.procedural_name()
|
| 131 |
-
verification_providers = "device"
|
| 132 |
-
provider = "cutlass"
|
| 133 |
-
problem_size = self.arguments.problem_size
|
| 134 |
-
|
| 135 |
-
if "cutlass3x" in kernel_name:
|
| 136 |
-
# cutlass3x generator only have column-major output
|
| 137 |
-
layout_name = self.operation.layout_name_3x()
|
| 138 |
-
if layout_name[-1] == "t":
|
| 139 |
-
new_layout_name = "".join(["n" for l in layout_name if l == "t" or "t"])
|
| 140 |
-
problem_size = GemmCoord(problem_size.n, problem_size.m, problem_size.k)
|
| 141 |
-
kernel_name = kernel_name.replace(layout_name, new_layout_name)
|
| 142 |
-
|
| 143 |
-
batch_count = self.arguments.batch_count
|
| 144 |
-
|
| 145 |
-
cmd = f"{profiler_path} --kernels={kernel_name} --verification-providers={verification_providers} " \
|
| 146 |
-
f"--providers={provider} --m={problem_size.m()} --n={problem_size.n()} --k={problem_size.k()} " \
|
| 147 |
-
f"--batch_count={batch_count} --alpha={alpha} --beta={beta} "\
|
| 148 |
-
f"--warmup-iterations={self.warmup_iterations} --profiling-iterations={self.iterations}"
|
| 149 |
-
|
| 150 |
-
result = subprocess.getoutput(cmd)
|
| 151 |
-
|
| 152 |
-
m = re.search(r"Runtime:\s+(?P<runtime>\d+.\d+)", result)
|
| 153 |
-
runtime = float(m.group("runtime"))
|
| 154 |
-
|
| 155 |
-
m = re.search(r"Bytes:\s+(?P<bytes>\d+)", result)
|
| 156 |
-
bytes = int(m.group("bytes"))
|
| 157 |
-
|
| 158 |
-
m = re.search(r"FLOPs:\s+(?P<flops>\d+)", result)
|
| 159 |
-
flops = int(m.group("flops"))
|
| 160 |
-
|
| 161 |
-
# check if the problem size matches
|
| 162 |
-
assert bytes == self.bytes(problem_size, batch_count, beta)
|
| 163 |
-
assert flops == self.flops(problem_size, batch_count, beta)
|
| 164 |
-
|
| 165 |
-
return runtime
|
| 166 |
-
|
| 167 |
-
def bytes(self, problem_size, batch_count=1, beta=0.0):
|
| 168 |
-
m = problem_size.m()
|
| 169 |
-
n = problem_size.n()
|
| 170 |
-
k = problem_size.k()
|
| 171 |
-
|
| 172 |
-
bytes = (
|
| 173 |
-
(DataTypeSize[self.operation.A.element] * m // 8) * k
|
| 174 |
-
+ (DataTypeSize[self.operation.B.element] * n // 8) * k
|
| 175 |
-
+ (DataTypeSize[self.operation.C.element] * m // 8) * n
|
| 176 |
-
)
|
| 177 |
-
|
| 178 |
-
if beta != 0:
|
| 179 |
-
bytes += (DataTypeSize[self.operation.C.element] * m // 8) * n
|
| 180 |
-
|
| 181 |
-
bytes *= batch_count
|
| 182 |
-
|
| 183 |
-
return bytes
|
| 184 |
-
|
| 185 |
-
def flops(self, problem_size, batch_count=1, beta=0.0):
|
| 186 |
-
m = problem_size.m()
|
| 187 |
-
n = problem_size.n()
|
| 188 |
-
k = problem_size.k()
|
| 189 |
-
|
| 190 |
-
flops_ = (m * n * k) * 2 * batch_count
|
| 191 |
-
|
| 192 |
-
if beta != 0:
|
| 193 |
-
flops_ += m * n * batch_count * 2
|
| 194 |
-
|
| 195 |
-
return flops_
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/__init__.py
DELETED
|
@@ -1,63 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
import os
|
| 34 |
-
import sys
|
| 35 |
-
|
| 36 |
-
from . import conv2d_operation
|
| 37 |
-
from . import conv3d_operation
|
| 38 |
-
from . import emit_kernel_listing
|
| 39 |
-
from . import gemm_operation
|
| 40 |
-
|
| 41 |
-
if '-m' not in sys.argv:
|
| 42 |
-
# Do not import generator when running python -m cutlass_library.generator to
|
| 43 |
-
# avoid double-import warnings
|
| 44 |
-
from . import generator
|
| 45 |
-
|
| 46 |
-
from . import library
|
| 47 |
-
from . import manifest
|
| 48 |
-
from . import rank_2k_operation
|
| 49 |
-
from . import rank_k_operation
|
| 50 |
-
from . import symm_operation
|
| 51 |
-
from . import trmm_operation
|
| 52 |
-
# Make enum types from library.py accessible via cutlass_library.*
|
| 53 |
-
from .library import *
|
| 54 |
-
|
| 55 |
-
# Set up `source` to point to the path containing the CUTLASS source.
|
| 56 |
-
# Check first if the path contains a `source` subdirectory -- this will
|
| 57 |
-
# be the case when the package has been installed via pip. Otherwise,
|
| 58 |
-
# default to the root of CUTLASS.
|
| 59 |
-
install_source_path = os.path.join(__path__[0], 'source')
|
| 60 |
-
if os.path.isdir(install_source_path):
|
| 61 |
-
source_path = install_source_path
|
| 62 |
-
else:
|
| 63 |
-
source_path = os.path.join(__path__[0], '../..')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv2d_operation.py
DELETED
|
@@ -1,621 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Utilities for emitting Conv2d kernels
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import enum
|
| 38 |
-
import logging
|
| 39 |
-
import os.path
|
| 40 |
-
import shutil
|
| 41 |
-
from string import Template
|
| 42 |
-
|
| 43 |
-
try:
|
| 44 |
-
import builtins
|
| 45 |
-
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
-
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
-
from cutlass_library.library import *
|
| 48 |
-
from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
|
| 49 |
-
except ImportError:
|
| 50 |
-
from library import *
|
| 51 |
-
from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
|
| 52 |
-
|
| 53 |
-
_LOGGER = logging.getLogger(__name__)
|
| 54 |
-
|
| 55 |
-
###################################################################################################
|
| 56 |
-
|
| 57 |
-
#
|
| 58 |
-
class Conv2dOperation:
|
| 59 |
-
#
|
| 60 |
-
def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
|
| 61 |
-
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity1, \
|
| 62 |
-
group_mode = GroupMode.NoneGroup):
|
| 63 |
-
|
| 64 |
-
self.operation_kind = OperationKind.Conv2d
|
| 65 |
-
self.arch = arch
|
| 66 |
-
self.tile_description = tile_description
|
| 67 |
-
self.conv_kind = conv_kind
|
| 68 |
-
self.A = A
|
| 69 |
-
self.B = B
|
| 70 |
-
self.C = C
|
| 71 |
-
self.element_epilogue = element_epilogue
|
| 72 |
-
self.epilogue_functor = epilogue_functor
|
| 73 |
-
self.iterator_algorithm = iterator_algorithm
|
| 74 |
-
self.stride_support = stride_support
|
| 75 |
-
self.swizzling_functor = swizzling_functor
|
| 76 |
-
self.group_mode = group_mode
|
| 77 |
-
#
|
| 78 |
-
def is_complex(self):
|
| 79 |
-
complex_operators = [
|
| 80 |
-
MathOperation.multiply_add_complex,
|
| 81 |
-
MathOperation.multiply_add_complex_gaussian
|
| 82 |
-
]
|
| 83 |
-
return self.tile_description.math_instruction.math_operation in complex_operators
|
| 84 |
-
|
| 85 |
-
#
|
| 86 |
-
def is_mixed_input(self):
|
| 87 |
-
return self.A.element != self.B.element
|
| 88 |
-
|
| 89 |
-
#
|
| 90 |
-
def accumulator_type(self):
|
| 91 |
-
accum = self.tile_description.math_instruction.element_accumulator
|
| 92 |
-
|
| 93 |
-
if self.is_complex():
|
| 94 |
-
return get_complex_from_real(accum)
|
| 95 |
-
|
| 96 |
-
return accum
|
| 97 |
-
|
| 98 |
-
#
|
| 99 |
-
def core_name(self):
|
| 100 |
-
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
| 101 |
-
|
| 102 |
-
intermediate_type = ''
|
| 103 |
-
|
| 104 |
-
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
|
| 105 |
-
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
| 106 |
-
if self.tile_description.math_instruction.element_a != self.A.element and \
|
| 107 |
-
self.tile_description.math_instruction.element_a != self.accumulator_type():
|
| 108 |
-
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 109 |
-
else:
|
| 110 |
-
inst_shape = ''
|
| 111 |
-
|
| 112 |
-
return "%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \
|
| 113 |
-
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
|
| 114 |
-
|
| 115 |
-
#
|
| 116 |
-
def extended_name(self):
|
| 117 |
-
''' Append data types if they differ from compute type. '''
|
| 118 |
-
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
| 119 |
-
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 120 |
-
extended_name = "${element_c}_${core_name}_${element_a}"
|
| 121 |
-
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
| 122 |
-
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 123 |
-
extended_name = "${core_name}_${element_a}"
|
| 124 |
-
else:
|
| 125 |
-
extended_name = "${core_name}"
|
| 126 |
-
|
| 127 |
-
extended_name = SubstituteTemplate(extended_name, {
|
| 128 |
-
'element_a': DataTypeNames[self.A.element],
|
| 129 |
-
'element_c': DataTypeNames[self.C.element],
|
| 130 |
-
'core_name': self.core_name()
|
| 131 |
-
})
|
| 132 |
-
|
| 133 |
-
return extended_name
|
| 134 |
-
|
| 135 |
-
#
|
| 136 |
-
def layout_name(self):
|
| 137 |
-
return "%s" % (ShortLayoutTypeNames[self.A.layout])
|
| 138 |
-
|
| 139 |
-
#
|
| 140 |
-
def configuration_name(self):
|
| 141 |
-
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 142 |
-
|
| 143 |
-
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 144 |
-
|
| 145 |
-
threadblock = self.tile_description.procedural_name()
|
| 146 |
-
|
| 147 |
-
# grouped conv
|
| 148 |
-
if self.group_mode != GroupMode.NoneGroup:
|
| 149 |
-
group_conv_name = f"{GroupModeNames[self.group_mode]}_"
|
| 150 |
-
else:
|
| 151 |
-
group_conv_name = ""
|
| 152 |
-
|
| 153 |
-
if self.stride_support == StrideSupport.Unity and self.conv_kind == ConvKind.Dgrad:
|
| 154 |
-
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_unity_stride_${group_conv_name}align${alignment}"
|
| 155 |
-
else:
|
| 156 |
-
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_${layout}_${group_conv_name}align${alignment}"
|
| 157 |
-
|
| 158 |
-
return SubstituteTemplate(
|
| 159 |
-
configuration_name,
|
| 160 |
-
{
|
| 161 |
-
'opcode_class': opcode_class_name,
|
| 162 |
-
'extended_name': self.extended_name(),
|
| 163 |
-
'threadblock': threadblock,
|
| 164 |
-
'layout': self.layout_name(),
|
| 165 |
-
'alignment': "%d" % self.A.alignment,
|
| 166 |
-
'group_conv_name': group_conv_name
|
| 167 |
-
}
|
| 168 |
-
)
|
| 169 |
-
|
| 170 |
-
#
|
| 171 |
-
def procedural_name(self):
|
| 172 |
-
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 173 |
-
return self.configuration_name()
|
| 174 |
-
|
| 175 |
-
###################################################################################################
|
| 176 |
-
#
|
| 177 |
-
# Emits single instances of a CUTLASS device-wide operator
|
| 178 |
-
#
|
| 179 |
-
###################################################################################################
|
| 180 |
-
|
| 181 |
-
class EmitConv2dInstance:
|
| 182 |
-
def __init__(self):
|
| 183 |
-
# Emitter for CUTLASS 3 convolution operations
|
| 184 |
-
self.conv3x_emitter = EmitConv3xInstance()
|
| 185 |
-
self.template = """
|
| 186 |
-
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
| 187 |
-
using ${operation_name}_base =
|
| 188 |
-
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
|
| 189 |
-
${element_a},
|
| 190 |
-
${layout_a},
|
| 191 |
-
${element_b},
|
| 192 |
-
${layout_b},
|
| 193 |
-
${element_c},
|
| 194 |
-
${layout_c},
|
| 195 |
-
${element_accumulator},
|
| 196 |
-
${opcode_class},
|
| 197 |
-
${arch},
|
| 198 |
-
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 199 |
-
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
| 200 |
-
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 201 |
-
${epilogue_functor}<
|
| 202 |
-
${element_c},
|
| 203 |
-
${epilogue_vector_length},
|
| 204 |
-
${element_accumulator},
|
| 205 |
-
${element_epilogue}
|
| 206 |
-
>,
|
| 207 |
-
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
| 208 |
-
${stages},
|
| 209 |
-
${math_operator},
|
| 210 |
-
${iterator_algorithm},
|
| 211 |
-
${stride_support},
|
| 212 |
-
${align_a},
|
| 213 |
-
${align_b}
|
| 214 |
-
>::Kernel;
|
| 215 |
-
"""
|
| 216 |
-
self.template_group_conv = """
|
| 217 |
-
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
| 218 |
-
using ${operation_name}_base =
|
| 219 |
-
typename cutlass::conv::kernel::DefaultConv2dGroup${conv_kind_name}<
|
| 220 |
-
${element_a},
|
| 221 |
-
${layout_a},
|
| 222 |
-
${element_b},
|
| 223 |
-
${layout_b},
|
| 224 |
-
${element_c},
|
| 225 |
-
${layout_c},
|
| 226 |
-
${element_accumulator},
|
| 227 |
-
${opcode_class},
|
| 228 |
-
${arch},
|
| 229 |
-
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 230 |
-
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
| 231 |
-
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 232 |
-
${epilogue_functor}<
|
| 233 |
-
${element_c},
|
| 234 |
-
${epilogue_vector_length},
|
| 235 |
-
${element_accumulator},
|
| 236 |
-
${element_epilogue}
|
| 237 |
-
>,
|
| 238 |
-
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
| 239 |
-
${stages},
|
| 240 |
-
${math_operator},
|
| 241 |
-
${group_mode},
|
| 242 |
-
${iterator_algorithm},
|
| 243 |
-
${stride_support},
|
| 244 |
-
${align_a},
|
| 245 |
-
${align_b}
|
| 246 |
-
>::Kernel;
|
| 247 |
-
"""
|
| 248 |
-
self.template_depthwise_direct_conv = """
|
| 249 |
-
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
| 250 |
-
using ${operation_name}_base =
|
| 251 |
-
typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConv${conv_kind_name}<
|
| 252 |
-
${element_a},
|
| 253 |
-
${layout_a},
|
| 254 |
-
${element_b},
|
| 255 |
-
${layout_b},
|
| 256 |
-
${element_c},
|
| 257 |
-
${layout_c},
|
| 258 |
-
${element_accumulator},
|
| 259 |
-
${opcode_class},
|
| 260 |
-
${arch},
|
| 261 |
-
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 262 |
-
cutlass::conv::TensorNHWCShape<${threadblock_output_shape_n}, ${threadblock_output_shape_p}, ${threadblock_output_shape_q}, ${groups_per_cta}>,
|
| 263 |
-
cutlass::MatrixShape<${filter_shape_r}, ${filter_shape_s}>,
|
| 264 |
-
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
|
| 265 |
-
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 266 |
-
${epilogue_functor}<
|
| 267 |
-
${element_c},
|
| 268 |
-
${epilogue_vector_length},
|
| 269 |
-
${element_accumulator},
|
| 270 |
-
${element_epilogue},
|
| 271 |
-
cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling
|
| 272 |
-
>,
|
| 273 |
-
|
| 274 |
-
cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<
|
| 275 |
-
1,
|
| 276 |
-
${threadblock_output_shape_n},
|
| 277 |
-
${threadblock_output_shape_p},
|
| 278 |
-
${threadblock_output_shape_q}>,
|
| 279 |
-
${stages},
|
| 280 |
-
${math_operator},
|
| 281 |
-
${iterator_algorithm},
|
| 282 |
-
${stride_support},
|
| 283 |
-
cutlass::MatrixShape<${stride_r}, ${stride_s}>,
|
| 284 |
-
cutlass::MatrixShape<${dilation_r}, ${dilation_s}>
|
| 285 |
-
>::Kernel;
|
| 286 |
-
"""
|
| 287 |
-
|
| 288 |
-
def arch_number_to_type(self, arch: int):
|
| 289 |
-
return f"cutlass::arch::Sm{arch}"
|
| 290 |
-
|
| 291 |
-
def emit(self, operation):
|
| 292 |
-
_LOGGER.debug("*** EmitConv2dInstance::emit")
|
| 293 |
-
_LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
|
| 294 |
-
|
| 295 |
-
if hasattr(operation, 'is_3x') and operation.is_3x:
|
| 296 |
-
_LOGGER.debug("*** CUTLASS 3 operation")
|
| 297 |
-
return self.conv3x_emitter.emit(operation)
|
| 298 |
-
|
| 299 |
-
_LOGGER.debug("*** CUTLASS 2 operation")
|
| 300 |
-
|
| 301 |
-
warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
|
| 302 |
-
|
| 303 |
-
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
| 304 |
-
|
| 305 |
-
values = {
|
| 306 |
-
'operation_name': operation.procedural_name(),
|
| 307 |
-
'conv_kind': ConvKindTag[operation.conv_kind],
|
| 308 |
-
'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
|
| 309 |
-
'element_a': DataTypeTag[operation.A.element],
|
| 310 |
-
'layout_a': LayoutTag[operation.A.layout],
|
| 311 |
-
'element_b': DataTypeTag[operation.B.element],
|
| 312 |
-
'layout_b': LayoutTag[operation.B.layout],
|
| 313 |
-
'element_c': DataTypeTag[operation.C.element],
|
| 314 |
-
'layout_c': LayoutTag[operation.C.layout],
|
| 315 |
-
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 316 |
-
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 317 |
-
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 318 |
-
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 319 |
-
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 320 |
-
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 321 |
-
'warp_shape_m': str(warp_shape[0]),
|
| 322 |
-
'warp_shape_n': str(warp_shape[1]),
|
| 323 |
-
'warp_shape_k': str(warp_shape[2]),
|
| 324 |
-
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 325 |
-
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 326 |
-
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 327 |
-
'epilogue_vector_length': str(epilogue_vector_length),
|
| 328 |
-
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 329 |
-
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 330 |
-
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 331 |
-
'stages': str(operation.tile_description.stages),
|
| 332 |
-
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
|
| 333 |
-
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
| 334 |
-
'stride_support': StrideSupportTag[operation.stride_support],
|
| 335 |
-
'math_operator': 'cutlass::arch::OpMultiplyAddComplex' if operation.is_complex() else \
|
| 336 |
-
MathOperationTag[operation.tile_description.math_instruction.math_operation],
|
| 337 |
-
'align_a': str(operation.A.alignment),
|
| 338 |
-
'align_b': str(operation.B.alignment),
|
| 339 |
-
}
|
| 340 |
-
|
| 341 |
-
if operation.group_mode == GroupMode.NoneGroup:
|
| 342 |
-
_LOGGER.debug("*** group_mode=NoneGroup")
|
| 343 |
-
return SubstituteTemplate(self.template, values)
|
| 344 |
-
|
| 345 |
-
elif operation.group_mode == GroupMode.Depthwise:
|
| 346 |
-
_LOGGER.debug("*** group_mode=Depthwise")
|
| 347 |
-
values['group_mode'] = GroupModeTag[operation.group_mode]
|
| 348 |
-
# Setup other template params
|
| 349 |
-
values['threadblock_output_shape_n'] = str(operation.tile_description.threadblock_output_shape[0])
|
| 350 |
-
values['threadblock_output_shape_p'] = str(operation.tile_description.threadblock_output_shape[1])
|
| 351 |
-
values['threadblock_output_shape_q'] = str(operation.tile_description.threadblock_output_shape[2])
|
| 352 |
-
|
| 353 |
-
values['groups_per_cta'] = str(operation.tile_description.threadblock_output_shape[3])
|
| 354 |
-
|
| 355 |
-
values['filter_shape_r'] = str(operation.tile_description.filter_shape[0])
|
| 356 |
-
values['filter_shape_s'] = str(operation.tile_description.filter_shape[1])
|
| 357 |
-
|
| 358 |
-
values['stride_r'] = str(operation.tile_description.stride[0])
|
| 359 |
-
values['stride_s'] = str(operation.tile_description.stride[1])
|
| 360 |
-
|
| 361 |
-
values['dilation_r'] = str(operation.tile_description.dilation[0])
|
| 362 |
-
values['dilation_s'] = str(operation.tile_description.dilation[1])
|
| 363 |
-
|
| 364 |
-
return SubstituteTemplate(self.template_depthwise_direct_conv, values)
|
| 365 |
-
|
| 366 |
-
else:
|
| 367 |
-
_LOGGER.debug("*** group_mode=" + GroupModeTag[operation.group_mode])
|
| 368 |
-
values['group_mode'] = GroupModeTag[operation.group_mode]
|
| 369 |
-
return SubstituteTemplate(self.template_group_conv, values)
|
| 370 |
-
|
| 371 |
-
###################################################################################################
|
| 372 |
-
#
|
| 373 |
-
# Generator functions for all layouts
|
| 374 |
-
#
|
| 375 |
-
###################################################################################################
|
| 376 |
-
|
| 377 |
-
#
|
| 378 |
-
def GenerateConv2dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
|
| 379 |
-
_LOGGER.debug("*** GenerateConv2dTensorOp")
|
| 380 |
-
|
| 381 |
-
for tile in tile_descriptions:
|
| 382 |
-
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
| 383 |
-
|
| 384 |
-
if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
|
| 385 |
-
|
| 386 |
-
#
|
| 387 |
-
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
|
| 388 |
-
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
|
| 389 |
-
else [tile.math_instruction.element_accumulator,]
|
| 390 |
-
|
| 391 |
-
for output_type in output_types:
|
| 392 |
-
A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
|
| 393 |
-
B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
|
| 394 |
-
C = TensorDescription(output_type, LayoutType.TensorNHWC, max(1, int(align / DataTypeSize[output_type])))
|
| 395 |
-
|
| 396 |
-
manifest.append(Conv2dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
|
| 397 |
-
|
| 398 |
-
class EmitConv2dIncludes:
|
| 399 |
-
'''Emit includes that are specific to the operation.'''
|
| 400 |
-
|
| 401 |
-
def __init__(self):
|
| 402 |
-
self.includes = ['conv2d_operation.h']
|
| 403 |
-
self.emitter_3x = EmitConv3xIncludes()
|
| 404 |
-
|
| 405 |
-
def operation_is_3x(self, operation) -> bool:
|
| 406 |
-
"""Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
|
| 407 |
-
return hasattr(operation, 'is_3x') and operation.is_3x
|
| 408 |
-
|
| 409 |
-
def emit(self, operation) -> str:
|
| 410 |
-
if self.operation_is_3x(operation):
|
| 411 |
-
return self.emitter_3x.emit(operation)
|
| 412 |
-
|
| 413 |
-
return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
|
| 414 |
-
"\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
|
| 415 |
-
|
| 416 |
-
###################################################################################################
|
| 417 |
-
#
|
| 418 |
-
# Emitters functions for all targets
|
| 419 |
-
#
|
| 420 |
-
###################################################################################################
|
| 421 |
-
|
| 422 |
-
class EmitConv2dConfigurationLibrary:
|
| 423 |
-
def __init__(self, operation_path, configuration_name):
|
| 424 |
-
self.configuration_name = configuration_name
|
| 425 |
-
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
|
| 426 |
-
|
| 427 |
-
self.instance_emitter = EmitConv2dInstance()
|
| 428 |
-
self.includes_emitter = EmitConv2dIncludes()
|
| 429 |
-
|
| 430 |
-
self.header_template = """
|
| 431 |
-
/*
|
| 432 |
-
Generated by conv2d_operation.py - Do not edit.
|
| 433 |
-
*/
|
| 434 |
-
|
| 435 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 436 |
-
|
| 437 |
-
#include "cutlass/cutlass.h"
|
| 438 |
-
#include "cutlass/library/library.h"
|
| 439 |
-
#include "cutlass/library/manifest.h"
|
| 440 |
-
|
| 441 |
-
#include "library_internal.h"
|
| 442 |
-
"""
|
| 443 |
-
|
| 444 |
-
self.instance_template = """
|
| 445 |
-
${stub_begin}
|
| 446 |
-
${operation_instance}
|
| 447 |
-
// Derived class
|
| 448 |
-
struct ${operation_name} :
|
| 449 |
-
public ${operation_name}_base { };
|
| 450 |
-
${stub_end}
|
| 451 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 452 |
-
|
| 453 |
-
"""
|
| 454 |
-
|
| 455 |
-
self.configuration_header = """
|
| 456 |
-
|
| 457 |
-
namespace cutlass {
|
| 458 |
-
namespace library {
|
| 459 |
-
|
| 460 |
-
// Initialize all instances
|
| 461 |
-
void initialize_${configuration_name}(Manifest &manifest) {
|
| 462 |
-
"""
|
| 463 |
-
|
| 464 |
-
self.configuration_instance = """${stub_begin}
|
| 465 |
-
using Operation_${operation_name} = cutlass::conv::device::${kernel_name}<
|
| 466 |
-
${operation_name}>;
|
| 467 |
-
|
| 468 |
-
manifest.append(new cutlass::library::${operation_wrapper}<
|
| 469 |
-
Operation_${operation_name}
|
| 470 |
-
>(
|
| 471 |
-
"${operation_name}"
|
| 472 |
-
));
|
| 473 |
-
${stub_end}
|
| 474 |
-
"""
|
| 475 |
-
|
| 476 |
-
self.configuration_epilogue = "}\n"
|
| 477 |
-
|
| 478 |
-
self.epilogue_template = """
|
| 479 |
-
|
| 480 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 481 |
-
|
| 482 |
-
} // namespace library
|
| 483 |
-
} // namespace cutlass
|
| 484 |
-
|
| 485 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 486 |
-
|
| 487 |
-
"""
|
| 488 |
-
|
| 489 |
-
def operation_is_3x(self, operation):
|
| 490 |
-
"""Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
|
| 491 |
-
return hasattr(operation, 'is_3x') and operation.is_3x
|
| 492 |
-
|
| 493 |
-
def __enter__(self):
|
| 494 |
-
"""
|
| 495 |
-
Open the configuration_file, and write the "header" C++ code to it.
|
| 496 |
-
|
| 497 |
-
The "header" consists of a comment (that this is generated code,
|
| 498 |
-
so it should not be edited), and includes that are common
|
| 499 |
-
to all kinds of kernels.
|
| 500 |
-
"""
|
| 501 |
-
_LOGGER.debug('*** EmitConv2dConfigurationLibrary::__enter__')
|
| 502 |
-
_LOGGER.debug('*** configuration_path (file to write): ' +
|
| 503 |
-
str(self.configuration_path))
|
| 504 |
-
_LOGGER.debug('*** configuration_name: ' + self.configuration_name)
|
| 505 |
-
self.configuration_file = open(self.configuration_path, "w")
|
| 506 |
-
|
| 507 |
-
self.configuration_file.write(SubstituteTemplate(self.header_template, {
|
| 508 |
-
'configuration_name': self.configuration_name
|
| 509 |
-
}))
|
| 510 |
-
self.operations = []
|
| 511 |
-
return self
|
| 512 |
-
|
| 513 |
-
def emit(self, operation):
|
| 514 |
-
"""
|
| 515 |
-
Write three pieces of C++ code to the configuration_file
|
| 516 |
-
(that was opened by the __enter__ method above):
|
| 517 |
-
|
| 518 |
-
1. the header includes that are specific to the operation
|
| 519 |
-
(CUTLASS 2 vs. CUTLASS 3);
|
| 520 |
-
|
| 521 |
-
2. the "operation instance" (a "using" declaration ending in "_base"); and
|
| 522 |
-
|
| 523 |
-
3. the "operation name" (declaration and definition of a derived class
|
| 524 |
-
of the above operation instance).
|
| 525 |
-
|
| 526 |
-
The "using" declaration turns a C++ class name, possibly namespace-qualified,
|
| 527 |
-
possibly also with angle brackets, into a C-style, easily demangled identifier.
|
| 528 |
-
"""
|
| 529 |
-
_LOGGER.debug('*** EmitConv2dConfigurationLibrary::emit')
|
| 530 |
-
_LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name())
|
| 531 |
-
self.operations.append(operation)
|
| 532 |
-
|
| 533 |
-
self.configuration_file.write(self.includes_emitter.emit(operation))
|
| 534 |
-
|
| 535 |
-
stub_begin = ''
|
| 536 |
-
stub_end = ''
|
| 537 |
-
# It can be useful to stub (comment) out instantiations for testing.
|
| 538 |
-
# In this case, one need only set is_stub to True.
|
| 539 |
-
is_stub = False
|
| 540 |
-
if is_stub:
|
| 541 |
-
stub_begin = "// STUB for now\n#if 0"
|
| 542 |
-
stub_end = '#endif // 0'
|
| 543 |
-
|
| 544 |
-
self.configuration_file.write(Template(self.instance_template).substitute({
|
| 545 |
-
'configuration_name': self.configuration_name,
|
| 546 |
-
'operation_name': operation.procedural_name(),
|
| 547 |
-
'operation_instance': self.instance_emitter.emit(operation),
|
| 548 |
-
'stub_begin': stub_begin,
|
| 549 |
-
'stub_end': stub_end
|
| 550 |
-
}))
|
| 551 |
-
|
| 552 |
-
def __exit__(self, exception_type, exception_value, traceback):
|
| 553 |
-
"""
|
| 554 |
-
Write the rest of the C++ code to the configuration_file, and close the file.
|
| 555 |
-
|
| 556 |
-
The "rest of the C++ code" has the following components.
|
| 557 |
-
|
| 558 |
-
1. Configuration header: Open the namespace(s), and open the definition
|
| 559 |
-
of the "initialize_${configuration_name}" registration function
|
| 560 |
-
that registers the operation with the Manifest.
|
| 561 |
-
("Registration" helps turn C++ compile-time polymorphism
|
| 562 |
-
(via template parameters) into a run-time choice of parameters.)
|
| 563 |
-
|
| 564 |
-
2. Configuration instance: In the body of the registration function,
|
| 565 |
-
make a "using" declaration Operation_${operation_name} for the
|
| 566 |
-
operation type (which uses operation_name as its template argument).
|
| 567 |
-
Then, tell the manifest about the operation via a "manifest.append" call.
|
| 568 |
-
The argument of the call is a new instance of
|
| 569 |
-
"SomethingOperation<Operation_${operation_name}>"
|
| 570 |
-
(replace Something with a specific name).
|
| 571 |
-
|
| 572 |
-
3. Configuration epilogue: Close the definition of the registration function.
|
| 573 |
-
|
| 574 |
-
4. Epilogue template: Close the namespace(s).
|
| 575 |
-
"""
|
| 576 |
-
|
| 577 |
-
_LOGGER.debug('*** EmitConv2dConfigurationLibrary::__exit__')
|
| 578 |
-
_LOGGER.debug('*** configuration_path (file to write): ' +
|
| 579 |
-
str(self.configuration_path))
|
| 580 |
-
_LOGGER.debug('*** configuration_name: ' + self.configuration_name)
|
| 581 |
-
|
| 582 |
-
self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
|
| 583 |
-
'configuration_name': self.configuration_name
|
| 584 |
-
}))
|
| 585 |
-
|
| 586 |
-
for operation in self.operations:
|
| 587 |
-
stub_begin = ''
|
| 588 |
-
stub_end = ''
|
| 589 |
-
# It can be useful to stub (comment) out instantiations for testing.
|
| 590 |
-
# In this case, one need only set is_stub to True.
|
| 591 |
-
is_stub = False
|
| 592 |
-
if is_stub:
|
| 593 |
-
stub_begin = "// STUB for now\n#if 0"
|
| 594 |
-
stub_end = "#endif // 0"
|
| 595 |
-
|
| 596 |
-
if operation.group_mode == GroupMode.Depthwise:
|
| 597 |
-
kernel_name = 'DirectConvolution'
|
| 598 |
-
operation_wrapper = 'DirectConv2dOperation'
|
| 599 |
-
else:
|
| 600 |
-
kernel_name = 'ImplicitGemmConvolution'
|
| 601 |
-
operation_wrapper = 'Conv2dOperation'
|
| 602 |
-
if self.operation_is_3x(operation):
|
| 603 |
-
kernel_name = 'ConvUniversalAdapter'
|
| 604 |
-
operation_wrapper = 'ConvOperation3x'
|
| 605 |
-
|
| 606 |
-
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
|
| 607 |
-
'configuration_name': self.configuration_name,
|
| 608 |
-
'operation_name': operation.procedural_name(),
|
| 609 |
-
'kernel_name': kernel_name,
|
| 610 |
-
'operation_wrapper': operation_wrapper,
|
| 611 |
-
'stub_begin': stub_begin,
|
| 612 |
-
'stub_end': stub_end
|
| 613 |
-
}))
|
| 614 |
-
|
| 615 |
-
self.configuration_file.write(self.configuration_epilogue)
|
| 616 |
-
self.configuration_file.write(self.epilogue_template)
|
| 617 |
-
self.configuration_file.close()
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
###################################################################################################
|
| 621 |
-
###################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3d_operation.py
DELETED
|
@@ -1,482 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Utilities for emitting Conv3d kernels
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import enum
|
| 38 |
-
import logging
|
| 39 |
-
import os.path
|
| 40 |
-
import shutil
|
| 41 |
-
from string import Template
|
| 42 |
-
|
| 43 |
-
try:
|
| 44 |
-
import builtins
|
| 45 |
-
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
-
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
-
from cutlass_library.library import *
|
| 48 |
-
from cutlass_library.conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
|
| 49 |
-
except ImportError:
|
| 50 |
-
from library import *
|
| 51 |
-
from conv3x_emitter import EmitConv3xInstance, EmitConv3xIncludes
|
| 52 |
-
|
| 53 |
-
_LOGGER = logging.getLogger(__name__)
|
| 54 |
-
|
| 55 |
-
###################################################################################################
|
| 56 |
-
|
| 57 |
-
#
|
| 58 |
-
class Conv3dOperation:
|
| 59 |
-
#
|
| 60 |
-
def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, C, element_epilogue, \
|
| 61 |
-
stride_support, epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4):
|
| 62 |
-
|
| 63 |
-
self.operation_kind = OperationKind.Conv3d
|
| 64 |
-
self.arch = arch
|
| 65 |
-
self.tile_description = tile_description
|
| 66 |
-
self.conv_kind = conv_kind
|
| 67 |
-
self.A = A
|
| 68 |
-
self.B = B
|
| 69 |
-
self.C = C
|
| 70 |
-
self.element_epilogue = element_epilogue
|
| 71 |
-
self.epilogue_functor = epilogue_functor
|
| 72 |
-
self.iterator_algorithm = iterator_algorithm
|
| 73 |
-
self.stride_support = stride_support
|
| 74 |
-
self.swizzling_functor = swizzling_functor
|
| 75 |
-
|
| 76 |
-
#
|
| 77 |
-
def is_mixed_input(self):
|
| 78 |
-
return self.A.element != self.B.element
|
| 79 |
-
|
| 80 |
-
#
|
| 81 |
-
def core_name(self):
|
| 82 |
-
''' The basic operation kind is prefixed with a letter indicating the accumulation type. '''
|
| 83 |
-
|
| 84 |
-
intermediate_type = ''
|
| 85 |
-
|
| 86 |
-
if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp:
|
| 87 |
-
inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape)
|
| 88 |
-
if self.tile_description.math_instruction.element_a != self.A.element and \
|
| 89 |
-
self.tile_description.math_instruction.element_a != self.tile_description.math_instruction.element_accumulator:
|
| 90 |
-
intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a]
|
| 91 |
-
else:
|
| 92 |
-
inst_shape = ''
|
| 93 |
-
|
| 94 |
-
return "%s%s%s%s3d_%s" % (ShortDataTypeNames[self.tile_description.math_instruction.element_accumulator], \
|
| 95 |
-
inst_shape, intermediate_type, ConvKindNames[self.conv_kind], IteratorAlgorithmNames[self.iterator_algorithm])
|
| 96 |
-
|
| 97 |
-
#
|
| 98 |
-
def extended_name(self):
|
| 99 |
-
''' Append data types if they differ from compute type. '''
|
| 100 |
-
if self.C.element != self.tile_description.math_instruction.element_accumulator and \
|
| 101 |
-
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 102 |
-
extended_name = "${element_c}_${core_name}_${element_a}"
|
| 103 |
-
elif self.C.element == self.tile_description.math_instruction.element_accumulator and \
|
| 104 |
-
self.A.element != self.tile_description.math_instruction.element_accumulator:
|
| 105 |
-
extended_name = "${core_name}_${element_a}"
|
| 106 |
-
else:
|
| 107 |
-
extended_name = "${core_name}"
|
| 108 |
-
|
| 109 |
-
extended_name = SubstituteTemplate(extended_name, {
|
| 110 |
-
'element_a': DataTypeNames[self.A.element],
|
| 111 |
-
'element_c': DataTypeNames[self.C.element],
|
| 112 |
-
'core_name': self.core_name()
|
| 113 |
-
})
|
| 114 |
-
|
| 115 |
-
return extended_name
|
| 116 |
-
|
| 117 |
-
#
|
| 118 |
-
def configuration_name(self):
|
| 119 |
-
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 120 |
-
|
| 121 |
-
opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class]
|
| 122 |
-
|
| 123 |
-
threadblock = "%dx%d_%dx%d" % (
|
| 124 |
-
self.tile_description.threadblock_shape[0],
|
| 125 |
-
self.tile_description.threadblock_shape[1],
|
| 126 |
-
self.tile_description.threadblock_shape[2],
|
| 127 |
-
self.tile_description.stages
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
if self.stride_support == StrideSupport.Unity:
|
| 131 |
-
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}_unity_stride"
|
| 132 |
-
else:
|
| 133 |
-
configuration_name = "cutlass_${opcode_class}_${extended_name}_${threadblock}"
|
| 134 |
-
|
| 135 |
-
return SubstituteTemplate(
|
| 136 |
-
configuration_name,
|
| 137 |
-
{
|
| 138 |
-
'opcode_class': opcode_class_name,
|
| 139 |
-
'extended_name': self.extended_name(),
|
| 140 |
-
'threadblock': threadblock,
|
| 141 |
-
}
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
#
|
| 145 |
-
def procedural_name(self):
|
| 146 |
-
''' The full procedural name indicates architecture, extended name, tile size, and layout. '''
|
| 147 |
-
return self.configuration_name()
|
| 148 |
-
|
| 149 |
-
###################################################################################################
|
| 150 |
-
#
|
| 151 |
-
# Emits single instances of a CUTLASS device-wide operator
|
| 152 |
-
#
|
| 153 |
-
###################################################################################################
|
| 154 |
-
|
| 155 |
-
class EmitConv3dInstance:
|
| 156 |
-
def __init__(self):
|
| 157 |
-
# Emitter for CUTLASS 3 convolution operations
|
| 158 |
-
self.conv3x_emitter = EmitConv3xInstance()
|
| 159 |
-
self.template = """
|
| 160 |
-
// Conv3d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
|
| 161 |
-
using ${operation_name}_base =
|
| 162 |
-
typename cutlass::conv::kernel::DefaultConv3d${conv_kind_name}<
|
| 163 |
-
${element_a},
|
| 164 |
-
cutlass::layout::TensorNDHWC,
|
| 165 |
-
${element_b},
|
| 166 |
-
cutlass::layout::TensorNDHWC,
|
| 167 |
-
${element_c},
|
| 168 |
-
cutlass::layout::TensorNDHWC,
|
| 169 |
-
${element_accumulator},
|
| 170 |
-
${opcode_class},
|
| 171 |
-
${arch},
|
| 172 |
-
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
|
| 173 |
-
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
|
| 174 |
-
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
|
| 175 |
-
${epilogue_functor}<
|
| 176 |
-
${element_c},
|
| 177 |
-
${epilogue_vector_length},
|
| 178 |
-
${element_accumulator},
|
| 179 |
-
${element_epilogue}
|
| 180 |
-
>,
|
| 181 |
-
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
|
| 182 |
-
${stages},
|
| 183 |
-
cutlass::arch::OpMultiplyAdd,
|
| 184 |
-
${iterator_algorithm},
|
| 185 |
-
${stride_support}
|
| 186 |
-
>::Kernel;
|
| 187 |
-
"""
|
| 188 |
-
|
| 189 |
-
def emit(self, operation):
|
| 190 |
-
_LOGGER.debug("*** EmitConv3dInstance::emit")
|
| 191 |
-
_LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
|
| 192 |
-
|
| 193 |
-
if hasattr(operation, 'is_3x') and operation.is_3x:
|
| 194 |
-
_LOGGER.debug("*** CUTLASS 3 operation")
|
| 195 |
-
return self.conv3x_emitter.emit(operation)
|
| 196 |
-
|
| 197 |
-
_LOGGER.debug("*** CUTLASS 2 operation")
|
| 198 |
-
|
| 199 |
-
warp_shape = [int(operation.tile_description.threadblock_shape[idx] / operation.tile_description.warp_count[idx]) for idx in range(3)]
|
| 200 |
-
|
| 201 |
-
epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element])
|
| 202 |
-
|
| 203 |
-
values = {
|
| 204 |
-
'operation_name': operation.procedural_name(),
|
| 205 |
-
'conv_kind': ConvKindTag[operation.conv_kind],
|
| 206 |
-
'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
|
| 207 |
-
'element_a': DataTypeTag[operation.A.element],
|
| 208 |
-
'layout_a': LayoutTag[operation.A.layout],
|
| 209 |
-
'element_b': DataTypeTag[operation.B.element],
|
| 210 |
-
'layout_b': LayoutTag[operation.B.layout],
|
| 211 |
-
'element_c': DataTypeTag[operation.C.element],
|
| 212 |
-
'layout_c': LayoutTag[operation.C.layout],
|
| 213 |
-
'element_accumulator': DataTypeTag[operation.tile_description.math_instruction.element_accumulator],
|
| 214 |
-
'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class],
|
| 215 |
-
'arch': "cutlass::arch::Sm%d" % operation.arch,
|
| 216 |
-
'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]),
|
| 217 |
-
'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]),
|
| 218 |
-
'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]),
|
| 219 |
-
'warp_shape_m': str(warp_shape[0]),
|
| 220 |
-
'warp_shape_n': str(warp_shape[1]),
|
| 221 |
-
'warp_shape_k': str(warp_shape[2]),
|
| 222 |
-
'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]),
|
| 223 |
-
'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]),
|
| 224 |
-
'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]),
|
| 225 |
-
'epilogue_vector_length': str(epilogue_vector_length),
|
| 226 |
-
'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor],
|
| 227 |
-
'element_epilogue': str(DataTypeTag[operation.element_epilogue]),
|
| 228 |
-
'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor],
|
| 229 |
-
'stages': str(operation.tile_description.stages),
|
| 230 |
-
'iterator_algorithm': IteratorAlgorithmTag[operation.iterator_algorithm],
|
| 231 |
-
'iterator_algorithm_name': IteratorAlgorithmNames[operation.iterator_algorithm].capitalize(),
|
| 232 |
-
'stride_support': StrideSupportTag[operation.stride_support]
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
-
return SubstituteTemplate(self.template, values)
|
| 236 |
-
|
| 237 |
-
###################################################################################################
|
| 238 |
-
#
|
| 239 |
-
# Generator functions for all layouts
|
| 240 |
-
#
|
| 241 |
-
###################################################################################################
|
| 242 |
-
|
| 243 |
-
#
|
| 244 |
-
def GenerateConv3dTensorOp(manifest, tile_descriptions, min_cc, align = 128):
|
| 245 |
-
|
| 246 |
-
for tile in tile_descriptions:
|
| 247 |
-
for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]:
|
| 248 |
-
|
| 249 |
-
if conv_kind == ConvKind.Fprop or (tile.math_instruction.element_accumulator in [DataType.f16, DataType.f32]):
|
| 250 |
-
|
| 251 |
-
#
|
| 252 |
-
output_types = [tile.math_instruction.element_a, tile.math_instruction.element_accumulator] \
|
| 253 |
-
if DataTypeSize[tile.math_instruction.element_accumulator] == 32 \
|
| 254 |
-
else [tile.math_instruction.element_accumulator,]
|
| 255 |
-
|
| 256 |
-
for output_type in output_types:
|
| 257 |
-
A = TensorDescription(tile.math_instruction.element_a, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_a]))
|
| 258 |
-
B = TensorDescription(tile.math_instruction.element_b, LayoutType.TensorNDHWC, int(align / DataTypeSize[tile.math_instruction.element_b]))
|
| 259 |
-
C = TensorDescription(output_type, LayoutType.TensorNDHWC, max(1, int(align / DataTypeSize[output_type])))
|
| 260 |
-
|
| 261 |
-
manifest.append(Conv3dOperation(conv_kind, min_cc, tile, A, B, C, tile.math_instruction.element_accumulator))
|
| 262 |
-
|
| 263 |
-
class EmitConv3dIncludes:
|
| 264 |
-
'''Emit includes that are specific to the operation.'''
|
| 265 |
-
|
| 266 |
-
def __init__(self):
|
| 267 |
-
self.includes = ['conv3d_operation.h']
|
| 268 |
-
self.emitter_3x = EmitConv3xIncludes()
|
| 269 |
-
|
| 270 |
-
def operation_is_3x(self, operation) -> bool:
|
| 271 |
-
"""Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
|
| 272 |
-
return hasattr(operation, 'is_3x') and operation.is_3x
|
| 273 |
-
|
| 274 |
-
def emit(self, operation) -> str:
|
| 275 |
-
if self.operation_is_3x(operation):
|
| 276 |
-
return self.emitter_3x.emit(operation)
|
| 277 |
-
|
| 278 |
-
return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
|
| 279 |
-
"\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
|
| 280 |
-
|
| 281 |
-
###################################################################################################
|
| 282 |
-
#
|
| 283 |
-
# Emitters functions for all targets
|
| 284 |
-
#
|
| 285 |
-
###################################################################################################
|
| 286 |
-
|
| 287 |
-
class EmitConv3dConfigurationLibrary:
|
| 288 |
-
def __init__(self, operation_path, configuration_name):
|
| 289 |
-
self.configuration_name = configuration_name
|
| 290 |
-
self.configuration_path = os.path.join(operation_path, "%s.cu" % configuration_name)
|
| 291 |
-
|
| 292 |
-
self.instance_emitter = EmitConv3dInstance()
|
| 293 |
-
self.includes_emitter = EmitConv3dIncludes()
|
| 294 |
-
|
| 295 |
-
self.header_template = """
|
| 296 |
-
/*
|
| 297 |
-
Generated by conv3d_operation.py - Do not edit.
|
| 298 |
-
*/
|
| 299 |
-
|
| 300 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 301 |
-
|
| 302 |
-
#include "cutlass/cutlass.h"
|
| 303 |
-
#include "cutlass/library/library.h"
|
| 304 |
-
#include "cutlass/library/manifest.h"
|
| 305 |
-
|
| 306 |
-
#include "library_internal.h"
|
| 307 |
-
"""
|
| 308 |
-
|
| 309 |
-
self.instance_template = """
|
| 310 |
-
${stub_begin}
|
| 311 |
-
${operation_instance}
|
| 312 |
-
// Derived class
|
| 313 |
-
struct ${operation_name} :
|
| 314 |
-
public ${operation_name}_base { };
|
| 315 |
-
${stub_end}
|
| 316 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 317 |
-
|
| 318 |
-
"""
|
| 319 |
-
|
| 320 |
-
self.configuration_header = """
|
| 321 |
-
|
| 322 |
-
namespace cutlass {
|
| 323 |
-
namespace library {
|
| 324 |
-
|
| 325 |
-
// Initialize all instances
|
| 326 |
-
void initialize_${configuration_name}(Manifest &manifest) {
|
| 327 |
-
"""
|
| 328 |
-
|
| 329 |
-
self.configuration_instance = """${stub_begin}
|
| 330 |
-
using Operation_${operation_name} = cutlass::conv::device::${kernel_name}<
|
| 331 |
-
${operation_name}>;
|
| 332 |
-
|
| 333 |
-
manifest.append(new cutlass::library::${operation_wrapper}<
|
| 334 |
-
Operation_${operation_name}
|
| 335 |
-
>(
|
| 336 |
-
"${operation_name}"
|
| 337 |
-
));
|
| 338 |
-
${stub_end}
|
| 339 |
-
"""
|
| 340 |
-
|
| 341 |
-
self.configuration_epilogue = "}\n"
|
| 342 |
-
|
| 343 |
-
self.epilogue_template = """
|
| 344 |
-
|
| 345 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 346 |
-
|
| 347 |
-
} // namespace library
|
| 348 |
-
} // namespace cutlass
|
| 349 |
-
|
| 350 |
-
///////////////////////////////////////////////////////////////////////////////////////////////////
|
| 351 |
-
|
| 352 |
-
"""
|
| 353 |
-
|
| 354 |
-
def operation_is_3x(self, operation):
|
| 355 |
-
"""Whether operation is a CUTLASS 3 convolution (as opposed to CUTLASS 2)"""
|
| 356 |
-
return hasattr(operation, 'is_3x') and operation.is_3x
|
| 357 |
-
|
| 358 |
-
def __enter__(self):
|
| 359 |
-
"""
|
| 360 |
-
Open the configuration_file, and write the "header" C++ code to it.
|
| 361 |
-
|
| 362 |
-
The "header" consists of a comment (that this is generated code,
|
| 363 |
-
so it should not be edited), and includes that are common
|
| 364 |
-
to both the CUTLASS 2 and the CUTLASS 3 cases.
|
| 365 |
-
"""
|
| 366 |
-
_LOGGER.debug('*** EmitConv3dConfigurationLibrary::__enter__')
|
| 367 |
-
_LOGGER.debug('*** configuration_path (file to write): ' +
|
| 368 |
-
str(self.configuration_path))
|
| 369 |
-
_LOGGER.debug('*** configuration_name: ' + self.configuration_name)
|
| 370 |
-
self.configuration_file = open(self.configuration_path, "w")
|
| 371 |
-
|
| 372 |
-
self.configuration_file.write(SubstituteTemplate(self.header_template, {
|
| 373 |
-
'configuration_name': self.configuration_name
|
| 374 |
-
}))
|
| 375 |
-
self.operations = []
|
| 376 |
-
return self
|
| 377 |
-
|
| 378 |
-
def emit(self, operation):
|
| 379 |
-
"""
|
| 380 |
-
Write three pieces of C++ code to the configuration_file
|
| 381 |
-
(that was opened by the __enter__ method above):
|
| 382 |
-
|
| 383 |
-
1. the header includes that are specific to the operation
|
| 384 |
-
(CUTLASS 2 vs. CUTLASS 3);
|
| 385 |
-
|
| 386 |
-
2. the "operation instance" (a "using" declaration ending in "_base"); and
|
| 387 |
-
|
| 388 |
-
3. the "operation name" (declaration and definition of a derived class
|
| 389 |
-
of the above operation instance).
|
| 390 |
-
|
| 391 |
-
The "using" declaration turns a C++ class name, possibly namespace-qualified,
|
| 392 |
-
possibly also with angle brackets, into a C-style, easily demangled identifier.
|
| 393 |
-
"""
|
| 394 |
-
_LOGGER.debug('*** EmitConv3dConfigurationLibrary::emit')
|
| 395 |
-
_LOGGER.debug('*** operation.procedural_name(): ' + operation.procedural_name())
|
| 396 |
-
self.operations.append(operation)
|
| 397 |
-
|
| 398 |
-
self.configuration_file.write(self.includes_emitter.emit(operation))
|
| 399 |
-
|
| 400 |
-
stub_begin = ''
|
| 401 |
-
stub_end = ''
|
| 402 |
-
# It can be useful to stub (comment) out instantiations for testing.
|
| 403 |
-
# In this case, one need only set is_stub to True.
|
| 404 |
-
is_stub = False
|
| 405 |
-
if is_stub:
|
| 406 |
-
stub_begin = "// STUB for now\n#if 0"
|
| 407 |
-
stub_end = '#endif // 0'
|
| 408 |
-
|
| 409 |
-
self.configuration_file.write(Template(self.instance_template).substitute({
|
| 410 |
-
'configuration_name': self.configuration_name,
|
| 411 |
-
'operation_name': operation.procedural_name(),
|
| 412 |
-
'operation_instance': self.instance_emitter.emit(operation),
|
| 413 |
-
'stub_begin': stub_begin,
|
| 414 |
-
'stub_end': stub_end
|
| 415 |
-
}))
|
| 416 |
-
|
| 417 |
-
def __exit__(self, exception_type, exception_value, traceback):
|
| 418 |
-
"""
|
| 419 |
-
Write the rest of the C++ code to the configuration_file, and close the file.
|
| 420 |
-
|
| 421 |
-
The "rest of the C++ code" has the following components.
|
| 422 |
-
|
| 423 |
-
1. Configuration header: Open the namespace(s), and open the definition
|
| 424 |
-
of the "initialize_${configuration_name}" registration function
|
| 425 |
-
that registers the operation with the Manifest.
|
| 426 |
-
("Registration" helps turn C++ compile-time polymorphism
|
| 427 |
-
(via template parameters) into a run-time choice of parameters.)
|
| 428 |
-
|
| 429 |
-
2. Configuration instance: In the body of the registration function,
|
| 430 |
-
make a "using" declaration Operation_${operation_name} for the
|
| 431 |
-
operation type (which uses operation_name as its template argument).
|
| 432 |
-
Then, tell the manifest about the operation via a "manifest.append" call.
|
| 433 |
-
The argument of the call is a new instance of
|
| 434 |
-
"SomethingOperation<Operation_${operation_name}>"
|
| 435 |
-
(replace Something with a specific name).
|
| 436 |
-
|
| 437 |
-
3. Configuration epilogue: Close the definition of the registration function.
|
| 438 |
-
|
| 439 |
-
4. Epilogue template: Close the namespace(s).
|
| 440 |
-
"""
|
| 441 |
-
|
| 442 |
-
_LOGGER.debug('*** EmitConv3dConfigurationLibrary::__exit__')
|
| 443 |
-
_LOGGER.debug('*** configuration_path (file to write): ' +
|
| 444 |
-
str(self.configuration_path))
|
| 445 |
-
_LOGGER.debug('*** configuration_name: ' + self.configuration_name)
|
| 446 |
-
|
| 447 |
-
self.configuration_file.write(SubstituteTemplate(self.configuration_header, {
|
| 448 |
-
'configuration_name': self.configuration_name
|
| 449 |
-
}))
|
| 450 |
-
|
| 451 |
-
for operation in self.operations:
|
| 452 |
-
stub_begin = ''
|
| 453 |
-
stub_end = ''
|
| 454 |
-
# It can be useful to stub (comment) out instantiations for testing.
|
| 455 |
-
# In this case, one need only set is_stub to True.
|
| 456 |
-
is_stub = False
|
| 457 |
-
if is_stub:
|
| 458 |
-
stub_begin = "// STUB for now\n#if 0"
|
| 459 |
-
stub_end = "#endif // 0"
|
| 460 |
-
|
| 461 |
-
kernel_name = 'ImplicitGemmConvolution'
|
| 462 |
-
operation_wrapper = 'Conv3dOperation'
|
| 463 |
-
if self.operation_is_3x(operation):
|
| 464 |
-
kernel_name = 'ConvUniversalAdapter'
|
| 465 |
-
operation_wrapper = 'ConvOperation3x'
|
| 466 |
-
|
| 467 |
-
self.configuration_file.write(SubstituteTemplate(self.configuration_instance, {
|
| 468 |
-
'configuration_name': self.configuration_name,
|
| 469 |
-
'operation_name': operation.procedural_name(),
|
| 470 |
-
'kernel_name': kernel_name,
|
| 471 |
-
'operation_wrapper': operation_wrapper,
|
| 472 |
-
'stub_begin': stub_begin,
|
| 473 |
-
'stub_end': stub_end
|
| 474 |
-
}))
|
| 475 |
-
|
| 476 |
-
self.configuration_file.write(self.configuration_epilogue)
|
| 477 |
-
self.configuration_file.write(self.epilogue_template)
|
| 478 |
-
self.configuration_file.close()
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
###################################################################################################
|
| 482 |
-
###################################################################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch211-cxx11-cu130-aarch64-linux/include/third-party/cutlass/python/cutlass_library/conv3x_emitter.py
DELETED
|
@@ -1,250 +0,0 @@
|
|
| 1 |
-
#################################################################################################
|
| 2 |
-
#
|
| 3 |
-
# Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
-
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
#
|
| 6 |
-
# Redistribution and use in source and binary forms, with or without
|
| 7 |
-
# modification, are permitted provided that the following conditions are met:
|
| 8 |
-
#
|
| 9 |
-
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
-
# list of conditions and the following disclaimer.
|
| 11 |
-
#
|
| 12 |
-
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
-
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
-
# and/or other materials provided with the distribution.
|
| 15 |
-
#
|
| 16 |
-
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
-
# contributors may be used to endorse or promote products derived from
|
| 18 |
-
# this software without specific prior written permission.
|
| 19 |
-
#
|
| 20 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
-
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
-
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
-
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
-
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
-
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
-
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
-
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
-
#
|
| 31 |
-
#################################################################################################
|
| 32 |
-
|
| 33 |
-
"""
|
| 34 |
-
Utilities for emitting CUTLASS >= 3 convolution kernels
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
import enum
|
| 38 |
-
import os.path
|
| 39 |
-
import shutil
|
| 40 |
-
import logging
|
| 41 |
-
from string import Template
|
| 42 |
-
|
| 43 |
-
try:
|
| 44 |
-
import builtins
|
| 45 |
-
if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True:
|
| 46 |
-
raise ImportError("Disabling attempt to import cutlass_library")
|
| 47 |
-
from cutlass_library.library import *
|
| 48 |
-
except ImportError:
|
| 49 |
-
from library import *
|
| 50 |
-
|
| 51 |
-
_LOGGER = logging.getLogger(__name__)
|
| 52 |
-
|
| 53 |
-
###################################################################################################
|
| 54 |
-
#
|
| 55 |
-
# Emits single instances of a CUTLASS device-wide operator
|
| 56 |
-
#
|
| 57 |
-
###################################################################################################
|
| 58 |
-
|
| 59 |
-
class EmitConv3xInstance:
|
| 60 |
-
def __init__(self):
|
| 61 |
-
_LOGGER.debug("*** EmitConv3xInstance::__init__")
|
| 62 |
-
|
| 63 |
-
# Define epilogue type first, so that the mainloop type
|
| 64 |
-
# can use it with StageCountAutoCarveout.
|
| 65 |
-
self.template = """
|
| 66 |
-
|
| 67 |
-
// CUTLASS >= 3 convolution ${conv_kind_name} kernel instance "${operation_name}"
|
| 68 |
-
using ${operation_name}_epilogue =
|
| 69 |
-
typename cutlass::epilogue::collective::CollectiveBuilder<
|
| 70 |
-
${arch},
|
| 71 |
-
${opcode_class_epi},
|
| 72 |
-
${mma_tile_shape}, // mma tile shape
|
| 73 |
-
${cluster_shape}, // cluster shape
|
| 74 |
-
${epi_tile_mn},
|
| 75 |
-
${element_accumulator},
|
| 76 |
-
${element_compute},
|
| 77 |
-
${element_c}, ${layout_c}, 128 / cute::sizeof_bits_v<${element_c}>,
|
| 78 |
-
${element_d}, ${layout_d}, 128 / cute::sizeof_bits_v<${element_d}>,
|
| 79 |
-
${epilogue_schedule}
|
| 80 |
-
// , class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination<ElementD,ElementCompute>
|
| 81 |
-
>::CollectiveOp;
|
| 82 |
-
|
| 83 |
-
using ${operation_name}_mainloop =
|
| 84 |
-
typename cutlass::conv::collective::CollectiveBuilder<
|
| 85 |
-
${arch},
|
| 86 |
-
${opcode_class_main},
|
| 87 |
-
${conv_kind}, // kFprop, kDgrad, or kWgrad
|
| 88 |
-
${element_a}, ${layout_a}, 128 / cute::sizeof_bits_v<${element_a}>,
|
| 89 |
-
${element_b}, ${layout_b}, 128 / cute::sizeof_bits_v<${element_b}>,
|
| 90 |
-
${element_accumulator},
|
| 91 |
-
${mma_tile_shape}, // mma tile shape
|
| 92 |
-
${cluster_shape}, // cluster shape
|
| 93 |
-
${stages},
|
| 94 |
-
${kernel_schedule}
|
| 95 |
-
>::CollectiveOp;
|
| 96 |
-
|
| 97 |
-
using ${operation_name}_problem_shape = cutlass::conv::ConvProblemShape<${conv_kind}, ${operation_name}_mainloop::NumSpatialDimensions>;
|
| 98 |
-
|
| 99 |
-
// Unit tests call this "ConvKernel".
|
| 100 |
-
// Conv operator ${operation_name}
|
| 101 |
-
using ${operation_name}_base = cutlass::conv::kernel::ConvUniversal<
|
| 102 |
-
${operation_name}_problem_shape,
|
| 103 |
-
${operation_name}_mainloop,
|
| 104 |
-
${operation_name}_epilogue,
|
| 105 |
-
${tile_scheduler}
|
| 106 |
-
>;
|
| 107 |
-
"""
|
| 108 |
-
|
| 109 |
-
def arch_number_to_type(self, arch: int) -> str:
|
| 110 |
-
return f"cutlass::arch::Sm{arch}"
|
| 111 |
-
|
| 112 |
-
def mma_tile_shape(self, operation, cta_m, cta_n, cta_k) -> str:
|
| 113 |
-
mma_m = cta_m
|
| 114 |
-
mma_n = cta_n
|
| 115 |
-
mma_k = cta_k
|
| 116 |
-
|
| 117 |
-
if operation.arch >= 100:
|
| 118 |
-
# MmaTileShape (mma_m, mma_n, mma_k) is passed to kernel mainloop where
|
| 119 |
-
# mma_m = cta_m for 1sm version and mma_m = cta_m * 2 for 2sm version.
|
| 120 |
-
# If schedule is auto and cluster size is static and cta_m % 64 == 0 and cluster_m % 2 == 0, 2sm kernel version is allocated,
|
| 121 |
-
# otherwise 1sm kernel is allocated.
|
| 122 |
-
cta_m_per_mma_instruction = 1
|
| 123 |
-
if "2sm" in operation.procedural_name() :
|
| 124 |
-
cta_m_per_mma_instruction = 2
|
| 125 |
-
elif "1sm" in operation.procedural_name() :
|
| 126 |
-
cta_m_per_mma_instruction = 1
|
| 127 |
-
elif operation.tile_description.cluster_shape[0] > 0 and operation.tile_description.cluster_shape[0] % 2 == 0 and cta_m % 64 == 0 :
|
| 128 |
-
cta_m_per_mma_instruction = 2
|
| 129 |
-
mma_m = cta_m * cta_m_per_mma_instruction
|
| 130 |
-
|
| 131 |
-
# For all three kinds of convolutions, the tile shape's K mode
|
| 132 |
-
# differs from GEMM in that needs to be wrapped in a Shape.
|
| 133 |
-
# For Wgrad convolutions specifically,
|
| 134 |
-
# the N tile shape also needs to be wrapped in a Shape.
|
| 135 |
-
m_template = 'cute::_${mma_m}'
|
| 136 |
-
if operation.conv_kind == ConvKind.Wgrad:
|
| 137 |
-
n_template = 'cute::Shape<cute::_${mma_n}>'
|
| 138 |
-
else:
|
| 139 |
-
n_template = 'cute::_${mma_n}'
|
| 140 |
-
k_template = 'cute::Shape<cute::_${mma_k}>'
|
| 141 |
-
|
| 142 |
-
mma_tile_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
|
| 143 |
-
values = {
|
| 144 |
-
'mma_m': mma_m,
|
| 145 |
-
'mma_n': mma_n,
|
| 146 |
-
'mma_k': mma_k
|
| 147 |
-
}
|
| 148 |
-
return Template(mma_tile_shape_template).substitute(values)
|
| 149 |
-
|
| 150 |
-
def cluster_shape(self, operation) -> str:
|
| 151 |
-
m_template = 'cute::_${cluster_shape_m}' if operation.tile_description.cluster_shape[0] > 0 else 'int(0)'
|
| 152 |
-
n_template = 'cute::_${cluster_shape_n}' if operation.tile_description.cluster_shape[1] > 0 else 'int(0)'
|
| 153 |
-
k_template = 'cute::_${cluster_shape_k}' if operation.tile_description.cluster_shape[2] > 0 else 'int(0)'
|
| 154 |
-
cluster_shape_template = f'cute::Shape<{m_template}, {n_template}, {k_template}>'
|
| 155 |
-
values = {
|
| 156 |
-
'cluster_shape_m': operation.tile_description.cluster_shape[0],
|
| 157 |
-
'cluster_shape_n': operation.tile_description.cluster_shape[1],
|
| 158 |
-
'cluster_shape_k': operation.tile_description.cluster_shape[2],
|
| 159 |
-
}
|
| 160 |
-
return Template(cluster_shape_template).substitute(values)
|
| 161 |
-
|
| 162 |
-
def stage_count(self, operation) -> str:
|
| 163 |
-
# stages == 0 tells builder to pick the number of stages automatically
|
| 164 |
-
namespace_prefix = 'cutlass::conv::collective::'
|
| 165 |
-
if operation.tile_description.stages > 0:
|
| 166 |
-
return f"{namespace_prefix}StageCount<{str(operation.tile_description.stages)}>"
|
| 167 |
-
else:
|
| 168 |
-
return f"{namespace_prefix}StageCountAutoCarveout<sizeof(typename {operation.procedural_name()}_epilogue::SharedStorage)>"
|
| 169 |
-
|
| 170 |
-
def emit(self, operation) -> str:
|
| 171 |
-
_LOGGER.debug("*** EmitConv3xInstance::emit")
|
| 172 |
-
_LOGGER.debug("*** operation: procedural_name()=" + operation.procedural_name())
|
| 173 |
-
|
| 174 |
-
# Identify the operation as CUTLASS 3 by its is_3x field
|
| 175 |
-
if (not hasattr(operation, 'is_3x')) or (not operation.is_3x):
|
| 176 |
-
raise RuntimeError("operation must be a CUTLASS 3 operation")
|
| 177 |
-
|
| 178 |
-
epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto"
|
| 179 |
-
opcode_class_main = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class]
|
| 180 |
-
opcode_class_epi = opcode_class_main
|
| 181 |
-
|
| 182 |
-
tile_shape = operation.tile_description.tile_shape
|
| 183 |
-
cluster_m = operation.tile_description.cluster_shape[0]
|
| 184 |
-
cluster_n = operation.tile_description.cluster_shape[1]
|
| 185 |
-
|
| 186 |
-
cta_m, cta_n, cta_k = tile_shape
|
| 187 |
-
# account for static/dynamic cluster shapes
|
| 188 |
-
if operation.arch >= 100:
|
| 189 |
-
cta_m = cta_m // cluster_m if cluster_m > 0 else cta_m
|
| 190 |
-
cta_n = cta_n // cluster_n if cluster_n > 0 else cta_n
|
| 191 |
-
|
| 192 |
-
warp_count = operation.tile_description.warp_count
|
| 193 |
-
epilogue_schedule = EpilogueScheduleTag[operation.epilogue_schedule]
|
| 194 |
-
|
| 195 |
-
# KernelScheduleTag and TileSchedulerTag both hard-code the
|
| 196 |
-
# namespace qualification of KernelScheduleAuto as
|
| 197 |
-
# "cutlass::gemm::collective::" (unless the tag is 'void').
|
| 198 |
-
#
|
| 199 |
-
# For TileSchedulerTag, this namespace is fine, since CUTLASS 3
|
| 200 |
-
# convolutions use the same tile schedulers (from the same
|
| 201 |
-
# cutlass::gemm::collective namespace) as GEMMs.
|
| 202 |
-
kernel_schedule = KernelScheduleTag[operation.kernel_schedule].replace('gemm::', 'conv::')
|
| 203 |
-
tile_scheduler = TileSchedulerTag[operation.tile_scheduler]
|
| 204 |
-
opcode_class = OpcodeClassTag[operation.tile_description.math_instruction.opcode_class]
|
| 205 |
-
|
| 206 |
-
values = {
|
| 207 |
-
'operation_name': operation.procedural_name(),
|
| 208 |
-
'conv_kind': ConvKindTag[operation.conv_kind],
|
| 209 |
-
'conv_kind_name': ConvKindNames[operation.conv_kind].capitalize(),
|
| 210 |
-
'element_a': DataTypeTag[operation.A.element],
|
| 211 |
-
'layout_a': LayoutTag[operation.A.layout],
|
| 212 |
-
'align_a': int(operation.A.alignment),
|
| 213 |
-
'element_b': DataTypeTag[operation.B.element],
|
| 214 |
-
'layout_b': LayoutTag[operation.B.layout],
|
| 215 |
-
'align_b': int(operation.B.alignment),
|
| 216 |
-
'element_c': DataTypeTag[operation.C.element],
|
| 217 |
-
'layout_c': LayoutTag[operation.C.layout],
|
| 218 |
-
'align_c': int(operation.C.alignment),
|
| 219 |
-
'element_d': DataTypeTag[operation.D.element],
|
| 220 |
-
'layout_d': LayoutTag[operation.D.layout],
|
| 221 |
-
'align_d': int(operation.D.alignment),
|
| 222 |
-
'element_accumulator': DataTypeTag[operation.accumulator_type()],
|
| 223 |
-
'opcode_class': opcode_class,
|
| 224 |
-
'arch': self.arch_number_to_type(operation.arch),
|
| 225 |
-
'mma_tile_shape': self.mma_tile_shape(operation, cta_m, cta_n, cta_k),
|
| 226 |
-
'cluster_shape': self.cluster_shape(operation),
|
| 227 |
-
'opcode_class_epi': opcode_class_epi,
|
| 228 |
-
'opcode_class_main': opcode_class_main,
|
| 229 |
-
'epi_tile_mn': epi_tile_mn,
|
| 230 |
-
'stages': self.stage_count(operation),
|
| 231 |
-
'kernel_schedule': kernel_schedule,
|
| 232 |
-
'epilogue_schedule': epilogue_schedule,
|
| 233 |
-
'tile_scheduler': tile_scheduler,
|
| 234 |
-
'element_compute': DataTypeTag[operation.element_compute]
|
| 235 |
-
}
|
| 236 |
-
return Template(self.template).substitute(values)
|
| 237 |
-
|
| 238 |
-
class EmitConv3xIncludes:
|
| 239 |
-
def __init__(self):
|
| 240 |
-
_LOGGER.debug("*** EmitConv3xIncludes::__init__")
|
| 241 |
-
self.includes = ['conv_operation_3x.hpp',
|
| 242 |
-
'cutlass/conv/device/conv_universal_adapter.hpp',
|
| 243 |
-
'cutlass/conv/kernel/conv_universal.hpp',
|
| 244 |
-
'cutlass/conv/collective/collective_builder.hpp',
|
| 245 |
-
'cutlass/epilogue/collective/collective_builder.hpp']
|
| 246 |
-
|
| 247 |
-
def emit(self, operation) -> str:
|
| 248 |
-
_LOGGER.debug("*** EmitConv3xIncludes::emit")
|
| 249 |
-
return '\n'.join(f"#include \"{incl}\"" for incl in self.includes) + \
|
| 250 |
-
"\n\n///////////////////////////////////////////////////////////////////////////////////////////////////"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|