Build uploaded using `kernels` (batch 5/10).
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +1 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py +173 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py +142 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py +319 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py +180 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py +44 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py +235 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py +134 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py +128 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py +146 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py +104 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py +103 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py +71 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py +112 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py +75 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py +103 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py +98 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py +423 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py +44 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py +260 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/installation.py +57 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py +284 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py +254 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py +354 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py +69 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py +75 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py +95 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py +92 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py +213 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py +80 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py +87 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py +96 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py +59 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h +102 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h +907 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h +927 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h +818 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h +666 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h +622 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h +734 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h +643 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h +293 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h +716 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h +732 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h +473 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp +1385 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp +768 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp +158 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp +775 -0
- build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp +217 -0
.gitattributes
CHANGED
|
@@ -14,3 +14,4 @@ build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lf
|
|
| 14 |
build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 15 |
build/torch210-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 16 |
build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 14 |
build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 15 |
build/torch210-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 16 |
build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit test for store nodes in SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen.backend import *
|
| 42 |
+
from cutlass_cppgen.epilogue import *
|
| 43 |
+
|
| 44 |
+
from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
| 45 |
+
|
| 46 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
| 50 |
+
class TestEVTLayout(EVTTestCaseBase):
|
| 51 |
+
|
| 52 |
+
def test_permute_1(self):
|
| 53 |
+
"""
|
| 54 |
+
Returning a tensor with shape [m, n]
|
| 55 |
+
"""
|
| 56 |
+
def evt_permute(accum, alpha, C):
|
| 57 |
+
F = alpha * accum
|
| 58 |
+
F_permute = permute(F, indices=(0, 2, 1))
|
| 59 |
+
D_permute = F_permute + permute(C, indices=(0, 2, 1))
|
| 60 |
+
D = permute(D_permute, indices=(0, 2, 1))
|
| 61 |
+
return D, F
|
| 62 |
+
|
| 63 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 64 |
+
example_inputs = {
|
| 65 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 66 |
+
"alpha": 0.5,
|
| 67 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 68 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 69 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
launcher = EVTTestBed(self.element, evt_permute, example_inputs)
|
| 73 |
+
input_keys = ["C", "alpha"]
|
| 74 |
+
result_keys = ["D", "F"]
|
| 75 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 76 |
+
|
| 77 |
+
@unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
|
| 78 |
+
def test_permute_2(self):
|
| 79 |
+
"""
|
| 80 |
+
Returning a tensor with shape [m, n]
|
| 81 |
+
"""
|
| 82 |
+
def evt_permute(accum, alpha, C):
|
| 83 |
+
F = alpha * accum
|
| 84 |
+
F_permute = permute(F, indices=(0, 2, 1))
|
| 85 |
+
D = F_permute + C
|
| 86 |
+
return D, F
|
| 87 |
+
|
| 88 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 89 |
+
example_inputs = {
|
| 90 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 91 |
+
"alpha": 0.5,
|
| 92 |
+
"C": self.fake_tensor(self.element, (l, n, m)),
|
| 93 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 94 |
+
"D": self.fake_tensor(self.element, (l, n, m)),
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
launcher = EVTTestBed(self.element, evt_permute, example_inputs)
|
| 98 |
+
input_keys = ["C", "alpha"]
|
| 99 |
+
result_keys = ["D", "F"]
|
| 100 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 101 |
+
|
| 102 |
+
@unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
|
| 103 |
+
def test_permute_3(self):
|
| 104 |
+
"""
|
| 105 |
+
Returning a tensor with shape [m, n]
|
| 106 |
+
"""
|
| 107 |
+
def evt_permute(accum, alpha, C):
|
| 108 |
+
F = alpha * accum
|
| 109 |
+
F_permute = permute(F, indices=(1, 0, 2))
|
| 110 |
+
D = F_permute + C
|
| 111 |
+
return D, F
|
| 112 |
+
|
| 113 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 114 |
+
example_inputs = {
|
| 115 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 116 |
+
"alpha": 0.5,
|
| 117 |
+
"C": self.fake_tensor(self.element, (m, l, n)),
|
| 118 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 119 |
+
"D": self.fake_tensor(self.element, (m, l, n)),
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
launcher = EVTTestBed(self.element, evt_permute, example_inputs)
|
| 123 |
+
input_keys = ["C", "alpha"]
|
| 124 |
+
result_keys = ["D", "F"]
|
| 125 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 126 |
+
|
| 127 |
+
def test_reshape(self):
|
| 128 |
+
"""
|
| 129 |
+
Test reshape
|
| 130 |
+
"""
|
| 131 |
+
def evt_reshape(accum, alpha, TensorE):
|
| 132 |
+
F = alpha * accum
|
| 133 |
+
E_reshape = reshape(TensorE, new_shape=(512, 1))
|
| 134 |
+
D = F + E_reshape
|
| 135 |
+
return D
|
| 136 |
+
|
| 137 |
+
example_inputs = {
|
| 138 |
+
"accum": self.fake_tensor(self.element, (self.l, self.m, self.n)),
|
| 139 |
+
"alpha": 0.5,
|
| 140 |
+
"TensorE": self.fake_tensor(self.element, (16, 32)),
|
| 141 |
+
"D": self.fake_tensor(self.element, (self.l, self.m, self.n)),
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
launcher = EVTTestBed(self.element, evt_reshape, example_inputs)
|
| 145 |
+
input_keys = ["alpha", "TensorE"]
|
| 146 |
+
result_keys = ["D"]
|
| 147 |
+
launcher.verify(self.problem_size, input_keys, result_keys, self.l)
|
| 148 |
+
|
| 149 |
+
def test_reshape2(self):
|
| 150 |
+
"""
|
| 151 |
+
Test reshape
|
| 152 |
+
"""
|
| 153 |
+
def evt_reshape(accum, alpha, TensorE):
|
| 154 |
+
F = alpha * accum
|
| 155 |
+
F_reshape = reshape(F, new_shape=(2, 3, 512, 256))
|
| 156 |
+
D = F_reshape + TensorE
|
| 157 |
+
return D
|
| 158 |
+
|
| 159 |
+
example_inputs = {
|
| 160 |
+
"accum": self.fake_tensor(self.element, (self.l, self.m, self.n)),
|
| 161 |
+
"alpha": 0.5,
|
| 162 |
+
"TensorE": self.fake_tensor(self.element, (2, 3, 1, self.n)),
|
| 163 |
+
"D": self.fake_tensor(self.element, (2, 3, self.m, self.n)),
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
launcher = EVTTestBed(self.element, evt_reshape, example_inputs)
|
| 167 |
+
input_keys = ["alpha", "TensorE"]
|
| 168 |
+
result_keys = ["D"]
|
| 169 |
+
launcher.verify(self.problem_size, input_keys, result_keys, self.l)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
if __name__ == '__main__':
|
| 173 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit test for load nodes in SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen.backend import *
|
| 42 |
+
from cutlass_cppgen.epilogue import *
|
| 43 |
+
|
| 44 |
+
from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
| 45 |
+
|
| 46 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
| 50 |
+
class TestEVTLoad(EVTTestCaseBase):
|
| 51 |
+
|
| 52 |
+
def test_tensor_load(self):
|
| 53 |
+
"""
|
| 54 |
+
Load extra tensor with shape [m, n]
|
| 55 |
+
"""
|
| 56 |
+
def evt_tensor_load(accum, C, aux, aux_batch):
|
| 57 |
+
D = accum + C + aux + aux_batch
|
| 58 |
+
return D
|
| 59 |
+
|
| 60 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 61 |
+
example_inputs = {
|
| 62 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 63 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 64 |
+
"aux": self.fake_tensor(self.element, (m, n)),
|
| 65 |
+
"aux_batch": self.fake_tensor(np.float32, (l, m, n)),
|
| 66 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
launcher = EVTTestBed(self.element, evt_tensor_load, example_inputs)
|
| 70 |
+
input_keys = ["C", "aux", "aux_batch"]
|
| 71 |
+
result_keys = ["D"]
|
| 72 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 73 |
+
|
| 74 |
+
def test_row_broadcast(self):
|
| 75 |
+
"""
|
| 76 |
+
Load extra tensor with shape [1, n]
|
| 77 |
+
"""
|
| 78 |
+
def evt_row_broadcast(accum, C, bias, bias_batch):
|
| 79 |
+
D = accum + C + bias + bias_batch
|
| 80 |
+
return D
|
| 81 |
+
|
| 82 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 83 |
+
example_inputs = {
|
| 84 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 85 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 86 |
+
"bias": self.fake_tensor(self.element, (n,)),
|
| 87 |
+
"bias_batch": self.fake_tensor(np.float32, (l, 1, n)),
|
| 88 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
launcher = EVTTestBed(self.element, evt_row_broadcast, example_inputs)
|
| 92 |
+
input_keys = ["C", "bias", "bias_batch"]
|
| 93 |
+
result_keys = ["D"]
|
| 94 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 95 |
+
|
| 96 |
+
def test_column_broadcast(self):
|
| 97 |
+
"""
|
| 98 |
+
Load extra tensor with shape [m, 1]
|
| 99 |
+
"""
|
| 100 |
+
def evt_column_broadcast(accum, C, bias, bias_batch):
|
| 101 |
+
D = accum + C + bias + bias_batch
|
| 102 |
+
return D
|
| 103 |
+
|
| 104 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 105 |
+
example_inputs = {
|
| 106 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 107 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 108 |
+
"bias": self.fake_tensor(self.element, (m, 1)),
|
| 109 |
+
"bias_batch": self.fake_tensor(np.float32, (l, m, 1)),
|
| 110 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
launcher = EVTTestBed(self.element, evt_column_broadcast, example_inputs)
|
| 114 |
+
input_keys = ["C", "bias", "bias_batch"]
|
| 115 |
+
result_keys = ["D"]
|
| 116 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 117 |
+
|
| 118 |
+
def test_scalar_broadcast(self):
|
| 119 |
+
"""
|
| 120 |
+
Load extra tensor with shape [1, 1]
|
| 121 |
+
"""
|
| 122 |
+
def evt_scalar_broadcast(accum, C, alpha, alpha_batch):
|
| 123 |
+
D = accum + C + alpha + alpha_batch
|
| 124 |
+
return D
|
| 125 |
+
|
| 126 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 127 |
+
example_inputs = {
|
| 128 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 129 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 130 |
+
"alpha": 0.5,
|
| 131 |
+
"alpha_batch": self.fake_tensor(np.float32, (l, 1, 1)),
|
| 132 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
launcher = EVTTestBed(self.element, evt_scalar_broadcast, example_inputs)
|
| 136 |
+
input_keys = ["C", "alpha", "alpha_batch"]
|
| 137 |
+
result_keys = ["D"]
|
| 138 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == '__main__':
|
| 142 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unittest for mixed types of nodes in SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen.backend import *
|
| 42 |
+
from cutlass_cppgen.epilogue import *
|
| 43 |
+
from cutlass_cppgen.swizzle import ThreadblockSwizzleStreamK
|
| 44 |
+
|
| 45 |
+
from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
| 51 |
+
class TestEVTMixed(EVTTestCaseBase):
|
| 52 |
+
|
| 53 |
+
def test_same_variable_used_multiple_times(self):
|
| 54 |
+
"""
|
| 55 |
+
The same variable z0 is used multiple times
|
| 56 |
+
"""
|
| 57 |
+
def evt_aux_store(accum):
|
| 58 |
+
z0 = relu(accum)
|
| 59 |
+
D = z0 + z0
|
| 60 |
+
return z0, D
|
| 61 |
+
|
| 62 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 63 |
+
example_inputs = {
|
| 64 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 65 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 66 |
+
"z0": self.fake_tensor(self.element, (l, m, n)),
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
launcher = EVTTestBed(self.element, evt_aux_store, example_inputs)
|
| 70 |
+
input_keys = ["accum"]
|
| 71 |
+
result_keys = ["z0", "D"]
|
| 72 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 73 |
+
|
| 74 |
+
def test_no_lca(self):
|
| 75 |
+
"""
|
| 76 |
+
The same variable z0 is used multiple times
|
| 77 |
+
"""
|
| 78 |
+
def evt_no_lca(accum, bias):
|
| 79 |
+
E = relu(accum)
|
| 80 |
+
F = E + bias
|
| 81 |
+
tmp_2 = E + 2
|
| 82 |
+
D = tmp_2 + E
|
| 83 |
+
return D
|
| 84 |
+
|
| 85 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 86 |
+
example_inputs = {
|
| 87 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 88 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 89 |
+
"bias": self.fake_tensor(self.element, (m,1), stride=(1,0)),
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
launcher = EVTTestBed(self.element, evt_no_lca, example_inputs)
|
| 93 |
+
input_keys = ["accum", "bias"]
|
| 94 |
+
result_keys = ["D"]
|
| 95 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 96 |
+
|
| 97 |
+
def test_mixed_dag(self):
|
| 98 |
+
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
| 99 |
+
F = alpha * accum + (beta * C + aux)
|
| 100 |
+
F_row_max = max(F, dim=[0, 1])
|
| 101 |
+
E = relu(F + 1) + cbias + rbias
|
| 102 |
+
E_col_max = max(E, dim=[0, 2])
|
| 103 |
+
D = E + F
|
| 104 |
+
return D, F, F_row_max, E_col_max
|
| 105 |
+
|
| 106 |
+
if device_cc() == 80:
|
| 107 |
+
alignments = [2, 4, 8]
|
| 108 |
+
else:
|
| 109 |
+
# Sm90 EVT currently only supports 128-bit alignment
|
| 110 |
+
alignments = [8,]
|
| 111 |
+
for align in alignments:
|
| 112 |
+
for m, n, k, l in self.get_problem_sizes(align):
|
| 113 |
+
example_inputs = {
|
| 114 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 115 |
+
"alpha": 1.0,
|
| 116 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 117 |
+
"beta": 1.0,
|
| 118 |
+
"aux": self.fake_tensor(self.element, (l, m, n)),
|
| 119 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 120 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 121 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 122 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 123 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 124 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs)
|
| 128 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 129 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 130 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 131 |
+
|
| 132 |
+
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
| 133 |
+
def test_mixed_dag_float(self):
|
| 134 |
+
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
| 135 |
+
F = alpha * accum + (beta * C + aux)
|
| 136 |
+
F_row_max = max(F, dim=[0, 1])
|
| 137 |
+
E = relu(F + 1) + cbias + rbias
|
| 138 |
+
E_col_max = max(E, dim=[0, 2])
|
| 139 |
+
D = E + F
|
| 140 |
+
return D, F, F_row_max, E_col_max
|
| 141 |
+
|
| 142 |
+
for align in [3, 2, 4]:
|
| 143 |
+
for m, n, k, l in self.get_problem_sizes(align):
|
| 144 |
+
example_inputs = {
|
| 145 |
+
"accum": self.fake_tensor(np.float32, (l, m, n)),
|
| 146 |
+
"alpha": 1.0,
|
| 147 |
+
"C": self.fake_tensor(np.float32, (l, m, n)),
|
| 148 |
+
"beta": 1.0,
|
| 149 |
+
"aux": self.fake_tensor(np.float32, (l, m, n)),
|
| 150 |
+
"cbias": self.fake_tensor(np.float32, (m, 1)),
|
| 151 |
+
"rbias": self.fake_tensor(np.float32, (n,)),
|
| 152 |
+
"D": self.fake_tensor(np.float32, (l, m, n)),
|
| 153 |
+
"F": self.fake_tensor(np.float32, (l, m, n)),
|
| 154 |
+
"F_row_max": self.fake_tensor(np.float32, (n,)),
|
| 155 |
+
"E_col_max": self.fake_tensor(np.float32, (m, 1))
|
| 156 |
+
}
|
| 157 |
+
launcher = EVTTestBed(DataType.f32, evt_mixed_dag, example_inputs)
|
| 158 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 159 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 160 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 161 |
+
|
| 162 |
+
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
| 163 |
+
def test_mixed_dag_stage2(self):
|
| 164 |
+
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
| 165 |
+
F = alpha * accum + (beta * C + aux)
|
| 166 |
+
F_row_max = max(F, dim=[0, 1])
|
| 167 |
+
E = relu(F + 1) + cbias + rbias
|
| 168 |
+
E_col_max = max(E, dim=[0, 2])
|
| 169 |
+
D = E + F
|
| 170 |
+
return D, F, F_row_max, E_col_max
|
| 171 |
+
|
| 172 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 173 |
+
example_inputs = {
|
| 174 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 175 |
+
"alpha": 1.0,
|
| 176 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 177 |
+
"beta": 1.0,
|
| 178 |
+
"aux": self.fake_tensor(self.element, (l, m, n)),
|
| 179 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 180 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 181 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 182 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 183 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 184 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, epilogue_stages=2)
|
| 188 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 189 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 190 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 191 |
+
|
| 192 |
+
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
| 193 |
+
def test_mixed_dag_partition_k(self):
|
| 194 |
+
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
| 195 |
+
F = alpha * accum + (beta * C + aux)
|
| 196 |
+
F_row_max = max(F, dim=[0, 1])
|
| 197 |
+
E = relu(F + 1) + cbias + rbias
|
| 198 |
+
E_col_max = max(E, dim=[0, 2])
|
| 199 |
+
D = E + F
|
| 200 |
+
return D, F, F_row_max, E_col_max
|
| 201 |
+
|
| 202 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 203 |
+
example_inputs = {
|
| 204 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 205 |
+
"alpha": 1.0,
|
| 206 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 207 |
+
"beta": 1.0,
|
| 208 |
+
"aux": self.fake_tensor(self.element, (l, m, n)),
|
| 209 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 210 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 211 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 212 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 213 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 214 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
tile_description = {
|
| 218 |
+
"threadblock_shape": [128, 128, 64],
|
| 219 |
+
"warp_count": [2, 2, 2]
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, tile_description=tile_description, epilogue_stages=2)
|
| 223 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 224 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 225 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 226 |
+
|
| 227 |
+
@unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
|
| 228 |
+
def test_mixed_dag_stream_k(self):
|
| 229 |
+
def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
|
| 230 |
+
F = alpha * accum + (beta * C + aux)
|
| 231 |
+
F_row_max = max(F, dim=[0, 1])
|
| 232 |
+
E = relu(F + 1) + cbias + rbias
|
| 233 |
+
E_col_max = max(E, dim=[0, 2])
|
| 234 |
+
D = E + F
|
| 235 |
+
return D, F, F_row_max, E_col_max
|
| 236 |
+
|
| 237 |
+
# High per-sm occupancy tile_description
|
| 238 |
+
tile_description = {
|
| 239 |
+
"threadblock_shape": [128, 128, 32],
|
| 240 |
+
"warp_count": [2, 2, 1],
|
| 241 |
+
"stages": 3
|
| 242 |
+
}
|
| 243 |
+
tds = [None, tile_description]
|
| 244 |
+
for td in tds:
|
| 245 |
+
for m, n, k, l in self.get_problem_sizes(8, k=960, batch_count=[1, 3]):
|
| 246 |
+
if l == 1:
|
| 247 |
+
example_inputs = {
|
| 248 |
+
"accum": self.fake_tensor(self.element, (m, n)),
|
| 249 |
+
"alpha": 1.0,
|
| 250 |
+
"C": self.fake_tensor(self.element, (m, n)),
|
| 251 |
+
"beta": 1.0,
|
| 252 |
+
"aux": self.fake_tensor(self.element, (m, n)),
|
| 253 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 254 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 255 |
+
"D": self.fake_tensor(self.element, (m, n)),
|
| 256 |
+
"F": self.fake_tensor(self.element, (m, n)),
|
| 257 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 258 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 259 |
+
}
|
| 260 |
+
else:
|
| 261 |
+
example_inputs = {
|
| 262 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 263 |
+
"alpha": 1.0,
|
| 264 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 265 |
+
"beta": 1.0,
|
| 266 |
+
"aux": self.fake_tensor(self.element, (l, m, n)),
|
| 267 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 268 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 269 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 270 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 271 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 272 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
if td is not None:
|
| 276 |
+
launcher = EVTTestBed(
|
| 277 |
+
self.element, evt_mixed_dag, example_inputs,
|
| 278 |
+
tile_description=td,
|
| 279 |
+
swizzling_functor=ThreadblockSwizzleStreamK, backend="torch")
|
| 280 |
+
else:
|
| 281 |
+
launcher = EVTTestBed(
|
| 282 |
+
self.element, evt_mixed_dag, example_inputs,
|
| 283 |
+
swizzling_functor=ThreadblockSwizzleStreamK, backend="torch")
|
| 284 |
+
|
| 285 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 286 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 287 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 288 |
+
|
| 289 |
+
def test_mixed_dag_no_batch(self):
|
| 290 |
+
def evt_mixed_dag_no_batch(accum, alpha, C, beta, aux, cbias, rbias):
|
| 291 |
+
F = alpha * accum + (beta * C + aux)
|
| 292 |
+
F_row_max = max(F, dim=[0, 1])
|
| 293 |
+
E = relu(F + 1) + cbias + rbias
|
| 294 |
+
E_col_max = max(E, dim=[0, 2])
|
| 295 |
+
D = E + F
|
| 296 |
+
return D, F, F_row_max, E_col_max
|
| 297 |
+
|
| 298 |
+
for m, n, k, _ in self.get_problem_sizes(8):
|
| 299 |
+
example_inputs = {
|
| 300 |
+
"accum": self.fake_tensor(self.element, (m, n)),
|
| 301 |
+
"alpha": 1.0,
|
| 302 |
+
"C": self.fake_tensor(self.element, (m, n)),
|
| 303 |
+
"beta": 1.0,
|
| 304 |
+
"aux": self.fake_tensor(self.element, (m, n)),
|
| 305 |
+
"cbias": self.fake_tensor(self.element, (m, 1)),
|
| 306 |
+
"rbias": self.fake_tensor(self.element, (n,)),
|
| 307 |
+
"D": self.fake_tensor(self.element, (m, n)),
|
| 308 |
+
"F": self.fake_tensor(self.element, (m, n)),
|
| 309 |
+
"F_row_max": self.fake_tensor(DataType.f32, (n,)),
|
| 310 |
+
"E_col_max": self.fake_tensor(DataType.f32, (m, 1))
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
launcher = EVTTestBed(self.element, evt_mixed_dag_no_batch, example_inputs)
|
| 314 |
+
input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
|
| 315 |
+
result_keys = ["D", "F", "F_row_max", "E_col_max"]
|
| 316 |
+
launcher.verify((m, n, k), input_keys, result_keys, 1)
|
| 317 |
+
|
| 318 |
+
if __name__ == '__main__':
|
| 319 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit test for store nodes in SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen.backend import *
|
| 42 |
+
from cutlass_cppgen.epilogue import *
|
| 43 |
+
|
| 44 |
+
from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
|
| 45 |
+
|
| 46 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
|
| 50 |
+
class TestEVTStore(EVTTestCaseBase):
|
| 51 |
+
|
| 52 |
+
@unittest.skipIf(device_cc() != 90, "This test is only for CC 90")
|
| 53 |
+
def test_invalid_store(self):
|
| 54 |
+
"""
|
| 55 |
+
Test invalid store
|
| 56 |
+
"""
|
| 57 |
+
def evt_invalid_store(accum):
|
| 58 |
+
D = accum
|
| 59 |
+
F = D + 1 # D has users, which is not allowed on SM90 or higher
|
| 60 |
+
return D, F
|
| 61 |
+
|
| 62 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 63 |
+
example_inputs = {
|
| 64 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 65 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 66 |
+
"F": self.fake_tensor(self.element, (l, m, n))
|
| 67 |
+
}
|
| 68 |
+
with self.assertRaisesRegex(
|
| 69 |
+
RuntimeError,
|
| 70 |
+
r"On SM90 or higher, D is expected to be a output node with 0 users "
|
| 71 |
+
r"to enable smem reuse between C and D, but got 1"
|
| 72 |
+
):
|
| 73 |
+
launcher = EVTTestBed(self.element, evt_invalid_store, example_inputs)
|
| 74 |
+
|
| 75 |
+
break # Only need to test once
|
| 76 |
+
|
| 77 |
+
def test_aux_store(self):
|
| 78 |
+
"""
|
| 79 |
+
Returning a tensor with shape [m, n]
|
| 80 |
+
"""
|
| 81 |
+
def evt_aux_store(accum, alpha, C):
|
| 82 |
+
F = alpha * accum
|
| 83 |
+
D = F + C
|
| 84 |
+
return D, F
|
| 85 |
+
|
| 86 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 87 |
+
example_inputs = {
|
| 88 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 89 |
+
"alpha": 0.5,
|
| 90 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 91 |
+
"F": self.fake_tensor(self.element, (l, m, n)),
|
| 92 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
launcher = EVTTestBed(self.element, evt_aux_store, example_inputs)
|
| 96 |
+
input_keys = ["C", "alpha"]
|
| 97 |
+
result_keys = ["D", "F"]
|
| 98 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 99 |
+
|
| 100 |
+
def test_col_reduce(self):
|
| 101 |
+
"""
|
| 102 |
+
Reduction [m, n] -> [m, 1]
|
| 103 |
+
"""
|
| 104 |
+
def evt_row_reduce(accum, alpha, C):
|
| 105 |
+
acc_row_max = max(accum, dim=[2,])
|
| 106 |
+
F = alpha * accum
|
| 107 |
+
F_row_max = max(F, dim=[0, 2])
|
| 108 |
+
D = F + C
|
| 109 |
+
return D, F_row_max, acc_row_max
|
| 110 |
+
|
| 111 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 112 |
+
example_inputs = {
|
| 113 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 114 |
+
"alpha": 2.0,
|
| 115 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 116 |
+
"F_row_max": self.fake_tensor(np.float32, (m, 1)),
|
| 117 |
+
"acc_row_max": self.fake_tensor(np.float32, (l, m, 1)),
|
| 118 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
launcher = EVTTestBed(self.element, evt_row_reduce, example_inputs)
|
| 122 |
+
input_keys = ["C", "alpha"]
|
| 123 |
+
result_keys = ["D", "F_row_max", "acc_row_max"]
|
| 124 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 125 |
+
|
| 126 |
+
def test_row_reduce(self):
|
| 127 |
+
"""
|
| 128 |
+
Reduction [m, n] -> [n]
|
| 129 |
+
"""
|
| 130 |
+
def evt_col_reduce(accum, alpha, C):
|
| 131 |
+
acc_col_max = max(accum, dim=[1,])
|
| 132 |
+
F = alpha * accum
|
| 133 |
+
F_col_max = max(F, dim=[0, 1])
|
| 134 |
+
D = F + C
|
| 135 |
+
return D, F_col_max, acc_col_max
|
| 136 |
+
|
| 137 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 138 |
+
example_inputs = {
|
| 139 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 140 |
+
"alpha": 2.0,
|
| 141 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 142 |
+
"F_col_max": self.fake_tensor(np.float32, (n,)),
|
| 143 |
+
"acc_col_max": self.fake_tensor(np.float32, (l, 1, n)),
|
| 144 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
launcher = EVTTestBed(self.element, evt_col_reduce, example_inputs)
|
| 148 |
+
input_keys = ["C", "alpha"]
|
| 149 |
+
result_keys = ["D", "F_col_max", "acc_col_max"]
|
| 150 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 151 |
+
|
| 152 |
+
def test_scalar_reduce(self):
|
| 153 |
+
"""
|
| 154 |
+
Reduction [m, n] -> [1,]
|
| 155 |
+
"""
|
| 156 |
+
def evt_scalar_reduce(accum, alpha, C):
|
| 157 |
+
acc_max = max(accum, dim=[1, 2])
|
| 158 |
+
F = alpha * accum
|
| 159 |
+
F_max = max(F, dim=[0, 1, 2])
|
| 160 |
+
D = F + C
|
| 161 |
+
return D, F_max, acc_max
|
| 162 |
+
|
| 163 |
+
for m, n, k, l in self.get_problem_sizes(8):
|
| 164 |
+
example_inputs = {
|
| 165 |
+
"accum": self.fake_tensor(self.element, (l, m, n)),
|
| 166 |
+
"alpha": 2.0,
|
| 167 |
+
"C": self.fake_tensor(self.element, (l, m, n)),
|
| 168 |
+
"acc_max": self.fake_tensor(np.float32, (l, 1, 1)),
|
| 169 |
+
"F_max": self.fake_tensor(np.float32, (1,)),
|
| 170 |
+
"D": self.fake_tensor(self.element, (l, m, n)),
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
launcher = EVTTestBed(self.element, evt_scalar_reduce, example_inputs)
|
| 174 |
+
input_keys = ["C", "alpha"]
|
| 175 |
+
result_keys = ["D", "F_max", "acc_max"]
|
| 176 |
+
launcher.verify((m, n, k), input_keys, result_keys, l)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == '__main__':
|
| 180 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
import pathlib
|
| 34 |
+
import unittest
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == '__main__':
|
| 38 |
+
loader = unittest.TestLoader()
|
| 39 |
+
script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/'
|
| 40 |
+
tests = loader.discover(script_dir, 'evt_*.py')
|
| 41 |
+
testRunner = unittest.runner.TextTestRunner()
|
| 42 |
+
results = testRunner.run(tests)
|
| 43 |
+
if not results.wasSuccessful():
|
| 44 |
+
raise Exception('Test cases failed')
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Testbed classes of EVT
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen import Tensor
|
| 42 |
+
import cutlass_cppgen.backend.evt
|
| 43 |
+
from cutlass_cppgen.shape import GemmCoord
|
| 44 |
+
from cutlass_cppgen.utils.datatypes import torch_type
|
| 45 |
+
from cutlass_cppgen.utils.profiler import CUDAEventProfiler
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class EVTReferenceModule:
|
| 49 |
+
def __init__(self, layout_A, layout_B, layout_C, epilogue_visitor):
|
| 50 |
+
self.layout_A = layout_A
|
| 51 |
+
self.layout_B = layout_B
|
| 52 |
+
self.layout_C = layout_C
|
| 53 |
+
self.epilogue_visitor = epilogue_visitor
|
| 54 |
+
|
| 55 |
+
def run(self, A, B, C, problem_size, alpha, beta, batch=1):
|
| 56 |
+
if self.layout_A == cutlass_cppgen.LayoutType.RowMajor:
|
| 57 |
+
A_row = A.view((batch, problem_size.m, problem_size.k))
|
| 58 |
+
else:
|
| 59 |
+
A_col = A.view((batch, problem_size.k, problem_size.m))
|
| 60 |
+
A_row = torch.permute(A_col, (0, 2, 1))
|
| 61 |
+
|
| 62 |
+
if self.layout_B == cutlass_cppgen.LayoutType.RowMajor:
|
| 63 |
+
B_row = B.view((batch, problem_size.k, problem_size.n))
|
| 64 |
+
else:
|
| 65 |
+
B_col = B.view((batch, problem_size.n, problem_size.k))
|
| 66 |
+
B_row = torch.permute(B_col, (0, 2, 1))
|
| 67 |
+
|
| 68 |
+
if self.layout_C == cutlass_cppgen.LayoutType.RowMajor:
|
| 69 |
+
C_row = C.view((batch, problem_size.m, problem_size.n))
|
| 70 |
+
else:
|
| 71 |
+
C_col = C.view((batch, problem_size.n, problem_size.m))
|
| 72 |
+
C_row = torch.permute(C_col, (0, 2, 1))
|
| 73 |
+
|
| 74 |
+
out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta
|
| 75 |
+
|
| 76 |
+
if self.layout_C == cutlass_cppgen.LayoutType.ColumnMajor:
|
| 77 |
+
out = torch.permute(out_row, (0, 2, 1))
|
| 78 |
+
else:
|
| 79 |
+
out = out_row
|
| 80 |
+
|
| 81 |
+
return torch.flatten(out)
|
| 82 |
+
|
| 83 |
+
def __call__(self, A, B, C, problem_size, batch=1, epilogue_args=None):
|
| 84 |
+
# Running the mainloop
|
| 85 |
+
accum = self.run(
|
| 86 |
+
A, B, C, problem_size, 1.0, 0.0, batch=batch
|
| 87 |
+
).reshape(batch, problem_size.m, problem_size.n)
|
| 88 |
+
|
| 89 |
+
# Running the epilogue
|
| 90 |
+
epilogue_args["accum"] = accum
|
| 91 |
+
references = self.epilogue_visitor(**epilogue_args)
|
| 92 |
+
|
| 93 |
+
# Return the results
|
| 94 |
+
if not isinstance(references, tuple):
|
| 95 |
+
references = (references,)
|
| 96 |
+
return references
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class EVTTestBed:
|
| 100 |
+
"""
|
| 101 |
+
Epilogue Visitor Testbed
|
| 102 |
+
"""
|
| 103 |
+
def __init__(self, element, evt_fn, example_inputs, profile=False, **kwargs) -> None:
|
| 104 |
+
self.element = element
|
| 105 |
+
layout = cutlass_cppgen.LayoutType.RowMajor
|
| 106 |
+
self.example_inputs = example_inputs
|
| 107 |
+
|
| 108 |
+
# Create the Gemm plan
|
| 109 |
+
self.plan = cutlass_cppgen.op.Gemm(element=element, layout=layout, element_accumulator=torch.float32)
|
| 110 |
+
|
| 111 |
+
if "tile_description" in kwargs:
|
| 112 |
+
self.plan.tile_description = kwargs["tile_description"]
|
| 113 |
+
|
| 114 |
+
if "swizzling_functor" in kwargs:
|
| 115 |
+
self.plan.swizzling_functor = kwargs["swizzling_functor"]
|
| 116 |
+
|
| 117 |
+
# Compile the epilogue visitor
|
| 118 |
+
epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_fn, example_inputs)
|
| 119 |
+
if "epilogue_stages" in kwargs:
|
| 120 |
+
epilogue_visitor.epilogue_stages = kwargs["epilogue_stages"]
|
| 121 |
+
self.plan.epilogue_visitor = epilogue_visitor
|
| 122 |
+
|
| 123 |
+
# Reference model
|
| 124 |
+
self.reference_fn = EVTReferenceModule(layout, layout, layout, epilogue_visitor)
|
| 125 |
+
|
| 126 |
+
self.profile = profile
|
| 127 |
+
|
| 128 |
+
def get_torch_tensor(self, shape, dtype=None, fill=None):
|
| 129 |
+
if dtype is None:
|
| 130 |
+
dtype = self.element
|
| 131 |
+
|
| 132 |
+
dtype = torch_type(dtype)
|
| 133 |
+
if fill is None:
|
| 134 |
+
return torch.ceil(
|
| 135 |
+
torch.empty(size=shape, dtype=dtype, device="cuda").uniform_(-4.5, 3.5)
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
return torch.full(shape, fill, dtype=dtype, device="cuda")
|
| 139 |
+
|
| 140 |
+
def verify(self, problem_size, input_keys, result_keys, batch_count=1):
|
| 141 |
+
"""
|
| 142 |
+
Verify the results
|
| 143 |
+
"""
|
| 144 |
+
problem_size = GemmCoord(*problem_size)
|
| 145 |
+
|
| 146 |
+
# Initiate the GEMM arguments
|
| 147 |
+
tensor_A = self.get_torch_tensor((batch_count, problem_size.m, problem_size.k))
|
| 148 |
+
tensor_B = self.get_torch_tensor((batch_count, problem_size.k, problem_size.n))
|
| 149 |
+
|
| 150 |
+
# Initialize the epilogue args
|
| 151 |
+
epilogue_args = {}
|
| 152 |
+
for key in self.example_inputs.keys():
|
| 153 |
+
if key in input_keys:
|
| 154 |
+
tensor = self.example_inputs[key]
|
| 155 |
+
if isinstance(tensor, Tensor):
|
| 156 |
+
epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element)
|
| 157 |
+
else:
|
| 158 |
+
epilogue_args[key] = tensor
|
| 159 |
+
elif key in result_keys:
|
| 160 |
+
tensor = self.example_inputs[key]
|
| 161 |
+
if isinstance(tensor, Tensor):
|
| 162 |
+
if "max" in key:
|
| 163 |
+
fill = -1000
|
| 164 |
+
else:
|
| 165 |
+
fill = 0
|
| 166 |
+
epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element, fill=fill)
|
| 167 |
+
else:
|
| 168 |
+
epilogue_args[key] = tensor
|
| 169 |
+
|
| 170 |
+
tensor_D = epilogue_args["D"]
|
| 171 |
+
if "C" in epilogue_args:
|
| 172 |
+
tensor_C = epilogue_args["C"]
|
| 173 |
+
else:
|
| 174 |
+
tensor_C = tensor_D
|
| 175 |
+
# Run the device kernel
|
| 176 |
+
self.plan.run(tensor_A, tensor_B, tensor_C, tensor_D, visitor_args=epilogue_args)
|
| 177 |
+
|
| 178 |
+
# Run the host reference
|
| 179 |
+
evt_args_inputs = {}
|
| 180 |
+
for key in input_keys:
|
| 181 |
+
evt_args_inputs[key] = epilogue_args[key]
|
| 182 |
+
|
| 183 |
+
reference_results = self.reference_fn(
|
| 184 |
+
tensor_A, tensor_B, tensor_C, problem_size, batch_count, evt_args_inputs)
|
| 185 |
+
|
| 186 |
+
# Compare the results
|
| 187 |
+
for result, ref in zip(result_keys, reference_results):
|
| 188 |
+
assert torch.equal(
|
| 189 |
+
epilogue_args[result].flatten(),
|
| 190 |
+
ref.masked_fill(torch.isnan(ref), float('inf')).flatten())
|
| 191 |
+
|
| 192 |
+
# Run profile
|
| 193 |
+
if self.profile:
|
| 194 |
+
profiler = CUDAEventProfiler(
|
| 195 |
+
self.plan, 100, 100, tensor_A, tensor_B, tensor_C, tensor_D,
|
| 196 |
+
visitor_args = epilogue_args
|
| 197 |
+
)
|
| 198 |
+
print(f"Cutlass Python Duration: {profiler()}")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class EVTTestCaseBase(unittest.TestCase):
|
| 202 |
+
"""
|
| 203 |
+
Base class for EVT Unittest
|
| 204 |
+
"""
|
| 205 |
+
def __init__(self, methodName: str = "runTest", lmnk=(6, 512, 256, 128)) -> None:
|
| 206 |
+
super().__init__(methodName)
|
| 207 |
+
|
| 208 |
+
self.element = cutlass_cppgen.DataType.f16
|
| 209 |
+
self.l, self.m, self.n, self.k = lmnk
|
| 210 |
+
|
| 211 |
+
self.problem_size = (self.m, self.n, self.k)
|
| 212 |
+
|
| 213 |
+
torch.random.manual_seed(42)
|
| 214 |
+
|
| 215 |
+
def fake_tensor(self, element, shape, stride=None):
|
| 216 |
+
if stride is None:
|
| 217 |
+
return Tensor(element=element, shape=shape, layout_tag=cutlass_cppgen.LayoutType.RowMajor)
|
| 218 |
+
else:
|
| 219 |
+
return Tensor(element=element, shape=shape, stride=stride)
|
| 220 |
+
|
| 221 |
+
def get_problem_sizes(self, alignment, k=None, batch_count=[3,]):
|
| 222 |
+
k = k if k else self.k
|
| 223 |
+
problem_size_m = [alignment, 512 - 3 * alignment]
|
| 224 |
+
problem_size_n = [alignment, 512 - alignment]
|
| 225 |
+
if alignment % 8 == 0:
|
| 226 |
+
problem_size_m.append(768)
|
| 227 |
+
problem_size_n.append(768)
|
| 228 |
+
problem_size_l = batch_count
|
| 229 |
+
problem_sizes = []
|
| 230 |
+
for m in problem_size_m:
|
| 231 |
+
for n in problem_size_n:
|
| 232 |
+
for l in problem_size_l:
|
| 233 |
+
problem_sizes.append((m, n, k, l))
|
| 234 |
+
|
| 235 |
+
return problem_sizes
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
High-level tests for running batched GEMMs
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import partial
|
| 38 |
+
import logging
|
| 39 |
+
from math import prod
|
| 40 |
+
import unittest
|
| 41 |
+
|
| 42 |
+
import cutlass_cppgen
|
| 43 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 44 |
+
import torch
|
| 45 |
+
|
| 46 |
+
from utils import LayoutCombination
|
| 47 |
+
|
| 48 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 49 |
+
|
| 50 |
+
torch.manual_seed(2023)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def pytorch_reference(A, B, C, alpha, beta):
|
| 54 |
+
# Get the batch count. Assume that any of A, B, and C
|
| 55 |
+
# with a batch dimension ahve matching batch count. Thus,
|
| 56 |
+
# we break out of the loop once we have found the first
|
| 57 |
+
# tensor containing a batch dimension.
|
| 58 |
+
batch_count = (1,)
|
| 59 |
+
for tensor in [A, B, C]:
|
| 60 |
+
if len(tensor.shape) > 2:
|
| 61 |
+
batch_count = tensor.shape[:-2]
|
| 62 |
+
break
|
| 63 |
+
|
| 64 |
+
int_batch_count = prod(batch_count)
|
| 65 |
+
|
| 66 |
+
def add_batch(tensor):
|
| 67 |
+
if len(tensor.shape) == 2:
|
| 68 |
+
return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1)
|
| 69 |
+
else:
|
| 70 |
+
return tensor.reshape(-1, tensor.size(-2), tensor.size(-1))
|
| 71 |
+
|
| 72 |
+
# Reshape tensors to have batch dimension
|
| 73 |
+
A = add_batch(A)
|
| 74 |
+
B = add_batch(B)
|
| 75 |
+
C = add_batch(C)
|
| 76 |
+
|
| 77 |
+
ret = (torch.bmm(A, B) * alpha) + (C * beta)
|
| 78 |
+
reshape_vals = batch_count + C.shape[-2:]
|
| 79 |
+
return ret.reshape(*reshape_vals)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def initialize(rows, cols, batch):
|
| 83 |
+
tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half()
|
| 84 |
+
if len(batch) > 0 and prod(batch) > 1:
|
| 85 |
+
reshape_vals = batch + (rows, cols)
|
| 86 |
+
return tensor.reshape(*reshape_vals)
|
| 87 |
+
else:
|
| 88 |
+
return tensor.reshape(rows, cols)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class GemmF16Batched(unittest.TestCase):
|
| 92 |
+
def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool):
|
| 93 |
+
M = 512
|
| 94 |
+
N = 256
|
| 95 |
+
K = 128
|
| 96 |
+
alpha = 1.
|
| 97 |
+
beta = 2.
|
| 98 |
+
|
| 99 |
+
A = initialize(M, K, batch_count if batch_A else (1,))
|
| 100 |
+
B = initialize(K, N, batch_count if batch_B else (1,))
|
| 101 |
+
C = initialize(M, N, batch_count if batch_C else (1,))
|
| 102 |
+
D = initialize(M, N, batch_count)
|
| 103 |
+
|
| 104 |
+
plan = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass_cppgen.DataType.f32)
|
| 105 |
+
plan.run(A, B, C, D, alpha, beta)
|
| 106 |
+
reference = pytorch_reference(A, B, C, alpha, beta)
|
| 107 |
+
assert reference.equal(D)
|
| 108 |
+
|
| 109 |
+
def test_batched_ABC(self):
|
| 110 |
+
self.run_batched((3,), True, True, True)
|
| 111 |
+
self.run_batched((2, 3), True, True, True)
|
| 112 |
+
|
| 113 |
+
def test_batched_AB(self):
|
| 114 |
+
self.run_batched((3,), True, True, False)
|
| 115 |
+
self.run_batched((2, 3), True, True, False)
|
| 116 |
+
|
| 117 |
+
def test_batched_AC(self):
|
| 118 |
+
self.run_batched((3,), True, False, True)
|
| 119 |
+
self.run_batched((2, 3), True, False, True)
|
| 120 |
+
|
| 121 |
+
def test_batched_BC(self):
|
| 122 |
+
self.run_batched((3,), False, True, True)
|
| 123 |
+
self.run_batched((2, 3), False, True, True)
|
| 124 |
+
|
| 125 |
+
def test_batched_A(self):
|
| 126 |
+
self.run_batched((3,), True, False, False)
|
| 127 |
+
self.run_batched((2, 3), True, False, False)
|
| 128 |
+
|
| 129 |
+
def test_batched_B(self):
|
| 130 |
+
self.run_batched((3,), False, True, False)
|
| 131 |
+
self.run_batched((2, 3), False, True, False)
|
| 132 |
+
|
| 133 |
+
if __name__ == '__main__':
|
| 134 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Low-level functionality tests for GEMM with F16 operands on SM80
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import partial
|
| 38 |
+
import logging
|
| 39 |
+
import unittest
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
|
| 44 |
+
from utils import LayoutCombination, add_test_gemm
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
cc = 80
|
| 49 |
+
dtype = cutlass_cppgen.DataType.f16
|
| 50 |
+
|
| 51 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
| 52 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 53 |
+
class GemmF16Sm80(unittest.TestCase):
|
| 54 |
+
"""
|
| 55 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 56 |
+
"""
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
| 61 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 62 |
+
class GemmF16Sm80StreamK(unittest.TestCase):
|
| 63 |
+
"""
|
| 64 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 65 |
+
"""
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
|
| 69 |
+
|
| 70 |
+
# Tests using TensorOp
|
| 71 |
+
add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
|
| 72 |
+
|
| 73 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 74 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 75 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 76 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 77 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 78 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 79 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 80 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 81 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 82 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 83 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 84 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 85 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 86 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 87 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 88 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 89 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 90 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
|
| 91 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 92 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3)
|
| 93 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 94 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3)
|
| 95 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 96 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 97 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 98 |
+
element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 99 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 100 |
+
element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 101 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 102 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5)
|
| 103 |
+
add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 104 |
+
element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 105 |
+
|
| 106 |
+
# Tests using SIMT
|
| 107 |
+
add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
|
| 108 |
+
|
| 109 |
+
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 110 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
| 111 |
+
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 112 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
| 113 |
+
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 114 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
| 115 |
+
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 116 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
| 117 |
+
add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 118 |
+
element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
| 119 |
+
|
| 120 |
+
# Stream K tests
|
| 121 |
+
add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK)
|
| 122 |
+
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 123 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 124 |
+
add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 125 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5)
|
| 126 |
+
|
| 127 |
+
if __name__ == '__main__':
|
| 128 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Low-level functionality tests for GEMM with F16 operands on SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import partial
|
| 38 |
+
import logging
|
| 39 |
+
import unittest
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
|
| 44 |
+
from utils import LayoutCombination, add_test_gemm
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
cc = 90
|
| 49 |
+
dtype = cutlass_cppgen.DataType.f16
|
| 50 |
+
|
| 51 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
|
| 52 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 53 |
+
class GemmF16Sm90(unittest.TestCase):
|
| 54 |
+
"""
|
| 55 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 56 |
+
"""
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype,
|
| 61 |
+
warp_count=None, compilation_modes=['nvcc'])
|
| 62 |
+
|
| 63 |
+
add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
|
| 64 |
+
|
| 65 |
+
# Tests with 1x1x1 clusters
|
| 66 |
+
add_test_unit_cluster = partial(add_test_tensorop, cluster_shape=[1, 1, 1])
|
| 67 |
+
add_test_unit_cluster(layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 68 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=3)
|
| 69 |
+
add_test_unit_cluster(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 70 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None)
|
| 71 |
+
add_test_unit_cluster(layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 72 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None)
|
| 73 |
+
add_test_unit_cluster(layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 74 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None)
|
| 75 |
+
add_test_unit_cluster(layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 76 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None)
|
| 77 |
+
add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 78 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None)
|
| 79 |
+
add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 80 |
+
element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None)
|
| 81 |
+
add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 82 |
+
element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None)
|
| 83 |
+
add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 84 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], stages=5)
|
| 85 |
+
add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16,
|
| 86 |
+
element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None)
|
| 87 |
+
|
| 88 |
+
# Tests with different cluster shapes
|
| 89 |
+
add_test_cluster_shape = partial(add_test_tensorop, threadblock_shape=[64, 128, 64], stages=None)
|
| 90 |
+
add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 91 |
+
element_accumulator=cutlass_cppgen.DataType.f16, cluster_shape=[2, 2, 1])
|
| 92 |
+
add_test_cluster_shape(layouts=LayoutCombination.TNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
|
| 93 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1])
|
| 94 |
+
add_test_cluster_shape(layouts=LayoutCombination.NTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
|
| 95 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1])
|
| 96 |
+
add_test_cluster_shape(layouts=LayoutCombination.NNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
|
| 97 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1])
|
| 98 |
+
add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
|
| 99 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1])
|
| 100 |
+
add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
|
| 101 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 4, 1])
|
| 102 |
+
add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
|
| 103 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 1, 1])
|
| 104 |
+
add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
|
| 105 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 2, 1])
|
| 106 |
+
|
| 107 |
+
# Tests for different schedule modes
|
| 108 |
+
add_test_schedule = partial(add_test_specialized, layouts=LayoutCombination.TTN, alignments=[8, 8, 4],
|
| 109 |
+
element_output=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
|
| 110 |
+
opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], stages=None)
|
| 111 |
+
add_test_schedule(
|
| 112 |
+
cluster_shape=[1, 1, 1],
|
| 113 |
+
kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong,
|
| 114 |
+
epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized
|
| 115 |
+
)
|
| 116 |
+
add_test_schedule(
|
| 117 |
+
cluster_shape=[1, 1, 1],
|
| 118 |
+
kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
|
| 119 |
+
epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative
|
| 120 |
+
)
|
| 121 |
+
add_test_schedule(
|
| 122 |
+
cluster_shape=[2, 1, 1],
|
| 123 |
+
kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong,
|
| 124 |
+
epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized
|
| 125 |
+
)
|
| 126 |
+
add_test_schedule(
|
| 127 |
+
cluster_shape=[2, 1, 1],
|
| 128 |
+
kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
|
| 129 |
+
epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Tests using SIMT
|
| 133 |
+
add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], stages=2)
|
| 134 |
+
add_test_simt(layouts=LayoutCombination.NNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8])
|
| 135 |
+
add_test_simt(layouts=LayoutCombination.TNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8])
|
| 136 |
+
add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8])
|
| 137 |
+
add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8])
|
| 138 |
+
add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8])
|
| 139 |
+
|
| 140 |
+
# Tests with void-C kernels
|
| 141 |
+
add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
|
| 142 |
+
element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None,
|
| 143 |
+
cluster_shape=[2, 1, 1], element_C=cutlass_cppgen.DataType.void)
|
| 144 |
+
|
| 145 |
+
if __name__ == '__main__':
|
| 146 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Low-level functionality tests for GEMM with F32 operands on SM80
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import partial
|
| 38 |
+
import logging
|
| 39 |
+
import unittest
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
|
| 44 |
+
from utils import LayoutCombination, add_test_gemm
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
cc = 80
|
| 49 |
+
dtype = cutlass_cppgen.DataType.f32
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
| 53 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 54 |
+
class GemmF32Sm80(unittest.TestCase):
|
| 55 |
+
"""
|
| 56 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 57 |
+
"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
| 62 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 63 |
+
class GemmF32Sm80StreamK(unittest.TestCase):
|
| 64 |
+
"""
|
| 65 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 66 |
+
"""
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
|
| 71 |
+
|
| 72 |
+
# Tests using TensorOp
|
| 73 |
+
add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
|
| 74 |
+
|
| 75 |
+
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
|
| 76 |
+
element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 77 |
+
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
|
| 78 |
+
element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 79 |
+
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
|
| 80 |
+
element_accumulator=dtype, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
|
| 81 |
+
add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
|
| 82 |
+
element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4)
|
| 83 |
+
# Tests using SIMT
|
| 84 |
+
add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
|
| 85 |
+
|
| 86 |
+
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 87 |
+
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
| 88 |
+
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 89 |
+
element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
| 90 |
+
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 91 |
+
element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
| 92 |
+
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 93 |
+
element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
| 94 |
+
add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 95 |
+
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
| 96 |
+
|
| 97 |
+
# Stream K tests
|
| 98 |
+
add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK)
|
| 99 |
+
add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
|
| 100 |
+
element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == '__main__':
|
| 104 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Low-level functionality tests for GEMM with F64 operands on SM80
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import partial
|
| 38 |
+
import logging
|
| 39 |
+
import unittest
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
|
| 44 |
+
from utils import LayoutCombination, add_test_gemm
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
cc = 80
|
| 49 |
+
dtype = cutlass_cppgen.DataType.f64
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
| 53 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 54 |
+
class GemmF64Sm80(unittest.TestCase):
|
| 55 |
+
"""
|
| 56 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 57 |
+
"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
| 62 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 63 |
+
class GemmF64Sm80StreamK(unittest.TestCase):
|
| 64 |
+
"""
|
| 65 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 66 |
+
"""
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
|
| 71 |
+
|
| 72 |
+
# Tests using TensorOp
|
| 73 |
+
add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
|
| 74 |
+
|
| 75 |
+
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 76 |
+
element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
|
| 77 |
+
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 78 |
+
element_accumulator=dtype, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4)
|
| 79 |
+
add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 80 |
+
element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5)
|
| 81 |
+
|
| 82 |
+
# Tests using SIMT
|
| 83 |
+
add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
|
| 84 |
+
|
| 85 |
+
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 86 |
+
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
| 87 |
+
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 88 |
+
element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
| 89 |
+
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 90 |
+
element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
| 91 |
+
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 92 |
+
element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
| 93 |
+
add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 94 |
+
element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
| 95 |
+
|
| 96 |
+
# Stream K tests
|
| 97 |
+
add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK)
|
| 98 |
+
add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
|
| 99 |
+
element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == '__main__':
|
| 103 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Low-level functionality tests for GEMM with F64 operands on SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import partial
|
| 38 |
+
import logging
|
| 39 |
+
import unittest
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
|
| 44 |
+
from utils import LayoutCombination, add_test_gemm
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
cc = 90
|
| 49 |
+
dtype = cutlass_cppgen.DataType.f64
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
|
| 53 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 54 |
+
class GemmF64Sm90(unittest.TestCase):
|
| 55 |
+
"""
|
| 56 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 57 |
+
"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1],
|
| 62 |
+
element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc'])
|
| 63 |
+
|
| 64 |
+
add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3)
|
| 65 |
+
add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3)
|
| 66 |
+
add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.NNN, threadblock_shape=[128, 128, 8], stages=2)
|
| 67 |
+
add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.TTT, threadblock_shape=[ 64, 128, 8], stages=2)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
if __name__ == '__main__':
|
| 71 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Low-level functionality tests for GEMM with S8 operands on SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import partial
|
| 38 |
+
import logging
|
| 39 |
+
import unittest
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
|
| 44 |
+
from utils import LayoutCombination, add_test_gemm
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
cc = 90
|
| 49 |
+
dtype = cutlass_cppgen.DataType.e4m3
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
|
| 53 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 54 |
+
class GemmF8E4M3Sm90(unittest.TestCase):
|
| 55 |
+
"""
|
| 56 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 57 |
+
"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc'])
|
| 62 |
+
|
| 63 |
+
add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
|
| 64 |
+
|
| 65 |
+
# Test with 1x1x1 clusters
|
| 66 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3,
|
| 67 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None)
|
| 68 |
+
|
| 69 |
+
# Tests with different cluster shapes
|
| 70 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3,
|
| 71 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None)
|
| 72 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3,
|
| 73 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None)
|
| 74 |
+
|
| 75 |
+
# Tests with warp-specialized ping-pong schedule
|
| 76 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3,
|
| 77 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None,
|
| 78 |
+
kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong,
|
| 79 |
+
epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized)
|
| 80 |
+
|
| 81 |
+
# Tests for SIMT
|
| 82 |
+
add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
|
| 83 |
+
add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.e4m3,
|
| 84 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
#
|
| 88 |
+
# Add a test for E5M2
|
| 89 |
+
#
|
| 90 |
+
dtype = cutlass_cppgen.DataType.e5m2
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
|
| 94 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 95 |
+
class GemmF8E5M2Sm90(unittest.TestCase):
|
| 96 |
+
"""
|
| 97 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 98 |
+
"""
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc'])
|
| 103 |
+
|
| 104 |
+
add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
|
| 105 |
+
|
| 106 |
+
# Tests with 1x1x1 clusters
|
| 107 |
+
add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype,
|
| 108 |
+
element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == '__main__':
|
| 112 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Low-level functionality tests for GEMM with mixed operands on SM80
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import partial
|
| 38 |
+
import logging
|
| 39 |
+
import unittest
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
|
| 44 |
+
from utils import LayoutCombination, add_test_gemm
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
cc = 80
|
| 49 |
+
dtype =cutlass_cppgen.DataType.f16
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
| 53 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 54 |
+
class GemmMixedSm80(unittest.TestCase):
|
| 55 |
+
"""
|
| 56 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 57 |
+
"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1],
|
| 62 |
+
opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64],
|
| 63 |
+
warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass_cppgen.DataType.f32)
|
| 64 |
+
|
| 65 |
+
# Test with upcast on A
|
| 66 |
+
add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT)
|
| 67 |
+
add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN)
|
| 68 |
+
|
| 69 |
+
# Test with upcast on B
|
| 70 |
+
add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT)
|
| 71 |
+
add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if __name__ == '__main__':
|
| 75 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Low-level functionality tests for GEMM with S8 operands on SM80
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import partial
|
| 38 |
+
import logging
|
| 39 |
+
import unittest
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
|
| 44 |
+
from utils import LayoutCombination, add_test_gemm
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
cc = 80
|
| 49 |
+
dtype = cutlass_cppgen.DataType.s8
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
| 53 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 54 |
+
class GemmS8Sm80(unittest.TestCase):
|
| 55 |
+
"""
|
| 56 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 57 |
+
"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
|
| 62 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 63 |
+
class GemmS8Sm80StreamK(unittest.TestCase):
|
| 64 |
+
"""
|
| 65 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 66 |
+
"""
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
|
| 71 |
+
|
| 72 |
+
# Tests using TensorOp
|
| 73 |
+
add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
|
| 74 |
+
|
| 75 |
+
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
|
| 76 |
+
element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3)
|
| 77 |
+
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
|
| 78 |
+
element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3)
|
| 79 |
+
add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32,
|
| 80 |
+
element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4)
|
| 81 |
+
|
| 82 |
+
# Tests using SIMT
|
| 83 |
+
add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
|
| 84 |
+
|
| 85 |
+
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
|
| 86 |
+
element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
| 87 |
+
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
|
| 88 |
+
element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
|
| 89 |
+
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
|
| 90 |
+
element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
|
| 91 |
+
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32,
|
| 92 |
+
element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
|
| 93 |
+
add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32,
|
| 94 |
+
element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
|
| 95 |
+
|
| 96 |
+
# Stream K tests
|
| 97 |
+
add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK)
|
| 98 |
+
add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
|
| 99 |
+
element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == '__main__':
|
| 103 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Low-level functionality tests for GEMM with S8 operands on SM90
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from functools import partial
|
| 38 |
+
import logging
|
| 39 |
+
import unittest
|
| 40 |
+
|
| 41 |
+
import cutlass_cppgen
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
|
| 44 |
+
from utils import LayoutCombination, add_test_gemm
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
cutlass_cppgen.set_log_level(logging.WARNING)
|
| 48 |
+
cc = 90
|
| 49 |
+
dtype = cutlass_cppgen.DataType.s8
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
|
| 53 |
+
@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
|
| 54 |
+
class GemmS8Sm90(unittest.TestCase):
|
| 55 |
+
"""
|
| 56 |
+
Wrapper class to which tests will be added dynamically in __main__
|
| 57 |
+
"""
|
| 58 |
+
pass
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc'])
|
| 62 |
+
|
| 63 |
+
add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
|
| 64 |
+
|
| 65 |
+
# Tests with 1x1x1 clusters
|
| 66 |
+
add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
|
| 67 |
+
element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3)
|
| 68 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
|
| 69 |
+
element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None)
|
| 70 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 8], element_output=cutlass_cppgen.DataType.s8,
|
| 71 |
+
element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None)
|
| 72 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
|
| 73 |
+
element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 128, 128], stages=None)
|
| 74 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
|
| 75 |
+
element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 64, 32], stages=None)
|
| 76 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[ 4, 4, 16], element_output=cutlass_cppgen.DataType.s8,
|
| 77 |
+
element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None)
|
| 78 |
+
|
| 79 |
+
# Tests with different cluster shapes
|
| 80 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
|
| 81 |
+
element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None)
|
| 82 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
|
| 83 |
+
element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None)
|
| 84 |
+
|
| 85 |
+
# Tests with warp-specialized ping-pong schedule
|
| 86 |
+
add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
|
| 87 |
+
element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None,
|
| 88 |
+
kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong,
|
| 89 |
+
epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized)
|
| 90 |
+
|
| 91 |
+
# Tests for SIMT
|
| 92 |
+
add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
|
| 93 |
+
add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8,
|
| 94 |
+
element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == '__main__':
|
| 98 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from math import prod
|
| 34 |
+
import os
|
| 35 |
+
import re
|
| 36 |
+
import subprocess
|
| 37 |
+
|
| 38 |
+
import torch
|
| 39 |
+
|
| 40 |
+
from cutlass_library import (
|
| 41 |
+
DataType,
|
| 42 |
+
DataTypeSize,
|
| 43 |
+
GemmUniversalMode,
|
| 44 |
+
LayoutType,
|
| 45 |
+
OpcodeClass,
|
| 46 |
+
ShortDataTypeNames,
|
| 47 |
+
SwizzlingFunctor
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
from cutlass_cppgen.backend import compiler
|
| 51 |
+
from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
|
| 52 |
+
from cutlass_cppgen.backend.reduction_operation import ReductionArguments, ReductionOperation
|
| 53 |
+
from cutlass_cppgen.shape import GemmCoord, MatrixCoord
|
| 54 |
+
from cutlass_cppgen.utils.datatypes import torch_type
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class GemmUniversalLauncher:
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
operation,
|
| 61 |
+
seed=2080,
|
| 62 |
+
verification=True,
|
| 63 |
+
iterations=500,
|
| 64 |
+
compiler_mode= "nvcc",
|
| 65 |
+
**kwargs,
|
| 66 |
+
) -> None:
|
| 67 |
+
self.math_operation = operation.tile_description.math_instruction.math_operation
|
| 68 |
+
self.verification = verification
|
| 69 |
+
|
| 70 |
+
if compiler_mode == "nvcc":
|
| 71 |
+
compiler.nvcc()
|
| 72 |
+
elif compiler_mode == "nvrtc":
|
| 73 |
+
compiler.nvrtc()
|
| 74 |
+
else:
|
| 75 |
+
raise Exception(f"Unexpected compiler string {compiler_mode}")
|
| 76 |
+
|
| 77 |
+
op_list = [operation]
|
| 78 |
+
if operation.arch < 90:
|
| 79 |
+
# Split K via Python is currently only supported for pre-SM90 kernels
|
| 80 |
+
self.reduction_operation: ReductionOperation = ReductionOperation(
|
| 81 |
+
shape=MatrixCoord(4, 32 * operation.C.alignment),
|
| 82 |
+
C=operation.C,
|
| 83 |
+
element_accumulator=operation.tile_description.math_instruction.element_accumulator,
|
| 84 |
+
element_compute=operation.epilogue_functor.element_epilogue,
|
| 85 |
+
epilogue_functor=operation.epilogue_functor,
|
| 86 |
+
count=operation.C.alignment,
|
| 87 |
+
)
|
| 88 |
+
op_list.append(self.reduction_operation)
|
| 89 |
+
|
| 90 |
+
compiler.add_module(op_list, bypass_cache=False)
|
| 91 |
+
|
| 92 |
+
self.operation = operation
|
| 93 |
+
|
| 94 |
+
self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element)
|
| 95 |
+
self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element)
|
| 96 |
+
self.dtype_C = torch_type(operation.C.element)
|
| 97 |
+
self.dtype_D = torch_type(operation.epilogue_functor.element_output)
|
| 98 |
+
|
| 99 |
+
element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element])
|
| 100 |
+
|
| 101 |
+
if element_size == 1:
|
| 102 |
+
self.rand_max = 1
|
| 103 |
+
self.rand_min = 0
|
| 104 |
+
elif element_size <= 8:
|
| 105 |
+
self.rand_max = 1
|
| 106 |
+
self.rand_min = -1
|
| 107 |
+
elif element_size == 16:
|
| 108 |
+
self.rand_max = 4
|
| 109 |
+
self.rand_min = -4
|
| 110 |
+
else:
|
| 111 |
+
self.rand_max = 8
|
| 112 |
+
self.rand_min = -8
|
| 113 |
+
|
| 114 |
+
self.seed = seed
|
| 115 |
+
|
| 116 |
+
self.compute_type = operation.epilogue_functor.element_epilogue
|
| 117 |
+
self.accumulator_type = operation.tile_description.math_instruction.element_accumulator
|
| 118 |
+
|
| 119 |
+
def print_problem_size(self, p, mode, batch_count):
|
| 120 |
+
if mode == GemmUniversalMode.Gemm:
|
| 121 |
+
mode = "Gemm"
|
| 122 |
+
elif mode == GemmUniversalMode.Batched:
|
| 123 |
+
mode = "GemmBatched"
|
| 124 |
+
elif mode == GemmUniversalMode.GemmSplitKParallel:
|
| 125 |
+
mode = "GemmSplitKParallel"
|
| 126 |
+
print(f"problem: {p.m}, {p.n}, {p.k}\n batch_count: {batch_count}\n mode: {mode}")
|
| 127 |
+
|
| 128 |
+
def uniform_init(self, shape, dtype, layout):
|
| 129 |
+
size = prod(shape)
|
| 130 |
+
if dtype.is_floating_point:
|
| 131 |
+
# Initialize data in FP32 and call convert to the data type we desire.
|
| 132 |
+
# This is a workaround for the following error that occurs when attempting to
|
| 133 |
+
# call uniform_ on a tensor with torch.float8_e4m3fn data:
|
| 134 |
+
# RuntimeError: "check_uniform_bounds" not implemented for 'Float8_e4m3fn'
|
| 135 |
+
data = torch.ceil(
|
| 136 |
+
torch.empty(size=(size,), dtype=torch.float32, device="cuda").uniform_(
|
| 137 |
+
self.rand_min - 0.5, self.rand_max - 0.5)
|
| 138 |
+
).to(dtype)
|
| 139 |
+
else:
|
| 140 |
+
# PyTorch does not currently support integer-typed matrix multiplications on GPU.
|
| 141 |
+
# Fall back to CPU for integer type references.
|
| 142 |
+
data = torch.empty(size=(size,), dtype=dtype, device="cpu").random_(self.rand_min, self.rand_max + 1)
|
| 143 |
+
|
| 144 |
+
is_fp8 = dtype == getattr(torch, "float8_e4m3fn", -1) or dtype == dtype == getattr(torch, "float8_e5m2", -1)
|
| 145 |
+
|
| 146 |
+
if dtype == torch.float64 or dtype == torch.float32 or is_fp8:
|
| 147 |
+
data = data.to("cpu")
|
| 148 |
+
|
| 149 |
+
data_ref = data.reshape(shape)
|
| 150 |
+
|
| 151 |
+
if layout == LayoutType.RowMajor:
|
| 152 |
+
data_cutlass = data_ref
|
| 153 |
+
else:
|
| 154 |
+
data_cutlass = data_ref.transpose(-1, -2).contiguous()
|
| 155 |
+
|
| 156 |
+
data_cutlass = data_cutlass.to("cuda")
|
| 157 |
+
|
| 158 |
+
# As of this writing, few operations in PyTorch are supported with FP8 data.
|
| 159 |
+
# Thus, we perform computation in FP32 for FP8 reference checks.
|
| 160 |
+
if is_fp8:
|
| 161 |
+
data_ref = data_ref.to(torch.float32)
|
| 162 |
+
|
| 163 |
+
return data_cutlass, data_ref
|
| 164 |
+
|
| 165 |
+
def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
|
| 166 |
+
# If any tensor is on CPU, place all tensors on CPU unless only
|
| 167 |
+
# tensor C is on CPU
|
| 168 |
+
# Handle mixed-input cases by casting to the larger data type and overriding
|
| 169 |
+
# to whatever the data type of the larger type is
|
| 170 |
+
if self.dtype_A != self.dtype_B:
|
| 171 |
+
if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]:
|
| 172 |
+
tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device)
|
| 173 |
+
else:
|
| 174 |
+
tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device)
|
| 175 |
+
|
| 176 |
+
devices = [x.device.type for x in [tensor_A, tensor_B]]
|
| 177 |
+
if tensor_C is not None:
|
| 178 |
+
devices.append(tensor_C.device.type)
|
| 179 |
+
|
| 180 |
+
if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]:
|
| 181 |
+
device = torch.device("cpu")
|
| 182 |
+
else:
|
| 183 |
+
device = tensor_A.device
|
| 184 |
+
|
| 185 |
+
tensor_A = tensor_A.to(device)
|
| 186 |
+
tensor_B = tensor_B.to(device)
|
| 187 |
+
if tensor_C is not None:
|
| 188 |
+
tensor_C = tensor_C.to(device)
|
| 189 |
+
|
| 190 |
+
dtype = torch_type(self.compute_type)
|
| 191 |
+
alpha_torch = torch.tensor([alpha], device=device).to(dtype)
|
| 192 |
+
beta_torch = torch.tensor([beta], device=device).to(dtype)
|
| 193 |
+
|
| 194 |
+
tmp = tensor_A @ tensor_B
|
| 195 |
+
tensor_D_ref = (alpha_torch * tmp)
|
| 196 |
+
if tensor_C is not None:
|
| 197 |
+
tensor_D_ref += (tensor_C * beta_torch)
|
| 198 |
+
return tensor_D_ref.to(self.dtype_D)
|
| 199 |
+
|
| 200 |
+
def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0):
|
| 201 |
+
torch.random.manual_seed(self.seed)
|
| 202 |
+
|
| 203 |
+
# Assign an actual batch count in cases where we are not running in batched mode.
|
| 204 |
+
# This is to differentiate between the number of split K slices and the batch count,
|
| 205 |
+
# which are overloaded within the single `batch_count` variable.
|
| 206 |
+
if mode == GemmUniversalMode.Batched:
|
| 207 |
+
true_batch_count = batch_count
|
| 208 |
+
else:
|
| 209 |
+
true_batch_count = 1
|
| 210 |
+
|
| 211 |
+
def transpose(layout):
|
| 212 |
+
if layout == LayoutType.RowMajor:
|
| 213 |
+
return LayoutType.ColumnMajor
|
| 214 |
+
else:
|
| 215 |
+
return LayoutType.RowMajor
|
| 216 |
+
|
| 217 |
+
tensor_A, tensor_A_ref = self.uniform_init(
|
| 218 |
+
(true_batch_count, problem_size.m, problem_size.k),
|
| 219 |
+
self.dtype_A,
|
| 220 |
+
self.operation.A.layout if not self.operation.switched else transpose(self.operation.B.layout),
|
| 221 |
+
)
|
| 222 |
+
tensor_B, tensor_B_ref = self.uniform_init(
|
| 223 |
+
(true_batch_count, problem_size.k, problem_size.n),
|
| 224 |
+
self.dtype_B,
|
| 225 |
+
self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout),
|
| 226 |
+
)
|
| 227 |
+
if self.dtype_C is not None:
|
| 228 |
+
tensor_C, tensor_C_ref = self.uniform_init(
|
| 229 |
+
(true_batch_count, problem_size.m, problem_size.n),
|
| 230 |
+
self.dtype_C,
|
| 231 |
+
self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout),
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
tensor_C = None
|
| 235 |
+
tensor_C_ref = None
|
| 236 |
+
|
| 237 |
+
tensor_D, _ = self.uniform_init(
|
| 238 |
+
(true_batch_count, problem_size.m, problem_size.n),
|
| 239 |
+
self.dtype_D,
|
| 240 |
+
self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout),
|
| 241 |
+
)
|
| 242 |
+
tensor_D = torch.zeros_like(tensor_D)
|
| 243 |
+
|
| 244 |
+
if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]:
|
| 245 |
+
alpha = int(alpha)
|
| 246 |
+
beta = int(beta)
|
| 247 |
+
|
| 248 |
+
#
|
| 249 |
+
# Launch kernel
|
| 250 |
+
#
|
| 251 |
+
|
| 252 |
+
arguments = GemmArguments(
|
| 253 |
+
operation=self.operation,
|
| 254 |
+
problem_size=problem_size,
|
| 255 |
+
A=tensor_A,
|
| 256 |
+
B=tensor_B,
|
| 257 |
+
C=tensor_C,
|
| 258 |
+
D=tensor_D,
|
| 259 |
+
output_op=self.operation.epilogue_type(alpha, beta),
|
| 260 |
+
gemm_mode=mode,
|
| 261 |
+
split_k_slices=split_k_slices,
|
| 262 |
+
batch=batch_count,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if mode == GemmUniversalMode.GemmSplitKParallel:
|
| 266 |
+
reduction_arguments = ReductionArguments(
|
| 267 |
+
self.reduction_operation,
|
| 268 |
+
problem_size=[problem_size.m, problem_size.n],
|
| 269 |
+
partitions=split_k_slices,
|
| 270 |
+
workspace=arguments.ptr_D,
|
| 271 |
+
destination=tensor_D,
|
| 272 |
+
source=tensor_C,
|
| 273 |
+
output_op=self.reduction_operation.epilogue_type(alpha, beta),
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
self.operation.run(arguments)
|
| 277 |
+
|
| 278 |
+
if mode == GemmUniversalMode.GemmSplitKParallel:
|
| 279 |
+
self.reduction_operation.run(reduction_arguments)
|
| 280 |
+
|
| 281 |
+
passed = True
|
| 282 |
+
|
| 283 |
+
if self.verification:
|
| 284 |
+
if mode == GemmUniversalMode.GemmSplitKParallel:
|
| 285 |
+
reduction_arguments.sync()
|
| 286 |
+
|
| 287 |
+
# Free memory allocated by args because we are not
|
| 288 |
+
# calling `arguments.sync()` in this case (which will free memory)
|
| 289 |
+
arguments.free()
|
| 290 |
+
else:
|
| 291 |
+
arguments.sync()
|
| 292 |
+
tensor_D_ref = self.reference(
|
| 293 |
+
problem_size,
|
| 294 |
+
tensor_A_ref,
|
| 295 |
+
tensor_B_ref,
|
| 296 |
+
tensor_C_ref,
|
| 297 |
+
alpha,
|
| 298 |
+
beta,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
tensor_D_ref = tensor_D_ref.to('cuda')
|
| 302 |
+
|
| 303 |
+
if self.operation.switched or self.operation.C.layout == LayoutType.ColumnMajor:
|
| 304 |
+
tensor_D = tensor_D.transpose(-1, -2).contiguous()
|
| 305 |
+
|
| 306 |
+
passed = tensor_D.equal(tensor_D_ref)
|
| 307 |
+
|
| 308 |
+
try:
|
| 309 |
+
assert passed
|
| 310 |
+
except AssertionError:
|
| 311 |
+
self.print_problem_size(problem_size, mode, batch_count)
|
| 312 |
+
del arguments
|
| 313 |
+
if mode == GemmUniversalMode.GemmSplitKParallel:
|
| 314 |
+
del reduction_arguments
|
| 315 |
+
|
| 316 |
+
return passed
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"):
|
| 320 |
+
passed = True
|
| 321 |
+
|
| 322 |
+
minimum_operand_element_size = min(
|
| 323 |
+
DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]
|
| 324 |
+
)
|
| 325 |
+
opcode_class = operation.tile_description.math_instruction.opcode_class
|
| 326 |
+
|
| 327 |
+
if opcode_class == OpcodeClass.Simt:
|
| 328 |
+
alignment = 1
|
| 329 |
+
else:
|
| 330 |
+
alignment = 128 // minimum_operand_element_size
|
| 331 |
+
|
| 332 |
+
alignment_m = alignment
|
| 333 |
+
alignment_n = alignment
|
| 334 |
+
alignment_k = alignment
|
| 335 |
+
|
| 336 |
+
# INT8 alignment constraints
|
| 337 |
+
if opcode_class == OpcodeClass.Simt:
|
| 338 |
+
A_is_s8 = operation.A.element == DataType.s8
|
| 339 |
+
B_is_s8 = operation.B.element == DataType.s8
|
| 340 |
+
|
| 341 |
+
if A_is_s8 and operation.A.layout == LayoutType.ColumnMajor:
|
| 342 |
+
alignment_m = 4
|
| 343 |
+
if B_is_s8 == DataType.s8 and operation.A.layout == LayoutType.RowMajor:
|
| 344 |
+
alignment_n = 4
|
| 345 |
+
if A_is_s8 and B_is_s8 and (operation.A.layout == LayoutType.RowMajor or operation.B.layout == LayoutType.ColumnMajor):
|
| 346 |
+
alignment_k = 4
|
| 347 |
+
|
| 348 |
+
threadblock_k = operation.tile_description.threadblock_shape[2]
|
| 349 |
+
|
| 350 |
+
assert testcase != "interleaved"
|
| 351 |
+
|
| 352 |
+
supports_split_k = operation.arch < 90 and not operation.swizzling_functor == SwizzlingFunctor.StreamK
|
| 353 |
+
|
| 354 |
+
if testcase == "multistage":
|
| 355 |
+
modes = [GemmUniversalMode.Gemm]
|
| 356 |
+
problem_size_m = [16, 528]
|
| 357 |
+
problem_size_n = [16, 528]
|
| 358 |
+
problem_size_k = [
|
| 359 |
+
threadblock_k,
|
| 360 |
+
threadblock_k * operation.tile_description.stages
|
| 361 |
+
+ operation.tile_description.math_instruction.instruction_shape[2],
|
| 362 |
+
]
|
| 363 |
+
problem_alpha = [1.0]
|
| 364 |
+
problem_beta = [0.0]
|
| 365 |
+
batch_counts = [1]
|
| 366 |
+
else:
|
| 367 |
+
modes = [GemmUniversalMode.Gemm]
|
| 368 |
+
batch_counts = [1, 2, 3, 5, 7]
|
| 369 |
+
if supports_split_k:
|
| 370 |
+
modes.append(GemmUniversalMode.GemmSplitKParallel)
|
| 371 |
+
|
| 372 |
+
problem_size_m = [alignment_m, 512 - 3 * alignment_m]
|
| 373 |
+
problem_size_n = [alignment_n, 512 - 2 * alignment_n]
|
| 374 |
+
if operation.tile_description.stages is None:
|
| 375 |
+
stages_for_k_calc = 7
|
| 376 |
+
else:
|
| 377 |
+
stages_for_k_calc = operation.tile_description.stages
|
| 378 |
+
problem_size_k = [
|
| 379 |
+
alignment_k,
|
| 380 |
+
threadblock_k * stages_for_k_calc - alignment_k,
|
| 381 |
+
threadblock_k * stages_for_k_calc * 3 - alignment_k,
|
| 382 |
+
]
|
| 383 |
+
problem_alpha = [1.0]
|
| 384 |
+
problem_beta = [2.0]
|
| 385 |
+
|
| 386 |
+
testbed = GemmUniversalLauncher(operation, compiler_mode=compilation_mode)
|
| 387 |
+
|
| 388 |
+
for mode in modes:
|
| 389 |
+
for m in problem_size_m:
|
| 390 |
+
for n in problem_size_n:
|
| 391 |
+
for k in problem_size_k:
|
| 392 |
+
for batch_count in batch_counts:
|
| 393 |
+
for alpha in problem_alpha:
|
| 394 |
+
for beta in problem_beta:
|
| 395 |
+
# skip very small K problems
|
| 396 |
+
if testcase == "universal":
|
| 397 |
+
if k // batch_count < 2 * threadblock_k:
|
| 398 |
+
continue
|
| 399 |
+
|
| 400 |
+
problem_size = GemmCoord(m, n, k)
|
| 401 |
+
|
| 402 |
+
if supports_split_k:
|
| 403 |
+
split_k_slices = batch_count
|
| 404 |
+
else:
|
| 405 |
+
split_k_slices = 1
|
| 406 |
+
|
| 407 |
+
overridden_mode = mode
|
| 408 |
+
if mode == GemmUniversalMode.Gemm and batch_count > 1:
|
| 409 |
+
overridden_mode = GemmUniversalMode.Batched
|
| 410 |
+
|
| 411 |
+
passed = testbed.run(
|
| 412 |
+
overridden_mode,
|
| 413 |
+
problem_size,
|
| 414 |
+
batch_count,
|
| 415 |
+
split_k_slices,
|
| 416 |
+
alpha,
|
| 417 |
+
beta,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
if not passed:
|
| 421 |
+
return False
|
| 422 |
+
|
| 423 |
+
return passed
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
import pathlib
|
| 34 |
+
import unittest
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
if __name__ == '__main__':
|
| 38 |
+
loader = unittest.TestLoader()
|
| 39 |
+
script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/'
|
| 40 |
+
tests = loader.discover(script_dir, 'gemm_*.py')
|
| 41 |
+
testRunner = unittest.runner.TextTestRunner()
|
| 42 |
+
results = testRunner.run(tests)
|
| 43 |
+
if not results.wasSuccessful():
|
| 44 |
+
raise Exception('Test cases failed')
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
from cutlass_library import SubstituteTemplate
|
| 34 |
+
|
| 35 |
+
import cutlass_cppgen
|
| 36 |
+
from cutlass_library import (
|
| 37 |
+
DataTypeNames,
|
| 38 |
+
EpilogueScheduleSuffixes,
|
| 39 |
+
KernelScheduleSuffixes,
|
| 40 |
+
LayoutType,
|
| 41 |
+
OpcodeClassNames,
|
| 42 |
+
ShortDataTypeNames,
|
| 43 |
+
ShortLayoutTypeNames
|
| 44 |
+
)
|
| 45 |
+
from cutlass_cppgen.backend import library
|
| 46 |
+
|
| 47 |
+
from gemm_testbed import test_all_gemm
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Layout:
|
| 51 |
+
"""
|
| 52 |
+
Utility class to map transpose and non-transpose terminology to row- and column-major terminology
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
T = LayoutType.RowMajor
|
| 56 |
+
N = LayoutType.ColumnMajor
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class LayoutCombination:
|
| 60 |
+
"""
|
| 61 |
+
Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
NNN = (Layout.N, Layout.N, Layout.N)
|
| 65 |
+
NNT = (Layout.N, Layout.N, Layout.T)
|
| 66 |
+
NTN = (Layout.N, Layout.T, Layout.N)
|
| 67 |
+
NTT = (Layout.N, Layout.T, Layout.T)
|
| 68 |
+
TNN = (Layout.T, Layout.N, Layout.N)
|
| 69 |
+
TNT = (Layout.T, Layout.N, Layout.T)
|
| 70 |
+
TTN = (Layout.T, Layout.T, Layout.N)
|
| 71 |
+
TTT = (Layout.T, Layout.T, Layout.T)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_name(
|
| 75 |
+
layouts,
|
| 76 |
+
alignments,
|
| 77 |
+
element_output,
|
| 78 |
+
element_accumulator,
|
| 79 |
+
element_epilogue,
|
| 80 |
+
cluster_shape,
|
| 81 |
+
threadblock_shape,
|
| 82 |
+
stages,
|
| 83 |
+
element_a,
|
| 84 |
+
element_b,
|
| 85 |
+
element_c,
|
| 86 |
+
arch,
|
| 87 |
+
opclass,
|
| 88 |
+
kernel_schedule=None,
|
| 89 |
+
epilogue_schedule=None,
|
| 90 |
+
suffix="",
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
Generates a procedural name for a test case.
|
| 94 |
+
|
| 95 |
+
:param layouts: indexable container of layouts of A, B, and C operands
|
| 96 |
+
:param alignments: indexable container of alignments of A, B, and C operands
|
| 97 |
+
:param element_output: data type of the output element
|
| 98 |
+
:param element_accumulator: data type used in accumulation
|
| 99 |
+
:param element_epilogue: data type used in computing the epilogue
|
| 100 |
+
:param cluster_shape: indexable container of dimensions of threadblock cluster to be launched
|
| 101 |
+
:param threadblock_shape: indexable container of dimensions of threadblock tiles
|
| 102 |
+
:param stages: number of pipeline stages to use in the kernel
|
| 103 |
+
:type stages: int
|
| 104 |
+
:param element_a: data type of operand A
|
| 105 |
+
:param element_b: data type of operand B
|
| 106 |
+
:param element_c: data type of operand C
|
| 107 |
+
:param arch: compute capability of kernel being generated
|
| 108 |
+
:type arch: int
|
| 109 |
+
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
|
| 110 |
+
:type opclass: cutlass_cppgen.OpcodeClass
|
| 111 |
+
:param kernel_schedule: kernel_schedule type
|
| 112 |
+
:type kernel_schedule: cutlass_cppgen.KernelScheduleType
|
| 113 |
+
:param epilogue_schedule: epilogue_schedule type
|
| 114 |
+
:type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
|
| 115 |
+
:param suffix: additional string to add to the suffix of the name
|
| 116 |
+
:type suffix: str
|
| 117 |
+
|
| 118 |
+
:return: str
|
| 119 |
+
"""
|
| 120 |
+
name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}"
|
| 121 |
+
return SubstituteTemplate(
|
| 122 |
+
name_format,
|
| 123 |
+
{
|
| 124 |
+
"arch": str(arch),
|
| 125 |
+
"eA": DataTypeNames[element_a],
|
| 126 |
+
"eB": DataTypeNames[element_b],
|
| 127 |
+
"eC": DataTypeNames[element_c],
|
| 128 |
+
"lA": ShortLayoutTypeNames[layouts[0]],
|
| 129 |
+
"lB": ShortLayoutTypeNames[layouts[1]],
|
| 130 |
+
"lC": ShortLayoutTypeNames[layouts[2]],
|
| 131 |
+
"opclass": OpcodeClassNames[opclass],
|
| 132 |
+
"acc": DataTypeNames[element_accumulator],
|
| 133 |
+
"cM": str(cluster_shape[0]),
|
| 134 |
+
"cN": str(cluster_shape[1]),
|
| 135 |
+
"cK": str(cluster_shape[2]),
|
| 136 |
+
"tbM": str(threadblock_shape[0]),
|
| 137 |
+
"tbN": str(threadblock_shape[1]),
|
| 138 |
+
"tbK": str(threadblock_shape[2]),
|
| 139 |
+
"stages": str(stages) if stages is not None else "auto",
|
| 140 |
+
"aA": str(alignments[0]),
|
| 141 |
+
"aB": str(alignments[1]),
|
| 142 |
+
"aC": str(alignments[2]),
|
| 143 |
+
"k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule],
|
| 144 |
+
"e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule],
|
| 145 |
+
"suffix": "" if suffix is None else suffix,
|
| 146 |
+
},
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def add_test_gemm(
|
| 151 |
+
cls=None,
|
| 152 |
+
cc=None,
|
| 153 |
+
element=None,
|
| 154 |
+
layouts=None,
|
| 155 |
+
alignments=None,
|
| 156 |
+
element_output=None,
|
| 157 |
+
element_accumulator=None,
|
| 158 |
+
cluster_shape=None,
|
| 159 |
+
threadblock_shape=None,
|
| 160 |
+
warp_count=None,
|
| 161 |
+
stages=None,
|
| 162 |
+
opclass=None,
|
| 163 |
+
swizzle=None,
|
| 164 |
+
kernel_schedule=None,
|
| 165 |
+
epilogue_schedule=None,
|
| 166 |
+
compilation_modes=['nvcc', 'nvrtc'],
|
| 167 |
+
element_A=None,
|
| 168 |
+
element_B=None,
|
| 169 |
+
element_C=None):
|
| 170 |
+
"""
|
| 171 |
+
Create test-running functions with the given specification and set it as a method of ``cls``.
|
| 172 |
+
|
| 173 |
+
:param cls: class to which the generated method will be added
|
| 174 |
+
:type cls: type
|
| 175 |
+
:param cc: compute capability to compile for
|
| 176 |
+
:type cc: int
|
| 177 |
+
:param element: data type of A and B operands
|
| 178 |
+
:type element: cutlass_cppgen.DataType.f16
|
| 179 |
+
:param layouts: layouts of A, B, and C operands
|
| 180 |
+
:type layouts: list or tuple
|
| 181 |
+
:param alignments: alingments of A, B, and C operands
|
| 182 |
+
:type alignments: list or tuple
|
| 183 |
+
:param element_output: data type of the output element
|
| 184 |
+
:type element_output: cutlass_cppgen.DataType
|
| 185 |
+
:param element_accumulator: data type used in accumulation
|
| 186 |
+
:type element_accumulator: cutlass_cppgen.DataType
|
| 187 |
+
:param cluster_shape: dimensions of clusters
|
| 188 |
+
:type cluster_shape: list or tuple
|
| 189 |
+
:param threadblock_shape: dimensions of threadblock tiles
|
| 190 |
+
:type threadblock_shape: list or tuple
|
| 191 |
+
:param warp_count: warps to be launched per threadblock dimension
|
| 192 |
+
:type warp_count: list or tuple
|
| 193 |
+
:param stages: number of pipeline stages to use in the kernel
|
| 194 |
+
:type stages: int
|
| 195 |
+
:param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
|
| 196 |
+
:type opclass: cutlass_cppgen.OpcodeClass
|
| 197 |
+
:param swizzle: threadblock swizzling functor
|
| 198 |
+
:param kernel_schedule: kernel schedule to use
|
| 199 |
+
:type kernel_schedule: cutlass_cppgen.KernelScheduleType
|
| 200 |
+
:param epilogue_schedule: epilogue schedule to use
|
| 201 |
+
:type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
|
| 202 |
+
:param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc')
|
| 203 |
+
:type compilation_modes: list,
|
| 204 |
+
:param element_A: data type of operand A. If set, overrides ``element``
|
| 205 |
+
:type element_A: cutlass_cppgen.DataType
|
| 206 |
+
:param element_B: data type of operand B. If set, overrides ``element``
|
| 207 |
+
:type element_B: cutlass_cppgen.DataType
|
| 208 |
+
:param element_C: data type of operand C. If set, overrides ``element``
|
| 209 |
+
:type element_C: cutlass_cppgen.DataType
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
if element_A is None:
|
| 213 |
+
element_A = element
|
| 214 |
+
if element_B is None:
|
| 215 |
+
element_B = element
|
| 216 |
+
if element_C is None:
|
| 217 |
+
element_C = element
|
| 218 |
+
if element_output is None:
|
| 219 |
+
element_output = element
|
| 220 |
+
if element_accumulator is None:
|
| 221 |
+
element_accumulator = element
|
| 222 |
+
|
| 223 |
+
for compilation_mode in compilation_modes:
|
| 224 |
+
def run(self):
|
| 225 |
+
"""
|
| 226 |
+
Dynamically-generated function that constructs a GEMM operation and verifies it against
|
| 227 |
+
multiple test cases.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
layout_A, layout_B, layout_C = layouts
|
| 231 |
+
alignment_A, alignment_B, alignment_C = alignments
|
| 232 |
+
|
| 233 |
+
plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B,
|
| 234 |
+
element_C=element_C, element_D=element_output,
|
| 235 |
+
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
|
| 236 |
+
element_accumulator=element_accumulator,
|
| 237 |
+
kernel_cc=cc)
|
| 238 |
+
|
| 239 |
+
plan.opclass = opclass
|
| 240 |
+
if swizzle is not None:
|
| 241 |
+
plan.swizzling_functor = swizzle
|
| 242 |
+
|
| 243 |
+
td = plan.tile_descriptions()[0]
|
| 244 |
+
|
| 245 |
+
if warp_count is not None:
|
| 246 |
+
td.warp_count = warp_count
|
| 247 |
+
td.threadblock_shape = threadblock_shape
|
| 248 |
+
td.stages = stages
|
| 249 |
+
td.cluster_shape = cluster_shape
|
| 250 |
+
op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C)
|
| 251 |
+
self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode))
|
| 252 |
+
|
| 253 |
+
element_epilogue = element_accumulator
|
| 254 |
+
name = get_name(
|
| 255 |
+
layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator,
|
| 256 |
+
element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape,
|
| 257 |
+
stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass,
|
| 258 |
+
kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}')
|
| 259 |
+
|
| 260 |
+
setattr(cls, name, run)
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/installation.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Tests for a successful installation of the CUTLASS Python interface
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import os
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
import cutlass_library
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class InstallationTest(unittest.TestCase):
|
| 45 |
+
def test_cutlass_source_paths(self):
|
| 46 |
+
"""
|
| 47 |
+
Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages
|
| 48 |
+
"""
|
| 49 |
+
src_file = 'include/cutlass/cutlass.h'
|
| 50 |
+
library_file = os.path.join(cutlass_library.source_path, src_file)
|
| 51 |
+
cutlass_file = os.path.join(cutlass_cppgen.CUTLASS_PATH, src_file)
|
| 52 |
+
assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded."
|
| 53 |
+
assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded."
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
if __name__ == "__main__":
|
| 57 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Tests the high-level Conv2d interface
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from math import ceil
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
import cutlass_cppgen.utils.datatypes as datatypes
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
from utils import ExpectException
|
| 44 |
+
import os
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Conv2dEquivalence:
|
| 48 |
+
"""
|
| 49 |
+
Helper class for testing the equivalence of different constructions of the Conv2d interface
|
| 50 |
+
"""
|
| 51 |
+
def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator,
|
| 52 |
+
alignment_A, alignment_B, alignment_C):
|
| 53 |
+
|
| 54 |
+
self.element_A = element_A
|
| 55 |
+
self.element_B = element_B
|
| 56 |
+
self.element_C = element_C
|
| 57 |
+
self.element_D = element_D
|
| 58 |
+
self.element_accumulator = element_accumulator
|
| 59 |
+
self.alignment_A = alignment_A
|
| 60 |
+
self.alignment_B = alignment_B
|
| 61 |
+
self.alignment_C = alignment_C
|
| 62 |
+
|
| 63 |
+
self.conv_kind = conv_kind
|
| 64 |
+
|
| 65 |
+
self.plan = cutlass_cppgen.op.Conv2d(
|
| 66 |
+
kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C,
|
| 67 |
+
element_D=element_D, element_accumulator=element_accumulator)
|
| 68 |
+
|
| 69 |
+
self.op = self.plan.construct(
|
| 70 |
+
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
|
| 71 |
+
alignment_C=self.alignment_C)
|
| 72 |
+
|
| 73 |
+
def _plans_equal(self, other_plan) -> bool:
|
| 74 |
+
"""
|
| 75 |
+
Compares whether two plans are equal
|
| 76 |
+
|
| 77 |
+
:param other_plan: plan to compare against the default Conv2d
|
| 78 |
+
:type other_plan: cutlass_cppgen.op.Conv2d
|
| 79 |
+
|
| 80 |
+
:return: whether `other_plan` is equivalent to `self.plan`
|
| 81 |
+
:rtype: bool
|
| 82 |
+
"""
|
| 83 |
+
other_op = other_plan.construct(
|
| 84 |
+
alignment_A=self.alignment_A, alignment_B=self.alignment_B,
|
| 85 |
+
alignment_C=self.alignment_C)
|
| 86 |
+
|
| 87 |
+
return self.op.rt_module.emit() == other_op.rt_module.emit()
|
| 88 |
+
|
| 89 |
+
def generic_test(self):
|
| 90 |
+
"""
|
| 91 |
+
Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types
|
| 92 |
+
and layouts for constructing the Conv2d interface
|
| 93 |
+
"""
|
| 94 |
+
if not datatypes.is_numpy_available():
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
# Test when specifying all parameters
|
| 98 |
+
plan_other = cutlass_cppgen.op.Conv2d(
|
| 99 |
+
kind=self.conv_kind,
|
| 100 |
+
element_A=self.element_A, element_B=self.element_B, element_C=self.element_C,
|
| 101 |
+
element_D=self.element_D, element_accumulator=self.element_accumulator)
|
| 102 |
+
assert self._plans_equal(plan_other)
|
| 103 |
+
|
| 104 |
+
# Test when specifying all parameters but A
|
| 105 |
+
plan_other = cutlass_cppgen.op.Conv2d(
|
| 106 |
+
kind=self.conv_kind,
|
| 107 |
+
element_B=self.element_B, element_C=self.element_C,
|
| 108 |
+
element_D=self.element_D, element_accumulator=self.element_accumulator,
|
| 109 |
+
element=self.element_A)
|
| 110 |
+
assert self._plans_equal(plan_other)
|
| 111 |
+
|
| 112 |
+
# Test when specifying all parameters but A and B as tensors using generic element and output
|
| 113 |
+
plan_other = cutlass_cppgen.op.Conv2d(
|
| 114 |
+
kind=self.conv_kind,
|
| 115 |
+
element_C=self.element_C,
|
| 116 |
+
element_D=self.element_D, element_accumulator=self.element_accumulator,
|
| 117 |
+
element=self.element_A)
|
| 118 |
+
assert self._plans_equal(plan_other)
|
| 119 |
+
|
| 120 |
+
# Test without explicit accumulator. Only run if the type of C and the accumulator are equal
|
| 121 |
+
if self.element_C == self.element_accumulator:
|
| 122 |
+
plan_other = cutlass_cppgen.op.Conv2d(
|
| 123 |
+
kind=self.conv_kind,
|
| 124 |
+
element_C=self.element_C,
|
| 125 |
+
element_D=self.element_D,
|
| 126 |
+
element=self.element_A)
|
| 127 |
+
assert self._plans_equal(plan_other)
|
| 128 |
+
|
| 129 |
+
# Test with only the generic types. Only rune if the types of A, B, C, and D are the same
|
| 130 |
+
if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D
|
| 131 |
+
and self.element_A == self.element_accumulator):
|
| 132 |
+
plan_other = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=self.element_A)
|
| 133 |
+
assert self._plans_equal(plan_other)
|
| 134 |
+
|
| 135 |
+
def numpy_test(self):
|
| 136 |
+
"""
|
| 137 |
+
Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend
|
| 138 |
+
"""
|
| 139 |
+
if not datatypes.is_numpy_available():
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
import numpy as np
|
| 143 |
+
type_A = datatypes.numpy_type(self.element_A)
|
| 144 |
+
type_B = datatypes.numpy_type(self.element_B)
|
| 145 |
+
type_C = datatypes.numpy_type(self.element_C)
|
| 146 |
+
type_D = datatypes.numpy_type(self.element_D)
|
| 147 |
+
type_accum = datatypes.numpy_type(self.element_accumulator)
|
| 148 |
+
|
| 149 |
+
size = (2, 2)
|
| 150 |
+
A = np.zeros(size, dtype=type_A)
|
| 151 |
+
B = np.zeros(size, dtype=type_B)
|
| 152 |
+
C = np.zeros(size, dtype=type_C)
|
| 153 |
+
D = np.zeros(size, dtype=type_D)
|
| 154 |
+
|
| 155 |
+
return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D)
|
| 156 |
+
|
| 157 |
+
def torch_test(self):
|
| 158 |
+
"""
|
| 159 |
+
Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend
|
| 160 |
+
"""
|
| 161 |
+
if not datatypes.is_torch_available():
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
import torch
|
| 165 |
+
type_A = datatypes.torch_type(self.element_A)
|
| 166 |
+
type_B = datatypes.torch_type(self.element_B)
|
| 167 |
+
type_C = datatypes.torch_type(self.element_C)
|
| 168 |
+
type_D = datatypes.torch_type(self.element_D)
|
| 169 |
+
type_accum = datatypes.torch_type(self.element_accumulator)
|
| 170 |
+
|
| 171 |
+
size = (2, 2)
|
| 172 |
+
|
| 173 |
+
A = torch.empty(size, dtype=type_A)
|
| 174 |
+
B = torch.empty(size, dtype=type_B)
|
| 175 |
+
C = torch.empty(size, dtype=type_C)
|
| 176 |
+
D = torch.empty(size, dtype=type_D)
|
| 177 |
+
|
| 178 |
+
return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D)
|
| 179 |
+
|
| 180 |
+
def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D):
|
| 181 |
+
# Test when specifying all parameters via tensors
|
| 182 |
+
plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum)
|
| 183 |
+
assert self._plans_equal(plan_np)
|
| 184 |
+
|
| 185 |
+
# Test when specifying all parameters but A as tensors
|
| 186 |
+
plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A)
|
| 187 |
+
assert self._plans_equal(plan_np)
|
| 188 |
+
|
| 189 |
+
# Test when specifying all parameters but A and B as tensors and using generic element and output
|
| 190 |
+
if type_A == type_B:
|
| 191 |
+
plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A)
|
| 192 |
+
assert self._plans_equal(plan_np)
|
| 193 |
+
|
| 194 |
+
# Test without explicit accumulator. Only run if the type of C and the accumulator.
|
| 195 |
+
if type_C == type_accum:
|
| 196 |
+
plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D)
|
| 197 |
+
assert self._plans_equal(plan_np)
|
| 198 |
+
|
| 199 |
+
# Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same.
|
| 200 |
+
if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum):
|
| 201 |
+
plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=type_A)
|
| 202 |
+
assert self._plans_equal(plan_np)
|
| 203 |
+
|
| 204 |
+
def test_all(self):
|
| 205 |
+
"""
|
| 206 |
+
Runs all tests on the Gemm interface
|
| 207 |
+
"""
|
| 208 |
+
self.generic_test()
|
| 209 |
+
self.numpy_test()
|
| 210 |
+
self.torch_test()
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.')
|
| 214 |
+
class ConvEquivalenceTest(unittest.TestCase):
|
| 215 |
+
"""
|
| 216 |
+
Tests the equivalence of different constructions of the Conv2d interface
|
| 217 |
+
"""
|
| 218 |
+
pass
|
| 219 |
+
|
| 220 |
+
type2alignment = {
|
| 221 |
+
cutlass_cppgen.DataType.f16: 8,
|
| 222 |
+
cutlass_cppgen.DataType.f32: 4
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator):
|
| 226 |
+
|
| 227 |
+
test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}"
|
| 228 |
+
|
| 229 |
+
def run(self):
|
| 230 |
+
conv2d_eq = Conv2dEquivalence(
|
| 231 |
+
conv_kind=conv_kind,
|
| 232 |
+
element_A=element_A, element_B=element_B,
|
| 233 |
+
element_C=element_C, element_D=element_D,
|
| 234 |
+
element_accumulator=element_accumulator,
|
| 235 |
+
alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B],
|
| 236 |
+
alignment_C=type2alignment[element_C]
|
| 237 |
+
)
|
| 238 |
+
conv2d_eq.test_all()
|
| 239 |
+
|
| 240 |
+
setattr(ConvEquivalenceTest, test_name, run)
|
| 241 |
+
|
| 242 |
+
for conv_kind in ["fprop", "wgrad", "dgrad"]:
|
| 243 |
+
for types in [
|
| 244 |
+
[cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16],
|
| 245 |
+
[cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32],
|
| 246 |
+
[cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16],
|
| 247 |
+
[cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32],
|
| 248 |
+
[cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32]
|
| 249 |
+
]:
|
| 250 |
+
add_test(conv_kind, types[0], types[1], types[2], types[3], types[4])
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.')
|
| 254 |
+
class Conv2dErrorTests(unittest.TestCase):
|
| 255 |
+
"""
|
| 256 |
+
Tests various error scenarios that arise with the high-level Gemm interface
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
def test_alignment(self):
|
| 260 |
+
"""
|
| 261 |
+
Tests case in which the alignment specified is unsupported
|
| 262 |
+
"""
|
| 263 |
+
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16)
|
| 264 |
+
|
| 265 |
+
with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'):
|
| 266 |
+
op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3)
|
| 267 |
+
|
| 268 |
+
def test_invalid_tile_description(self):
|
| 269 |
+
"""
|
| 270 |
+
Tests scenarios in which an invalid tile description is provided for a given CC
|
| 271 |
+
"""
|
| 272 |
+
plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16)
|
| 273 |
+
|
| 274 |
+
td = plan.tile_descriptions()[0]
|
| 275 |
+
td.threadblock_shape=[17, 32, 5]
|
| 276 |
+
|
| 277 |
+
plan.tile_description = td
|
| 278 |
+
with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'):
|
| 279 |
+
plan.compile()
|
| 280 |
+
# Clean up the error message
|
| 281 |
+
os.remove("./cutlass_python_compilation_device_error.txt")
|
| 282 |
+
|
| 283 |
+
if __name__ == '__main__':
|
| 284 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Test the EVT interface
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import numpy as np
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
from cutlass_cppgen import LayoutType, Tensor
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
from cutlass_cppgen.epilogue import reshape, permute
|
| 44 |
+
|
| 45 |
+
from utils import ExpectException
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
|
| 49 |
+
class EVTErrorTests(unittest.TestCase):
|
| 50 |
+
"""
|
| 51 |
+
Tests various error scenarios that arise with the EVT interface
|
| 52 |
+
"""
|
| 53 |
+
@unittest.skipIf(device_cc() != 90, "Only Sm90 EVT requires root node be 'D'")
|
| 54 |
+
def test_root_not_d(self):
|
| 55 |
+
"""
|
| 56 |
+
Test when "D" does not exist in Sm90 EVT
|
| 57 |
+
"""
|
| 58 |
+
def evt_root_not_d(accum, alpha):
|
| 59 |
+
F = accum * alpha
|
| 60 |
+
return F
|
| 61 |
+
|
| 62 |
+
example_tensors = {
|
| 63 |
+
"accum": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 64 |
+
"alpha": 1.2,
|
| 65 |
+
"F": self.fake_tensor(np.float16, (6, 512, 512))
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
with ExpectException(device_cc() == 90,
|
| 69 |
+
"SyntaxError: Sm90 EVT requires the epilogue to have a returned tensor D, "
|
| 70 |
+
"but the variable 'D' is not found in the return values.", True):
|
| 71 |
+
|
| 72 |
+
cutlass_cppgen.epilogue.trace(evt_root_not_d, example_tensors)
|
| 73 |
+
|
| 74 |
+
def test_no_accum(self):
|
| 75 |
+
"""
|
| 76 |
+
Test when "accum" is not in input arguments
|
| 77 |
+
"""
|
| 78 |
+
def evt_no_accum(alpha, C):
|
| 79 |
+
D = alpha * C
|
| 80 |
+
return D
|
| 81 |
+
|
| 82 |
+
example_tensors = {
|
| 83 |
+
"C": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 84 |
+
"alpha": 1.2,
|
| 85 |
+
"D": self.fake_tensor(np.float16, (6, 512, 512))
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
with ExpectException(True, "SyntaxError: Cannot find 'accum' in the argument list.", True):
|
| 89 |
+
cutlass_cppgen.epilogue.trace(evt_no_accum, example_tensors)
|
| 90 |
+
|
| 91 |
+
@unittest.skipIf(device_cc() != 90, "Only Sm90 EVT has concern on smem size")
|
| 92 |
+
def test_too_much_shared_memory(self):
|
| 93 |
+
"""
|
| 94 |
+
Test when the epilogue consumes too much shared memory
|
| 95 |
+
"""
|
| 96 |
+
def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8):
|
| 97 |
+
D1 = accum + C1
|
| 98 |
+
D2 = D1 + C2
|
| 99 |
+
D3 = D2 + C3
|
| 100 |
+
D4 = D3 + C4
|
| 101 |
+
D5 = D4 + C5
|
| 102 |
+
D6 = D5 + C6
|
| 103 |
+
D7 = D6 + C7
|
| 104 |
+
D = D7 + C8
|
| 105 |
+
return D, D1, D2, D3, D4, D5, D6, D7
|
| 106 |
+
|
| 107 |
+
example_tensors = {
|
| 108 |
+
"accum": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 109 |
+
"C1": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 110 |
+
"C2": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 111 |
+
"C3": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 112 |
+
"C4": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 113 |
+
"C5": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 114 |
+
"C6": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 115 |
+
"C7": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 116 |
+
"C8": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 117 |
+
"D1": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 118 |
+
"D2": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 119 |
+
"D3": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 120 |
+
"D4": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 121 |
+
"D5": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 122 |
+
"D6": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 123 |
+
"D7": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 124 |
+
"D": self.fake_tensor(np.float16, (6, 512, 512))
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_too_much_shared_memory, example_tensors)
|
| 128 |
+
|
| 129 |
+
plan = cutlass_cppgen.op.Gemm(
|
| 130 |
+
element=np.float16, layout=cutlass_cppgen.LayoutType.RowMajor,
|
| 131 |
+
element_accumulator=np.float32
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
with ExpectException(True,
|
| 135 |
+
"RuntimeError: The epilogue consumes too much shared memory. "
|
| 136 |
+
"No valid tile description is found in the generator.", True):
|
| 137 |
+
plan.epilogue_visitor = epilogue_visitor
|
| 138 |
+
|
| 139 |
+
def test_not_ssa(self):
|
| 140 |
+
"""
|
| 141 |
+
Test when the epilogue is not in SSA
|
| 142 |
+
"""
|
| 143 |
+
def evt_redefine(accum, C, alpha):
|
| 144 |
+
F = accum + C
|
| 145 |
+
F = F * alpha
|
| 146 |
+
D = F
|
| 147 |
+
return D, F
|
| 148 |
+
|
| 149 |
+
example_tensors = {
|
| 150 |
+
"accum": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 151 |
+
"C": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 152 |
+
"alpha": 1.5,
|
| 153 |
+
"D": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 154 |
+
"F": self.fake_tensor(np.float16, (6, 512, 512))
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
with ExpectException(True, "SyntaxError: Variable 'F' cannot be defined twice.", True):
|
| 158 |
+
cutlass_cppgen.epilogue.trace(evt_redefine, example_tensors)
|
| 159 |
+
|
| 160 |
+
def evt_undefine(accum, alpha):
|
| 161 |
+
F = accum + C
|
| 162 |
+
D = F * alpha
|
| 163 |
+
return D, F
|
| 164 |
+
|
| 165 |
+
example_tensors = {
|
| 166 |
+
"accum": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 167 |
+
"alpha": 1.5,
|
| 168 |
+
"D": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 169 |
+
"F": self.fake_tensor(np.float16, (6, 512, 512))
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
with ExpectException(True, "SyntaxError: Variable 'C' is undefined.", True):
|
| 173 |
+
cutlass_cppgen.epilogue.trace(evt_undefine, example_tensors)
|
| 174 |
+
|
| 175 |
+
def test_missing_example_tensor(self):
|
| 176 |
+
"""
|
| 177 |
+
Test when the example tensor of an input/output variable is not provided
|
| 178 |
+
"""
|
| 179 |
+
def evt_missing_example_tensor(accum, C):
|
| 180 |
+
D = accum + C
|
| 181 |
+
return D
|
| 182 |
+
|
| 183 |
+
example_tensors = {
|
| 184 |
+
"accum": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 185 |
+
"C": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
with ExpectException(True, "RuntimeError: Example input for D is not provided.", True):
|
| 189 |
+
cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors)
|
| 190 |
+
|
| 191 |
+
example_tensors = {
|
| 192 |
+
"accum": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 193 |
+
"D": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
with ExpectException(True, "RuntimeError: Example input for C is not provided.", True):
|
| 197 |
+
cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors)
|
| 198 |
+
|
| 199 |
+
def test_return_expression(self):
|
| 200 |
+
"""
|
| 201 |
+
Test when the return value is an expression
|
| 202 |
+
"""
|
| 203 |
+
def evt_return_expr(accum, C):
|
| 204 |
+
return accum + C
|
| 205 |
+
|
| 206 |
+
example_tensors = {
|
| 207 |
+
"accum": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 208 |
+
"C": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
with ExpectException(True, "SyntaxError: Return value cannot be an expression", True):
|
| 212 |
+
cutlass_cppgen.epilogue.trace(evt_return_expr, example_tensors)
|
| 213 |
+
|
| 214 |
+
def test_incompatible_shape(self):
|
| 215 |
+
"""
|
| 216 |
+
Test when the shape of example tensors are incompatible
|
| 217 |
+
"""
|
| 218 |
+
def evt_incompatible_shape(accum, C):
|
| 219 |
+
D = accum + C
|
| 220 |
+
return D
|
| 221 |
+
|
| 222 |
+
example_tensors = {
|
| 223 |
+
"accum": self.fake_tensor(np.float16, (6, 256, 512)),
|
| 224 |
+
"C": self.fake_tensor(np.float16, (6, 512, 512)),
|
| 225 |
+
"D": self.fake_tensor(np.float16, (6, 512, 512))
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
with ExpectException(True,
|
| 229 |
+
"RuntimeError: Dimension mismatch between accum(6, 256, 512), C(6, 512, 512).", True):
|
| 230 |
+
cutlass_cppgen.epilogue.trace(evt_incompatible_shape, example_tensors)
|
| 231 |
+
|
| 232 |
+
def test_no_matching_impl(self):
|
| 233 |
+
def evt_no_matching_impl(accum, bias):
|
| 234 |
+
D = accum + reshape(permute(bias, indices=(1, 0)), new_shape=(512, 1))
|
| 235 |
+
return D
|
| 236 |
+
|
| 237 |
+
example_tensors = {
|
| 238 |
+
"accum": self.fake_tensor(np.float16, (6, 512, 256)),
|
| 239 |
+
"bias": self.fake_tensor(np.float16, (16, 32)),
|
| 240 |
+
"D": self.fake_tensor(np.float16, (6, 512, 256))
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
with ExpectException(True, "NotImplementedError: No matching op for node bias with stride (0, (1, 32), 0).", True):
|
| 244 |
+
cutlass_cppgen.epilogue.trace(evt_no_matching_impl, example_tensors)
|
| 245 |
+
#
|
| 246 |
+
# Helper functions
|
| 247 |
+
#
|
| 248 |
+
|
| 249 |
+
def fake_tensor(self, element, shape):
|
| 250 |
+
return Tensor(element=element, shape=shape, layout_tag=LayoutType.RowMajor)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == '__main__':
|
| 254 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Tests the high-level GEMM interface
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from math import ceil
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
import cutlass_cppgen
|
| 41 |
+
import cutlass_cppgen.utils.datatypes as datatypes
|
| 42 |
+
from cutlass_cppgen.backend.utils.device import device_cc
|
| 43 |
+
from utils import ExpectException
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class GemmEquivalence:
|
| 47 |
+
"""
|
| 48 |
+
Helper class for testing the equivalence of different constructions of the Gemm interface
|
| 49 |
+
"""
|
| 50 |
+
def __init__(self, element_A, element_B, element_C, element_D, element_accumulator,
|
| 51 |
+
layout_A, layout_B, layout_C, alignment_A, alignment_B, alignment_C):
|
| 52 |
+
self.element_A = element_A
|
| 53 |
+
self.element_B = element_B
|
| 54 |
+
self.element_C = element_C
|
| 55 |
+
self.element_D = element_D
|
| 56 |
+
self.element_accumulator = element_accumulator
|
| 57 |
+
self.layout_A = layout_A
|
| 58 |
+
self.layout_B = layout_B
|
| 59 |
+
self.layout_C = layout_C
|
| 60 |
+
self.alignment_A = alignment_A
|
| 61 |
+
self.alignment_B = alignment_B
|
| 62 |
+
self.alignment_C = alignment_C
|
| 63 |
+
self.plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C,
|
| 64 |
+
element_D=element_D, element_accumulator=element_accumulator,
|
| 65 |
+
layout_A=layout_A, layout_B=layout_B, layout_C=layout_C)
|
| 66 |
+
self.op = self.plan.construct(alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C)
|
| 67 |
+
|
| 68 |
+
def _plans_equal(self, other_plan) -> bool:
|
| 69 |
+
"""
|
| 70 |
+
Compares whether two plans are equal
|
| 71 |
+
|
| 72 |
+
:param other_plan: plan to compare against the default GEMM
|
| 73 |
+
:type other_plan: cutlass_cppgen.op.Gemm
|
| 74 |
+
|
| 75 |
+
:return: whether `other_plan` is equivalent to `self.plan`
|
| 76 |
+
:rtype: bool
|
| 77 |
+
"""
|
| 78 |
+
other_op = other_plan.construct(alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C)
|
| 79 |
+
|
| 80 |
+
# Compare whether the operations are equal by comparing the C++ code that would be emitted for them
|
| 81 |
+
return self.op.rt_module.emit() == other_op.rt_module.emit()
|
| 82 |
+
|
| 83 |
+
def generic_test(self):
|
| 84 |
+
"""
|
| 85 |
+
Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types
|
| 86 |
+
and layouts for constructing the Gemm interface
|
| 87 |
+
"""
|
| 88 |
+
if not datatypes.is_numpy_available():
|
| 89 |
+
return
|
| 90 |
+
|
| 91 |
+
# Test when specifying all parameters
|
| 92 |
+
plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C,
|
| 93 |
+
element_D=self.element_D, element_accumulator=self.element_accumulator,
|
| 94 |
+
layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C)
|
| 95 |
+
assert self._plans_equal(plan_other)
|
| 96 |
+
|
| 97 |
+
# Test when specifying all parameters but A
|
| 98 |
+
plan_other = cutlass_cppgen.op.Gemm(element_B=self.element_B, element_C=self.element_C,
|
| 99 |
+
element_D=self.element_D, element_accumulator=self.element_accumulator,
|
| 100 |
+
layout_B=self.layout_B, layout_C=self.layout_C,
|
| 101 |
+
element=self.element_A, layout=self.layout_A)
|
| 102 |
+
assert self._plans_equal(plan_other)
|
| 103 |
+
|
| 104 |
+
# Test when specifying all parameters but A and B as tensors and using generic element and output
|
| 105 |
+
# Only run this test if the layouts and types for A and B are equal.
|
| 106 |
+
if self.element_A == self.element_B and self.layout_A == self.layout_B:
|
| 107 |
+
plan_other = cutlass_cppgen.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator,
|
| 108 |
+
layout_C=self.layout_C, element=self.element_A, layout=self.layout_A)
|
| 109 |
+
assert self._plans_equal(plan_other)
|
| 110 |
+
|
| 111 |
+
# Test without explicit accumulator. Only run if the type of C and the accumulator.
|
| 112 |
+
if self.element_C == self.element_accumulator:
|
| 113 |
+
plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C,
|
| 114 |
+
element_D=self.element_D, layout_A=self.layout_A, layout_B=self.layout_B,
|
| 115 |
+
layout_C=self.layout_C)
|
| 116 |
+
assert self._plans_equal(plan_other)
|
| 117 |
+
|
| 118 |
+
# Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same.
|
| 119 |
+
if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D
|
| 120 |
+
and self.element_A == self.element_accumulator and
|
| 121 |
+
self.layout_A == self.layout_B and self.layout_A == self.layout_C):
|
| 122 |
+
plan_other = cutlass_cppgen.op.Gemm(element=self.element_A, layout=self.layout_A)
|
| 123 |
+
assert self._plans_equal(plan_other)
|
| 124 |
+
|
| 125 |
+
def numpy_test(self):
|
| 126 |
+
"""
|
| 127 |
+
Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend
|
| 128 |
+
"""
|
| 129 |
+
if not datatypes.is_numpy_available():
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
import numpy as np
|
| 133 |
+
type_A = datatypes.numpy_type(self.element_A)
|
| 134 |
+
type_B = datatypes.numpy_type(self.element_B)
|
| 135 |
+
type_C = datatypes.numpy_type(self.element_C)
|
| 136 |
+
type_D = datatypes.numpy_type(self.element_D)
|
| 137 |
+
type_accum = datatypes.numpy_type(self.element_accumulator)
|
| 138 |
+
|
| 139 |
+
layout_to_order = {
|
| 140 |
+
cutlass_cppgen.LayoutType.RowMajor: 'C',
|
| 141 |
+
cutlass_cppgen.LayoutType.ColumnMajor: 'F'
|
| 142 |
+
}
|
| 143 |
+
size = (2, 2)
|
| 144 |
+
A = np.zeros(size, order=layout_to_order[self.layout_A], dtype=type_A)
|
| 145 |
+
B = np.zeros(size, order=layout_to_order[self.layout_B], dtype=type_B)
|
| 146 |
+
C = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_C)
|
| 147 |
+
D = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_D)
|
| 148 |
+
|
| 149 |
+
# Test when specifying all parameters via tensors
|
| 150 |
+
plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum)
|
| 151 |
+
assert self._plans_equal(plan_np)
|
| 152 |
+
|
| 153 |
+
# Test when specifying all parameters but A as tensors
|
| 154 |
+
plan_np = cutlass_cppgen.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A)
|
| 155 |
+
assert self._plans_equal(plan_np)
|
| 156 |
+
|
| 157 |
+
# Test when specifying all parameters but A and B as tensors and using generic element and output
|
| 158 |
+
# Only run this test if the layouts and types for A and B are equal.
|
| 159 |
+
if type_A == type_B and self.layout_A == self.layout_B:
|
| 160 |
+
plan_np = cutlass_cppgen.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A)
|
| 161 |
+
assert self._plans_equal(plan_np)
|
| 162 |
+
|
| 163 |
+
# Test without explicit accumulator. Only run if the type of C and the accumulator.
|
| 164 |
+
if type_C == type_accum:
|
| 165 |
+
plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D)
|
| 166 |
+
assert self._plans_equal(plan_np)
|
| 167 |
+
|
| 168 |
+
# Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same.
|
| 169 |
+
if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum and
|
| 170 |
+
self.layout_A == self.layout_B and self.layout_A == self.layout_C):
|
| 171 |
+
plan_np = cutlass_cppgen.op.Gemm(element=type_A, layout=self.layout_A)
|
| 172 |
+
assert self._plans_equal(plan_np)
|
| 173 |
+
|
| 174 |
+
def test_all(self):
|
| 175 |
+
"""
|
| 176 |
+
Runs all tests on the Gemm interface
|
| 177 |
+
"""
|
| 178 |
+
self.generic_test()
|
| 179 |
+
self.numpy_test()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class GemmEquivalenceTest(unittest.TestCase):
|
| 183 |
+
"""
|
| 184 |
+
Tests the equivalence of different constructions of the Gemm interface
|
| 185 |
+
"""
|
| 186 |
+
@unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.")
|
| 187 |
+
def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_8_8_8(self):
|
| 188 |
+
gemm_eq = GemmEquivalence(
|
| 189 |
+
element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 190 |
+
element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16,
|
| 191 |
+
layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor,
|
| 192 |
+
alignment_A=8, alignment_B=8, alignment_C=8)
|
| 193 |
+
gemm_eq.test_all()
|
| 194 |
+
|
| 195 |
+
@unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.")
|
| 196 |
+
def test_gemm_equivalence_f16_f16_f16_f16_f32_ntn_8_8_8(self):
|
| 197 |
+
gemm_eq = GemmEquivalence(
|
| 198 |
+
element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 199 |
+
element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32,
|
| 200 |
+
layout_A=cutlass_cppgen.LayoutType.ColumnMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.ColumnMajor,
|
| 201 |
+
alignment_A=8, alignment_B=8, alignment_C=8)
|
| 202 |
+
gemm_eq.test_all()
|
| 203 |
+
|
| 204 |
+
@unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.")
|
| 205 |
+
def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_4_4_4(self):
|
| 206 |
+
gemm_eq = GemmEquivalence(
|
| 207 |
+
element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
|
| 208 |
+
element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16,
|
| 209 |
+
layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor,
|
| 210 |
+
alignment_A=8, alignment_B=8, alignment_C=8)
|
| 211 |
+
gemm_eq.test_all()
|
| 212 |
+
|
| 213 |
+
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for F64 Tensor Core tests.")
|
| 214 |
+
def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self):
|
| 215 |
+
gemm_eq = GemmEquivalence(
|
| 216 |
+
element_A=cutlass_cppgen.DataType.f64, element_B=cutlass_cppgen.DataType.f64, element_C=cutlass_cppgen.DataType.f64,
|
| 217 |
+
element_D=cutlass_cppgen.DataType.f64, element_accumulator=cutlass_cppgen.DataType.f64,
|
| 218 |
+
layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.ColumnMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor,
|
| 219 |
+
alignment_A=1, alignment_B=1, alignment_C=1)
|
| 220 |
+
gemm_eq.test_all()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class GemmErrorTests(unittest.TestCase):
|
| 224 |
+
"""
|
| 225 |
+
Tests various error scenarios that arise with the high-level Gemm interface
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def test_alignment(self):
|
| 229 |
+
"""
|
| 230 |
+
Tests case in which the alignment specified is unsupported
|
| 231 |
+
"""
|
| 232 |
+
plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 233 |
+
|
| 234 |
+
with ExpectException(True, 'Alignment 16 is not supported for F16. The construction should fail.'):
|
| 235 |
+
op = plan.construct(alignment_A=16, alignment_B=16, alignment_C=16)
|
| 236 |
+
|
| 237 |
+
def test_tensorop_availability(self):
|
| 238 |
+
"""
|
| 239 |
+
Tests case in which only SIMT operations are available but TensorOp is requested
|
| 240 |
+
"""
|
| 241 |
+
cc = device_cc()
|
| 242 |
+
|
| 243 |
+
# F64 Tensor Core operations are only avaiable on certain devices
|
| 244 |
+
supports_tensorop_f64 = cc in [80, 89, 90]
|
| 245 |
+
plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f64, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 246 |
+
|
| 247 |
+
error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}'
|
| 248 |
+
with ExpectException(not supports_tensorop_f64, error_msg):
|
| 249 |
+
plan.opclass = cutlass_cppgen.OpcodeClass.TensorOp
|
| 250 |
+
|
| 251 |
+
expected_opclass = cutlass_cppgen.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass_cppgen.OpcodeClass.Simt
|
| 252 |
+
assert plan.opclass == expected_opclass, f'Expected opclass to be {expected_opclass}, but received {plan.opclass} for SM{cc}'
|
| 253 |
+
|
| 254 |
+
@unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for F16 Tensor Core tests.")
|
| 255 |
+
def test_opclass_switch(self):
|
| 256 |
+
"""
|
| 257 |
+
Tests cases in which the opcode class in question is switched (e.g., from TensorOp to SIMT)
|
| 258 |
+
"""
|
| 259 |
+
plan = cutlass_cppgen.op.Gemm( element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 260 |
+
assert plan.opclass == cutlass_cppgen.OpcodeClass.TensorOp
|
| 261 |
+
|
| 262 |
+
# Ensure that all tile descriptions have opclass of TensorOp
|
| 263 |
+
for td in plan.tile_descriptions():
|
| 264 |
+
assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.TensorOp
|
| 265 |
+
|
| 266 |
+
plan.opclass = cutlass_cppgen.OpcodeClass.Simt
|
| 267 |
+
|
| 268 |
+
# Ensure that all tile descriptions have opclass of Simt
|
| 269 |
+
for td in plan.tile_descriptions():
|
| 270 |
+
assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.Simt
|
| 271 |
+
|
| 272 |
+
def test_invalid_tile_description(self):
|
| 273 |
+
"""
|
| 274 |
+
Tests scenarios in which an invalid tile description is provided for a given CC
|
| 275 |
+
"""
|
| 276 |
+
cc = device_cc()
|
| 277 |
+
plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
|
| 278 |
+
td = plan.tile_descriptions()[0]
|
| 279 |
+
stages = td.stages
|
| 280 |
+
|
| 281 |
+
# Zero stage count is valid for SM90+, as this is used to indicate that the builder's auto stage
|
| 282 |
+
# count should be used
|
| 283 |
+
with ExpectException(cc < 90, f'Requested zero stages'):
|
| 284 |
+
td.stages = 0
|
| 285 |
+
plan.construct(td)
|
| 286 |
+
|
| 287 |
+
if cc < 90:
|
| 288 |
+
with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'):
|
| 289 |
+
td.stages = 3
|
| 290 |
+
plan.construct(td)
|
| 291 |
+
elif cc == 90:
|
| 292 |
+
original_kschedule = td.kernel_schedule
|
| 293 |
+
original_eschedule = td.epilogue_schedule
|
| 294 |
+
with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'):
|
| 295 |
+
td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong
|
| 296 |
+
td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.NoSmemWarpSpecialized
|
| 297 |
+
td.stages = 3
|
| 298 |
+
plan.construct(td)
|
| 299 |
+
# Reset schedules
|
| 300 |
+
td.kernel_schedule = original_kschedule
|
| 301 |
+
td.epilogue_schedule = original_eschedule
|
| 302 |
+
elif cc in [100, 101, 103]:
|
| 303 |
+
with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'):
|
| 304 |
+
td.stages = 3
|
| 305 |
+
plan.construct(td)
|
| 306 |
+
|
| 307 |
+
with ExpectException(True, f'Requested too many stages'):
|
| 308 |
+
td.stages = 100
|
| 309 |
+
plan.construct(td)
|
| 310 |
+
|
| 311 |
+
# Reset stage count
|
| 312 |
+
td.stages = stages
|
| 313 |
+
|
| 314 |
+
cluster_shape = td.cluster_shape
|
| 315 |
+
with ExpectException(cc < 90, f'Requested non-unit cluster shape on SM{cc}'):
|
| 316 |
+
td.cluster_shape = [2, 1, 1]
|
| 317 |
+
plan.construct(td)
|
| 318 |
+
|
| 319 |
+
# Reset cluster shape
|
| 320 |
+
td.cluster_shape = cluster_shape
|
| 321 |
+
|
| 322 |
+
with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'):
|
| 323 |
+
td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong
|
| 324 |
+
td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized
|
| 325 |
+
plan.construct(td)
|
| 326 |
+
|
| 327 |
+
with ExpectException(cc == 90, f'Requested a non-auto kernel schedule with an auto epilogue schedule'):
|
| 328 |
+
td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong
|
| 329 |
+
td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto
|
| 330 |
+
plan.construct(td)
|
| 331 |
+
|
| 332 |
+
with ExpectException(cc == 90, f'Requested an auto kernel schedule with a non-auto epilogue schedule'):
|
| 333 |
+
td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto
|
| 334 |
+
td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized
|
| 335 |
+
plan.construct(td)
|
| 336 |
+
|
| 337 |
+
with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'):
|
| 338 |
+
td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative
|
| 339 |
+
td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative
|
| 340 |
+
td.tile_scheduler = cutlass_cppgen.TileSchedulerType.StreamK
|
| 341 |
+
plan.construct(td)
|
| 342 |
+
|
| 343 |
+
# Ensure that all returned tile descriptions are unique
|
| 344 |
+
ops = {}
|
| 345 |
+
for i, td in enumerate(plan.tile_descriptions()):
|
| 346 |
+
op = plan.construct(td)
|
| 347 |
+
code_str = op.rt_module.emit()
|
| 348 |
+
if code_str in ops:
|
| 349 |
+
conflicting_td = ops[code_str]
|
| 350 |
+
assert False, f'Multiple tile descriptions emitted {code_str}\nTile descriptions are:\n{td}\n{conflicting_td}'
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
if __name__ == '__main__':
|
| 354 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Helper functions & classes for interface test
|
| 35 |
+
"""
|
| 36 |
+
class ExpectException:
|
| 37 |
+
"""
|
| 38 |
+
Utility class to assert that an exception was raised when expected
|
| 39 |
+
|
| 40 |
+
Example:
|
| 41 |
+
|
| 42 |
+
.. highlight:: python
|
| 43 |
+
.. code-block:: python
|
| 44 |
+
|
| 45 |
+
with ExceptionExpected(True, 'Division by zero'):
|
| 46 |
+
x = 1.0 / 0.0
|
| 47 |
+
|
| 48 |
+
:param exception_expected: whether an exception is expected to be raised
|
| 49 |
+
:type exception_expected: bool
|
| 50 |
+
:param message: message to print if an exception is raised when not expected or vice versa
|
| 51 |
+
:type message: str
|
| 52 |
+
"""
|
| 53 |
+
def __init__(self, exception_expected: bool, message: str = '', verify_msg=False):
|
| 54 |
+
self.exception_expected = exception_expected
|
| 55 |
+
self.message = message
|
| 56 |
+
self.verify_msg = verify_msg
|
| 57 |
+
|
| 58 |
+
def __enter__(self):
|
| 59 |
+
return self
|
| 60 |
+
|
| 61 |
+
def __exit__(self, exc_type, exc_val, traceback):
|
| 62 |
+
exception_raised = exc_type is not None
|
| 63 |
+
assert self.exception_expected == exception_raised, self.message
|
| 64 |
+
if self.verify_msg:
|
| 65 |
+
exc_message = f"{exc_type.__name__}: {exc_val}"
|
| 66 |
+
assert exc_message == self.message, f"expect error message {self.message}, got {exc_message}"
|
| 67 |
+
|
| 68 |
+
# Suppress the exception
|
| 69 |
+
return True
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Utility script for discovering and running all PyCuTe tests
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import argparse
|
| 38 |
+
import logging
|
| 39 |
+
import pathlib
|
| 40 |
+
import unittest
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def numeric_log_level(log_level: str) -> int:
|
| 44 |
+
"""
|
| 45 |
+
Converts the string identifier of the log level into the numeric identifier used
|
| 46 |
+
in setting the log level
|
| 47 |
+
|
| 48 |
+
:param x: string representation of log level (e.g., 'INFO', 'DEBUG')
|
| 49 |
+
:type x: str
|
| 50 |
+
|
| 51 |
+
:return: numeric representation of log level
|
| 52 |
+
:rtype: int
|
| 53 |
+
"""
|
| 54 |
+
numeric_level = getattr(logging, log_level.upper(), None)
|
| 55 |
+
if not isinstance(numeric_level, int):
|
| 56 |
+
raise ValueError(f"Invalid log level: {log_level}")
|
| 57 |
+
return numeric_level
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
parser = argparse.ArgumentParser()
|
| 62 |
+
parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False,
|
| 63 |
+
help='Logging level to be used by the generator script')
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
|
| 66 |
+
# Set the logging level based on the user-provided `--log-level` command-line option
|
| 67 |
+
logging.basicConfig(level=args.log_level)
|
| 68 |
+
|
| 69 |
+
loader = unittest.TestLoader()
|
| 70 |
+
script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/'
|
| 71 |
+
tests = loader.discover(script_dir, "test_*.py")
|
| 72 |
+
test_runner = unittest.runner.TextTestRunner()
|
| 73 |
+
results = test_runner.run(tests)
|
| 74 |
+
if not results.wasSuccessful():
|
| 75 |
+
raise Exception("Test cases failed")
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit tests for pycute.coalesce
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
from pycute import *
|
| 41 |
+
|
| 42 |
+
_LOGGER = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TestCoalesce(unittest.TestCase):
|
| 46 |
+
def helper_test_coalesce(self, layout):
|
| 47 |
+
layoutR = coalesce(layout)
|
| 48 |
+
|
| 49 |
+
_LOGGER.debug(f"{layout} => {layoutR}")
|
| 50 |
+
|
| 51 |
+
self.assertEqual(size(layoutR), size(layout))
|
| 52 |
+
|
| 53 |
+
for i in range(size(layout)):
|
| 54 |
+
self.assertEqual(layoutR(i), layout(i))
|
| 55 |
+
|
| 56 |
+
def test_coalesce(self):
|
| 57 |
+
layout = Layout(1,0)
|
| 58 |
+
self.helper_test_coalesce(layout)
|
| 59 |
+
|
| 60 |
+
layout = Layout(1,1)
|
| 61 |
+
self.helper_test_coalesce(layout)
|
| 62 |
+
|
| 63 |
+
layout = Layout((2,4))
|
| 64 |
+
self.helper_test_coalesce(layout)
|
| 65 |
+
|
| 66 |
+
layout = Layout((2,4,6))
|
| 67 |
+
self.helper_test_coalesce(layout)
|
| 68 |
+
|
| 69 |
+
layout = Layout((2,4,6), (1,6,2))
|
| 70 |
+
self.helper_test_coalesce(layout)
|
| 71 |
+
|
| 72 |
+
layout = Layout((2,1,6), (1,7,2))
|
| 73 |
+
self.helper_test_coalesce(layout)
|
| 74 |
+
|
| 75 |
+
layout = Layout((2,1,6), (4,7,8))
|
| 76 |
+
self.helper_test_coalesce(layout)
|
| 77 |
+
|
| 78 |
+
layout = Layout((2,(4,6)))
|
| 79 |
+
self.helper_test_coalesce(layout)
|
| 80 |
+
|
| 81 |
+
layout = Layout((2,4), (4,1))
|
| 82 |
+
self.helper_test_coalesce(layout)
|
| 83 |
+
|
| 84 |
+
layout = Layout((2,4,6), (24,6,1))
|
| 85 |
+
self.helper_test_coalesce(layout)
|
| 86 |
+
|
| 87 |
+
layout = Layout((2,1,3), (2,4,4))
|
| 88 |
+
self.helper_test_coalesce(layout)
|
| 89 |
+
|
| 90 |
+
layout = Layout(((2,2),(2,2)), ((1,4),(8,32)))
|
| 91 |
+
self.helper_test_coalesce(layout)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit tests for pycute.complement
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
from pycute import *
|
| 41 |
+
|
| 42 |
+
_LOGGER = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TestComplement(unittest.TestCase):
|
| 46 |
+
def helper_test_complement(self, layout):
|
| 47 |
+
layoutR = complement(layout)
|
| 48 |
+
|
| 49 |
+
_LOGGER.debug(f"{layout} => {layoutR}")
|
| 50 |
+
|
| 51 |
+
# Post-condition: test disjointness of the codomains
|
| 52 |
+
for a in range(size(layout)):
|
| 53 |
+
for b in range(size(layoutR)):
|
| 54 |
+
assert (layout(a) != layoutR(b)) or (layout(a) == 0 and layoutR(b) == 0)
|
| 55 |
+
|
| 56 |
+
def test_complement(self):
|
| 57 |
+
test = Layout(1,0)
|
| 58 |
+
self.helper_test_complement(test)
|
| 59 |
+
|
| 60 |
+
test = Layout(1,1)
|
| 61 |
+
self.helper_test_complement(test)
|
| 62 |
+
|
| 63 |
+
test = Layout(4,0)
|
| 64 |
+
self.helper_test_complement(test)
|
| 65 |
+
|
| 66 |
+
test = Layout((2,4),(1,2))
|
| 67 |
+
self.helper_test_complement(test)
|
| 68 |
+
|
| 69 |
+
test = Layout((2,3),(1,2))
|
| 70 |
+
self.helper_test_complement(test)
|
| 71 |
+
|
| 72 |
+
test = Layout((2,4),(1,4))
|
| 73 |
+
self.helper_test_complement(test)
|
| 74 |
+
|
| 75 |
+
test = Layout((2,4,8),(8,1,64))
|
| 76 |
+
self.helper_test_complement(test)
|
| 77 |
+
|
| 78 |
+
test = Layout(((2,2),(2,2)),((1,4),(8,32)))
|
| 79 |
+
self.helper_test_complement(test)
|
| 80 |
+
|
| 81 |
+
test = Layout((2,(3,4)),(3,(1,6)))
|
| 82 |
+
self.helper_test_complement(test)
|
| 83 |
+
|
| 84 |
+
test = Layout((4,6),(1,6))
|
| 85 |
+
self.helper_test_complement(test)
|
| 86 |
+
|
| 87 |
+
test = Layout((4,10),(1,10))
|
| 88 |
+
self.helper_test_complement(test)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit tests for pycute.composition
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
from pycute import *
|
| 41 |
+
|
| 42 |
+
_LOGGER = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TestComposition(unittest.TestCase):
|
| 46 |
+
def helper_test_composition(self, layoutA, layoutB):
|
| 47 |
+
layoutR = composition(layoutA, layoutB)
|
| 48 |
+
|
| 49 |
+
_LOGGER.debug(f"{layoutA} o {layoutB} => {layoutR}")
|
| 50 |
+
|
| 51 |
+
# True post-condition: Every coordinate c of layoutB with L1D(c) < size(layoutR) is a coordinate of layoutR.
|
| 52 |
+
|
| 53 |
+
# Test that R(c) = A(B(c)) for all coordinates c in layoutR
|
| 54 |
+
for i in range(size(layoutR)):
|
| 55 |
+
self.assertEqual(layoutR(i), layoutA(layoutB(i)))
|
| 56 |
+
|
| 57 |
+
def test_composition(self):
|
| 58 |
+
layoutA = Layout(1,0)
|
| 59 |
+
layoutB = Layout(1,0)
|
| 60 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 61 |
+
|
| 62 |
+
layoutA = Layout(1,0)
|
| 63 |
+
layoutB = Layout(1,1)
|
| 64 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 65 |
+
|
| 66 |
+
layoutA = Layout(1,1)
|
| 67 |
+
layoutB = Layout(1,0)
|
| 68 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 69 |
+
|
| 70 |
+
layoutA = Layout(1,1)
|
| 71 |
+
layoutB = Layout(1,1)
|
| 72 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 73 |
+
|
| 74 |
+
layoutA = Layout((4))
|
| 75 |
+
layoutB = Layout((4))
|
| 76 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 77 |
+
|
| 78 |
+
layoutA = Layout((4), (2))
|
| 79 |
+
layoutB = Layout((4))
|
| 80 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 81 |
+
|
| 82 |
+
layoutA = Layout((4))
|
| 83 |
+
layoutB = Layout((4), (2))
|
| 84 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 85 |
+
|
| 86 |
+
layoutA = Layout((4), (0))
|
| 87 |
+
layoutB = Layout((4))
|
| 88 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 89 |
+
|
| 90 |
+
layoutA = Layout((4))
|
| 91 |
+
layoutB = Layout((4), (0))
|
| 92 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 93 |
+
|
| 94 |
+
layoutA = Layout((1), (0))
|
| 95 |
+
layoutB = Layout((4))
|
| 96 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 97 |
+
|
| 98 |
+
layoutA = Layout((4))
|
| 99 |
+
layoutB = Layout((1), (0))
|
| 100 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 101 |
+
|
| 102 |
+
layoutA = Layout((4))
|
| 103 |
+
layoutB = Layout((2))
|
| 104 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 105 |
+
|
| 106 |
+
layoutA = Layout((4), (2))
|
| 107 |
+
layoutB = Layout((2))
|
| 108 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 109 |
+
|
| 110 |
+
layoutA = Layout((4))
|
| 111 |
+
layoutB = Layout((2), (2))
|
| 112 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 113 |
+
|
| 114 |
+
layoutA = Layout((4), (2))
|
| 115 |
+
layoutB = Layout((2), (2))
|
| 116 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 117 |
+
|
| 118 |
+
layoutA = Layout((12))
|
| 119 |
+
layoutB = Layout((4,3))
|
| 120 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 121 |
+
|
| 122 |
+
layoutA = Layout((12), (2))
|
| 123 |
+
layoutB = Layout((4,3))
|
| 124 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 125 |
+
|
| 126 |
+
layoutA = Layout((12))
|
| 127 |
+
layoutB = Layout((4,3), (3,1))
|
| 128 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 129 |
+
|
| 130 |
+
layoutA = Layout((12), (2))
|
| 131 |
+
layoutB = Layout((4,3), (3,1))
|
| 132 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 133 |
+
|
| 134 |
+
layoutA = Layout((12))
|
| 135 |
+
layoutB = Layout((2,3), (2,4))
|
| 136 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 137 |
+
|
| 138 |
+
layoutA = Layout((4,3))
|
| 139 |
+
layoutB = Layout((4,3))
|
| 140 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 141 |
+
|
| 142 |
+
layoutA = Layout((4,3))
|
| 143 |
+
layoutB = Layout((12))
|
| 144 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 145 |
+
|
| 146 |
+
layoutA = Layout((4,3))
|
| 147 |
+
layoutB = Layout((6), (2))
|
| 148 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 149 |
+
|
| 150 |
+
layoutA = Layout((4,3))
|
| 151 |
+
layoutB = Layout((6,2), (2,1))
|
| 152 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 153 |
+
|
| 154 |
+
layoutA = Layout((4,3), (3,1))
|
| 155 |
+
layoutB = Layout((4,3))
|
| 156 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 157 |
+
|
| 158 |
+
layoutA = Layout((4,3), (3,1))
|
| 159 |
+
layoutB = Layout((12))
|
| 160 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 161 |
+
|
| 162 |
+
layoutA = Layout((4,3), (3,1))
|
| 163 |
+
layoutB = Layout((6), (2))
|
| 164 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 165 |
+
|
| 166 |
+
layoutA = Layout((4,3), (3,1))
|
| 167 |
+
layoutB = Layout((6,2), (2,1))
|
| 168 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 169 |
+
|
| 170 |
+
layoutA = Layout((8,8))
|
| 171 |
+
layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32)))
|
| 172 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 173 |
+
|
| 174 |
+
layoutA = Layout((8,8), (8,1))
|
| 175 |
+
layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32)))
|
| 176 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 177 |
+
|
| 178 |
+
layoutA = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32)))
|
| 179 |
+
layoutB = Layout(8, 4)
|
| 180 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 181 |
+
|
| 182 |
+
layoutA = Layout(((4,2)), ((1,16)))
|
| 183 |
+
layoutB = Layout((4,2), (2,1))
|
| 184 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 185 |
+
|
| 186 |
+
layoutA = Layout((2,2), (2,1))
|
| 187 |
+
layoutB = Layout((2,2), (2,1))
|
| 188 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 189 |
+
|
| 190 |
+
layoutA = Layout((4,8,2))
|
| 191 |
+
layoutB = Layout((2,2,2), (2,8,1))
|
| 192 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 193 |
+
|
| 194 |
+
layoutA = Layout((4,8,2), (2,8,1))
|
| 195 |
+
layoutB = Layout((2,2,2), (1,8,2))
|
| 196 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 197 |
+
|
| 198 |
+
layoutA = Layout((4,8,2), (2,8,1))
|
| 199 |
+
layoutB = Layout((4,2,2), (2,8,1))
|
| 200 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 201 |
+
|
| 202 |
+
# Pre-coalesced LHS
|
| 203 |
+
layoutA = Layout((4,6,8),(1,4,7))
|
| 204 |
+
layoutB = Layout((6),(1))
|
| 205 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 206 |
+
|
| 207 |
+
# Mid-layout truncation
|
| 208 |
+
layoutA = Layout((4,6,8,10),(2,3,5,7))
|
| 209 |
+
layoutB = Layout(6,12)
|
| 210 |
+
self.helper_test_composition(layoutA, layoutB)
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit tests for pycute.int_tuple
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import unittest
|
| 38 |
+
|
| 39 |
+
from pycute import *
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class TestIntTuple(unittest.TestCase):
|
| 43 |
+
def test_product(self):
|
| 44 |
+
self.assertEqual(product(2), 2)
|
| 45 |
+
|
| 46 |
+
self.assertEqual(product((3,2)), 6)
|
| 47 |
+
|
| 48 |
+
self.assertEqual(product(product(((2,3),4))), 24)
|
| 49 |
+
|
| 50 |
+
def test_inner_product(self):
|
| 51 |
+
self.assertEqual(inner_product(2, 3), 6)
|
| 52 |
+
|
| 53 |
+
self.assertEqual(inner_product((1,2), (3,2)), 7)
|
| 54 |
+
|
| 55 |
+
self.assertEqual(inner_product(((2,3),4), ((2,1),2)), 15)
|
| 56 |
+
|
| 57 |
+
def test_shape_div(self):
|
| 58 |
+
self.assertEqual(shape_div((3,4), 6), (1,2))
|
| 59 |
+
|
| 60 |
+
self.assertEqual(shape_div((3,4), 12), (1,1))
|
| 61 |
+
|
| 62 |
+
self.assertEqual(shape_div((3,4), 36), (1,1))
|
| 63 |
+
|
| 64 |
+
self.assertEqual(shape_div(((3,4),6), 36), ((1,1),2))
|
| 65 |
+
|
| 66 |
+
self.assertEqual(shape_div((6,(3,4)), 36), (1,(1,2)))
|
| 67 |
+
|
| 68 |
+
def test_prefix_product(self):
|
| 69 |
+
self.assertEqual(prefix_product(2), 1)
|
| 70 |
+
|
| 71 |
+
self.assertEqual(prefix_product((3,2)), (1,3))
|
| 72 |
+
|
| 73 |
+
self.assertEqual(prefix_product((3,2,4)), (1,3,6))
|
| 74 |
+
|
| 75 |
+
self.assertEqual(prefix_product(((2,3),4)), ((1,2),6))
|
| 76 |
+
|
| 77 |
+
self.assertEqual(prefix_product(((2,3),(2, 1, 2),( 5, 2, 1))),
|
| 78 |
+
((1,2),(6,12,12),(24,120,240)))
|
| 79 |
+
|
| 80 |
+
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit tests for pycute.left_inverse
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
from pycute import *
|
| 41 |
+
|
| 42 |
+
_LOGGER = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TestLeftInverse(unittest.TestCase):
|
| 46 |
+
def helper_test_left_inverse(self, layout):
|
| 47 |
+
inv_layout = left_inverse(layout)
|
| 48 |
+
|
| 49 |
+
_LOGGER.debug(f"{layout} => {inv_layout}")
|
| 50 |
+
|
| 51 |
+
for i in range(size(layout)):
|
| 52 |
+
self.assertEqual(inv_layout(layout(i)), i)
|
| 53 |
+
|
| 54 |
+
def test_left_inverse(self):
|
| 55 |
+
test = Layout(1,0)
|
| 56 |
+
self.helper_test_left_inverse(test)
|
| 57 |
+
|
| 58 |
+
test = Layout((1,1),(0,0))
|
| 59 |
+
self.helper_test_left_inverse(test)
|
| 60 |
+
|
| 61 |
+
test = Layout(1,1)
|
| 62 |
+
self.helper_test_left_inverse(test)
|
| 63 |
+
|
| 64 |
+
test = Layout(4,1)
|
| 65 |
+
self.helper_test_left_inverse(test)
|
| 66 |
+
|
| 67 |
+
test = Layout(4,2)
|
| 68 |
+
self.helper_test_left_inverse(test)
|
| 69 |
+
|
| 70 |
+
test = Layout((8,4),(1,8))
|
| 71 |
+
self.helper_test_left_inverse(test)
|
| 72 |
+
|
| 73 |
+
test = Layout((8,4),(4,1))
|
| 74 |
+
self.helper_test_left_inverse(test)
|
| 75 |
+
|
| 76 |
+
test = Layout((2,4,6),(1,2,8))
|
| 77 |
+
self.helper_test_left_inverse(test)
|
| 78 |
+
|
| 79 |
+
test = Layout((2,4,6),(4,1,8))
|
| 80 |
+
self.helper_test_left_inverse(test)
|
| 81 |
+
|
| 82 |
+
test = Layout((4,2),(1,16))
|
| 83 |
+
self.helper_test_left_inverse(test)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit tests for pycute.left_inverse
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
|
| 40 |
+
from pycute import *
|
| 41 |
+
|
| 42 |
+
_LOGGER = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TestRightInverse(unittest.TestCase):
|
| 46 |
+
def helper_test_right_inverse(self, layout):
|
| 47 |
+
inv_layout = right_inverse(layout)
|
| 48 |
+
|
| 49 |
+
_LOGGER.debug(f"{layout} => {inv_layout}")
|
| 50 |
+
|
| 51 |
+
for i in range(size(inv_layout)):
|
| 52 |
+
self.assertEqual(layout(inv_layout(i)), i)
|
| 53 |
+
|
| 54 |
+
def test_right_inverse(self):
|
| 55 |
+
test = Layout(1,0)
|
| 56 |
+
self.helper_test_right_inverse(test)
|
| 57 |
+
|
| 58 |
+
test = Layout((1,1),(0,0))
|
| 59 |
+
self.helper_test_right_inverse(test)
|
| 60 |
+
|
| 61 |
+
test = Layout((3,7),(0,0))
|
| 62 |
+
self.helper_test_right_inverse(test)
|
| 63 |
+
|
| 64 |
+
test = Layout(1,1)
|
| 65 |
+
self.helper_test_right_inverse(test)
|
| 66 |
+
|
| 67 |
+
test = Layout(4,0)
|
| 68 |
+
self.helper_test_right_inverse(test)
|
| 69 |
+
|
| 70 |
+
test = Layout(4,1)
|
| 71 |
+
self.helper_test_right_inverse(test)
|
| 72 |
+
|
| 73 |
+
test = Layout(4,2)
|
| 74 |
+
self.helper_test_right_inverse(test)
|
| 75 |
+
|
| 76 |
+
test = Layout((2,4),(0,2))
|
| 77 |
+
self.helper_test_right_inverse(test)
|
| 78 |
+
|
| 79 |
+
test = Layout((8,4),(1,8))
|
| 80 |
+
self.helper_test_right_inverse(test)
|
| 81 |
+
|
| 82 |
+
test = Layout((8,4),(4,1))
|
| 83 |
+
self.helper_test_right_inverse(test)
|
| 84 |
+
|
| 85 |
+
test = Layout((2,4,6),(1,2,8))
|
| 86 |
+
self.helper_test_right_inverse(test)
|
| 87 |
+
|
| 88 |
+
test = Layout((2,4,6),(4,1,8))
|
| 89 |
+
self.helper_test_right_inverse(test)
|
| 90 |
+
|
| 91 |
+
test = Layout((4,2),(1,16))
|
| 92 |
+
self.helper_test_right_inverse(test)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#################################################################################################
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 4 |
+
# SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
+
#
|
| 6 |
+
# Redistribution and use in source and binary forms, with or without
|
| 7 |
+
# modification, are permitted provided that the following conditions are met:
|
| 8 |
+
#
|
| 9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
| 10 |
+
# list of conditions and the following disclaimer.
|
| 11 |
+
#
|
| 12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 13 |
+
# this list of conditions and the following disclaimer in the documentation
|
| 14 |
+
# and/or other materials provided with the distribution.
|
| 15 |
+
#
|
| 16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
| 17 |
+
# contributors may be used to endorse or promote products derived from
|
| 18 |
+
# this software without specific prior written permission.
|
| 19 |
+
#
|
| 20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 30 |
+
#
|
| 31 |
+
#################################################################################################
|
| 32 |
+
|
| 33 |
+
"""
|
| 34 |
+
Unit tests for pycute.typing
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
import logging
|
| 38 |
+
import unittest
|
| 39 |
+
from pycute import *
|
| 40 |
+
|
| 41 |
+
_LOGGER = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TestTyping(unittest.TestCase):
|
| 45 |
+
def helper_test_typing(self, _cls, _obj, cls, expected: bool):
|
| 46 |
+
_LOGGER.debug(f"issubclass({_cls}, {cls})")
|
| 47 |
+
_LOGGER.debug(f"isinstance({_obj}, {cls})")
|
| 48 |
+
|
| 49 |
+
self.assertEqual(expected, issubclass(_cls, cls))
|
| 50 |
+
self.assertEqual(expected, isinstance(_obj, cls))
|
| 51 |
+
|
| 52 |
+
def test_typing(self):
|
| 53 |
+
self.helper_test_typing(int, 1, Integer, True)
|
| 54 |
+
self.helper_test_typing(float, 1., Integer, False)
|
| 55 |
+
self.helper_test_typing(str, 'hi', Integer, False)
|
| 56 |
+
self.helper_test_typing(bool, False, Integer, False)
|
| 57 |
+
|
| 58 |
+
if __name__ == '__main__':
|
| 59 |
+
unittest.main()
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
#pragma warning (disable : 4068 ) /* disable unknown pragma warnings for visual studio */
|
| 34 |
+
|
| 35 |
+
#pragma nv_diag_suppress boolean_controlling_expr_is_constant
|
| 36 |
+
#include <gtest/gtest.h>
|
| 37 |
+
#pragma nv_diag_warning boolean_controlling_expr_is_constant
|
| 38 |
+
#pragma warning( disable : 4503)
|
| 39 |
+
|
| 40 |
+
#include <cstdlib>
|
| 41 |
+
#include <string>
|
| 42 |
+
|
| 43 |
+
#include <cuda_runtime_api.h>
|
| 44 |
+
|
| 45 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 46 |
+
|
| 47 |
+
/// Gets a CUDA device
|
| 48 |
+
cudaDeviceProp GetCudaDevice();
|
| 49 |
+
|
| 50 |
+
/// Prints device properties
|
| 51 |
+
std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &device);
|
| 52 |
+
|
| 53 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 54 |
+
|
| 55 |
+
/// Sets flags for Unit test
|
| 56 |
+
void FilterArchitecture();
|
| 57 |
+
|
| 58 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
|
| 60 |
+
/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order
|
| 61 |
+
// of problem sizes run by CUTLASS unit tests
|
| 62 |
+
int CutlassUnitTestProblemCount();
|
| 63 |
+
|
| 64 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 65 |
+
|
| 66 |
+
// active test macro
|
| 67 |
+
#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \
|
| 68 |
+
TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__
|
| 69 |
+
|
| 70 |
+
// disabled test macro
|
| 71 |
+
#define CUTLASS_TEST_LEVEL_DISABLED(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \
|
| 72 |
+
TEST(NAME_STATIC,DISABLED_L##LEVEL##_##NAME_DYNAMIC) {}
|
| 73 |
+
|
| 74 |
+
#if CUTLASS_TEST_LEVEL == 0
|
| 75 |
+
#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
|
| 76 |
+
#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
|
| 77 |
+
#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
|
| 78 |
+
#elif CUTLASS_TEST_LEVEL == 1
|
| 79 |
+
#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
|
| 80 |
+
#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
|
| 81 |
+
#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
|
| 82 |
+
#else
|
| 83 |
+
#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
|
| 84 |
+
#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
|
| 85 |
+
#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
|
| 86 |
+
#endif
|
| 87 |
+
|
| 88 |
+
#if !defined(CUTLASS_TEST_UNIT_ENABLE_WARNINGS)
|
| 89 |
+
#define CUTLASS_TEST_UNIT_ENABLE_WARNINGS false
|
| 90 |
+
#endif
|
| 91 |
+
|
| 92 |
+
#if (__CUDACC_VER_MAJOR__ >= 12)
|
| 93 |
+
#define CUDA_12_0_SM90_FEATURES_SUPPORTED true
|
| 94 |
+
#else
|
| 95 |
+
#define CUDA_12_0_SM90_FEATURES_SUPPORTED false
|
| 96 |
+
#endif
|
| 97 |
+
|
| 98 |
+
#include <cutlass/cutlass.h>
|
| 99 |
+
#include <cutlass/numeric_types.h>
|
| 100 |
+
#include <cutlass/trace.h>
|
| 101 |
+
|
| 102 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h
ADDED
|
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Helper to construct cached name for
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include <typeinfo>
|
| 37 |
+
#include <fstream>
|
| 38 |
+
#include <list>
|
| 39 |
+
#include <utility>
|
| 40 |
+
#include <sstream>
|
| 41 |
+
|
| 42 |
+
#include "cutlass/cutlass.h"
|
| 43 |
+
#include "cutlass/layout/matrix.h"
|
| 44 |
+
#include "cutlass/conv/convolution.h"
|
| 45 |
+
#include "cutlass/conv/conv2d_problem_size.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/conv/conv3d_problem_size.h"
|
| 48 |
+
#include "cutlass/core_io.h"
|
| 49 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 50 |
+
|
| 51 |
+
#include "thrust/universal_vector.h"
|
| 52 |
+
|
| 53 |
+
#ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS
|
| 54 |
+
#define CUTLASS_TEST_ENABLE_CACHED_RESULTS false
|
| 55 |
+
#endif
|
| 56 |
+
|
| 57 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 58 |
+
|
| 59 |
+
namespace test::conv::device {
|
| 60 |
+
|
| 61 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 62 |
+
|
| 63 |
+
/// Result of a test
|
| 64 |
+
struct CachedTestKey {
|
| 65 |
+
|
| 66 |
+
std::string op; ///< Concatenated string representation of operation performed
|
| 67 |
+
std::string problem; ///< Concatenated string representation of problem description
|
| 68 |
+
std::string types; ///< Concatenated string representation of operand types
|
| 69 |
+
uint32_t A; ///< Hashed result of tensor A
|
| 70 |
+
uint32_t B; ///< Hashed result of tensor B
|
| 71 |
+
uint32_t C; ///< Hashed result of tensor C
|
| 72 |
+
|
| 73 |
+
//
|
| 74 |
+
// Methods
|
| 75 |
+
//
|
| 76 |
+
inline CachedTestKey(): A(), B(), C() { }
|
| 77 |
+
|
| 78 |
+
inline CachedTestKey(
|
| 79 |
+
std::string op, ///< Concatenated string representation of operation performed
|
| 80 |
+
std::string problem, ///< Concatenated string representation of problem description
|
| 81 |
+
std::string types, ///< Concatenated string representation of operand types
|
| 82 |
+
uint32_t A, ///< Hashed result of tensor A
|
| 83 |
+
uint32_t B, ///< Hashed result of tensor B
|
| 84 |
+
uint32_t C ///< Hashed result of tensor C
|
| 85 |
+
):
|
| 86 |
+
op(op), problem(problem), types(types), A(A), B(B), C(C)
|
| 87 |
+
{ }
|
| 88 |
+
|
| 89 |
+
/// Checks for equality of the problem
|
| 90 |
+
bool operator==(CachedTestKey const &rhs) const {
|
| 91 |
+
return op == rhs.op && problem == rhs.problem && types == rhs.types && A == rhs.A && B == rhs.B && C == rhs.C;
|
| 92 |
+
}
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 96 |
+
|
| 97 |
+
inline std::istream &operator>>(std::istream &in, CachedTestKey &result) {
|
| 98 |
+
|
| 99 |
+
in >> result.op;
|
| 100 |
+
in >> result.problem;
|
| 101 |
+
in >> result.types;
|
| 102 |
+
in >> result.A;
|
| 103 |
+
in >> result.B;
|
| 104 |
+
in >> result.C;
|
| 105 |
+
|
| 106 |
+
return in;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) {
|
| 110 |
+
|
| 111 |
+
out << result.op << " ";
|
| 112 |
+
out << result.problem << " ";
|
| 113 |
+
out << result.types << " ";
|
| 114 |
+
out << result.A << " ";
|
| 115 |
+
out << result.B << " ";
|
| 116 |
+
out << result.C << " ";
|
| 117 |
+
|
| 118 |
+
return out;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 122 |
+
|
| 123 |
+
struct CachedTestResult {
|
| 124 |
+
uint32_t D;
|
| 125 |
+
//
|
| 126 |
+
// Methods
|
| 127 |
+
//
|
| 128 |
+
|
| 129 |
+
CachedTestResult(): D()
|
| 130 |
+
{ }
|
| 131 |
+
|
| 132 |
+
CachedTestResult(uint32_t D): D(D)
|
| 133 |
+
{ }
|
| 134 |
+
|
| 135 |
+
operator bool() const {
|
| 136 |
+
return bool(D);
|
| 137 |
+
}
|
| 138 |
+
};
|
| 139 |
+
|
| 140 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 141 |
+
|
| 142 |
+
inline std::istream &operator>>(std::istream &in, CachedTestResult &result) {
|
| 143 |
+
in >> result.D;
|
| 144 |
+
return in;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
inline std::ostream &operator<<(std::ostream &out, CachedTestResult const &result) {
|
| 148 |
+
out << result.D;
|
| 149 |
+
return out;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 153 |
+
|
| 154 |
+
struct CachedTestResultListing {
|
| 155 |
+
|
| 156 |
+
std::list<std::pair<CachedTestKey, CachedTestResult>> results;
|
| 157 |
+
|
| 158 |
+
//
|
| 159 |
+
// Methods
|
| 160 |
+
//
|
| 161 |
+
|
| 162 |
+
inline CachedTestResultListing(std::string const &path) {
|
| 163 |
+
std::ifstream file(path);
|
| 164 |
+
|
| 165 |
+
while (file.good()) {
|
| 166 |
+
CachedTestKey key;
|
| 167 |
+
file >> key;
|
| 168 |
+
|
| 169 |
+
CachedTestResult result;
|
| 170 |
+
file >> result;
|
| 171 |
+
|
| 172 |
+
if (result) {
|
| 173 |
+
results.push_back(std::make_pair(key, result));
|
| 174 |
+
}
|
| 175 |
+
}
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
/// Returns the cached result
|
| 179 |
+
std::pair<bool, CachedTestResult> find(CachedTestKey const &rhs) const {
|
| 180 |
+
for (auto const & result : results) {
|
| 181 |
+
if (result.first == rhs) {
|
| 182 |
+
return std::make_pair(true, result.second);
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
return std::make_pair(false, CachedTestResult());
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
/// Appends an entry
|
| 189 |
+
void append(CachedTestKey const &key, CachedTestResult const &result) {
|
| 190 |
+
if (result) {
|
| 191 |
+
results.push_back(std::make_pair(key, result));
|
| 192 |
+
}
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
/// Writes the entire listing to a file
|
| 196 |
+
bool write(std::string const &path) {
|
| 197 |
+
std::ofstream file(path);
|
| 198 |
+
if (!file.good()) {
|
| 199 |
+
return false;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
for (auto const &result : results) {
|
| 203 |
+
file << result.first << result.second << std::endl;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
return true;
|
| 207 |
+
}
|
| 208 |
+
};
|
| 209 |
+
|
| 210 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 211 |
+
|
| 212 |
+
template <typename Element>
|
| 213 |
+
struct ScalarEncoder {
|
| 214 |
+
Element scalar;
|
| 215 |
+
|
| 216 |
+
ScalarEncoder(Element s): scalar(s) { }
|
| 217 |
+
|
| 218 |
+
std::string str() const {
|
| 219 |
+
std::stringstream ss;
|
| 220 |
+
Element s = scalar;
|
| 221 |
+
if (s < Element()) {
|
| 222 |
+
s = -s;
|
| 223 |
+
ss << "n";
|
| 224 |
+
}
|
| 225 |
+
ss << s;
|
| 226 |
+
return ss.str();
|
| 227 |
+
}
|
| 228 |
+
};
|
| 229 |
+
|
| 230 |
+
template <typename Element>
|
| 231 |
+
ScalarEncoder<Element> EncodeScalar(Element a) {
|
| 232 |
+
return ScalarEncoder<Element>(a);
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
template <typename Element>
|
| 236 |
+
struct ScalarEncoder<cutlass::complex<Element>> {
|
| 237 |
+
cutlass::complex<Element> scalar;
|
| 238 |
+
|
| 239 |
+
ScalarEncoder(cutlass::complex<Element> s): scalar(s) { }
|
| 240 |
+
|
| 241 |
+
std::string str() const {
|
| 242 |
+
std::stringstream ss;
|
| 243 |
+
ss << EncodeScalar<Element>(scalar.real()) << "_" << EncodeScalar<Element>(scalar.imag()) << "i";
|
| 244 |
+
return ss.str();
|
| 245 |
+
}
|
| 246 |
+
};
|
| 247 |
+
|
| 248 |
+
template <typename Element>
|
| 249 |
+
std::ostream &operator<<(std::ostream &out, ScalarEncoder<Element> const &scalar) {
|
| 250 |
+
out << scalar.str();
|
| 251 |
+
return out;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 255 |
+
|
| 256 |
+
inline char const *EncodeOperator(cutlass::conv::Operator conv_op) {
|
| 257 |
+
switch (conv_op) {
|
| 258 |
+
case cutlass::conv::Operator::kFprop: return "fprop";
|
| 259 |
+
case cutlass::conv::Operator::kDgrad: return "dgrad";
|
| 260 |
+
case cutlass::conv::Operator::kWgrad: return "wgrad";
|
| 261 |
+
case cutlass::conv::Operator::kDeconv: return "deconv";
|
| 262 |
+
}
|
| 263 |
+
return "conv_unknown";
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 267 |
+
|
| 268 |
+
// Encode GemmCoord (Gemm problem size)
|
| 269 |
+
inline std::ostream &EncodeProblemSize(
|
| 270 |
+
std::ostream &out,
|
| 271 |
+
cutlass::gemm::GemmCoord const &problem) {
|
| 272 |
+
|
| 273 |
+
out << problem.m() << "x" << problem.n() << "x" << problem.k() << "_";
|
| 274 |
+
|
| 275 |
+
return out;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 279 |
+
// Encode Conv2dProblemSize
|
| 280 |
+
inline std::ostream &EncodeProblemSize(
|
| 281 |
+
std::ostream &out,
|
| 282 |
+
cutlass::conv::Conv2dProblemSize const &problem) {
|
| 283 |
+
|
| 284 |
+
out << problem.N << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_"
|
| 285 |
+
<< problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_";
|
| 286 |
+
|
| 287 |
+
out << "pad_h" << problem.pad_h << "w" << problem.pad_w << "_";
|
| 288 |
+
out << "stride_h" << problem.stride_h << "w" << problem.stride_w << "_";
|
| 289 |
+
out << "dil_h" << problem.dilation_h << "w" << problem.dilation_w << "_";
|
| 290 |
+
|
| 291 |
+
switch (problem.mode) {
|
| 292 |
+
case cutlass::conv::Mode::kCrossCorrelation:
|
| 293 |
+
out << "corr";
|
| 294 |
+
break;
|
| 295 |
+
case cutlass::conv::Mode::kConvolution:
|
| 296 |
+
out << "conv";
|
| 297 |
+
break;
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
return out;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 304 |
+
|
| 305 |
+
// Encode Conv3dProblemSize
|
| 306 |
+
inline std::ostream &EncodeProblemSize(
|
| 307 |
+
std::ostream &out,
|
| 308 |
+
cutlass::conv::Conv3dProblemSize const &problem) {
|
| 309 |
+
|
| 310 |
+
out << problem.N << "x" << problem.D << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_"
|
| 311 |
+
<< problem.Z << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_";
|
| 312 |
+
|
| 313 |
+
out << "pad_d" << problem.pad_h << "h" << problem.pad_h << "w" << problem.pad_w << "_";
|
| 314 |
+
out << "stride_d" << problem.stride_d << "h" << problem.stride_h << "w" << problem.stride_w << "_";
|
| 315 |
+
out << "dil_d" << problem.dilation_d << "h" << problem.dilation_h << "w" << problem.dilation_w << "_";
|
| 316 |
+
|
| 317 |
+
switch (problem.mode) {
|
| 318 |
+
case cutlass::conv::Mode::kCrossCorrelation:
|
| 319 |
+
out << "corr";
|
| 320 |
+
break;
|
| 321 |
+
case cutlass::conv::Mode::kConvolution:
|
| 322 |
+
out << "conv";
|
| 323 |
+
break;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
return out;
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 330 |
+
// Encode 3.x ConvNd ProblemShape
|
| 331 |
+
template <class ProblemShape>
|
| 332 |
+
inline std::ostream &EncodeProblemSize(
|
| 333 |
+
std::ostream &out,
|
| 334 |
+
ProblemShape const& problem_shape) {
|
| 335 |
+
|
| 336 |
+
out << problem_shape.shape_A << "_";
|
| 337 |
+
out << problem_shape.shape_B << "_";
|
| 338 |
+
|
| 339 |
+
out << "padl" << problem_shape.lower_padding << "_";
|
| 340 |
+
out << "padu" << problem_shape.upper_padding << "_";
|
| 341 |
+
out << "str" << problem_shape.traversal_stride << "_";
|
| 342 |
+
out << "dil" << problem_shape.dilation << "_";
|
| 343 |
+
|
| 344 |
+
switch (problem_shape.mode) {
|
| 345 |
+
case cutlass::conv::Mode::kCrossCorrelation:
|
| 346 |
+
out << "corr";
|
| 347 |
+
break;
|
| 348 |
+
case cutlass::conv::Mode::kConvolution:
|
| 349 |
+
out << "conv";
|
| 350 |
+
break;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
return out;
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 357 |
+
|
| 358 |
+
template <typename Element>
|
| 359 |
+
inline std::string ElementTypeName() {
|
| 360 |
+
return std::string(typeid(Element).name());
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
template <>
|
| 364 |
+
inline std::string ElementTypeName<cutlass::half_t>() {
|
| 365 |
+
return "h";
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
template <>
|
| 369 |
+
inline std::string ElementTypeName<cutlass::complex<cutlass::half_t>>() {
|
| 370 |
+
return "ch";
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
template <>
|
| 374 |
+
inline std::string ElementTypeName<cutlass::bfloat16_t>() {
|
| 375 |
+
return "bf16";
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
template <>
|
| 379 |
+
inline std::string ElementTypeName<cutlass::complex<cutlass::bfloat16_t>>() {
|
| 380 |
+
return "cbf16";
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
template <>
|
| 384 |
+
inline std::string ElementTypeName<cutlass::tfloat32_t>() {
|
| 385 |
+
return "tf32";
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
template <>
|
| 389 |
+
inline std::string ElementTypeName<cutlass::complex<cutlass::tfloat32_t>>() {
|
| 390 |
+
return "ctf32";
|
| 391 |
+
}
|
| 392 |
+
|
| 393 |
+
template <>
|
| 394 |
+
inline std::string ElementTypeName<cutlass::complex<float>>() {
|
| 395 |
+
return "c";
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
template <>
|
| 399 |
+
inline std::string ElementTypeName<cutlass::complex<double>>() {
|
| 400 |
+
return "z";
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
template <>
|
| 404 |
+
inline std::string ElementTypeName<cutlass::Quaternion<float>>() {
|
| 405 |
+
return "q";
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
template <>
|
| 409 |
+
inline std::string ElementTypeName<int8_t>() {
|
| 410 |
+
return "s8";
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
template <>
|
| 414 |
+
inline std::string ElementTypeName<uint8_t>() {
|
| 415 |
+
return "u8";
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
template <>
|
| 419 |
+
inline std::string ElementTypeName<cutlass::int4b_t>() {
|
| 420 |
+
return "s4";
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
template <>
|
| 424 |
+
inline std::string ElementTypeName<cutlass::uint4b_t>() {
|
| 425 |
+
return "u4";
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 429 |
+
|
| 430 |
+
template <typename Layout>
|
| 431 |
+
inline std::string LayoutTypeName() {
|
| 432 |
+
return std::string(typeid(Layout).name());
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
template <>
|
| 436 |
+
inline std::string LayoutTypeName<cutlass::layout::ColumnMajor>() {
|
| 437 |
+
return "n";
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
template <>
|
| 441 |
+
inline std::string LayoutTypeName<cutlass::layout::RowMajor>() {
|
| 442 |
+
return "t";
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
template <>
|
| 446 |
+
inline std::string LayoutTypeName<cutlass::layout::TensorNHWC>() {
|
| 447 |
+
return "nhwc";
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
template <>
|
| 451 |
+
inline std::string LayoutTypeName<cutlass::layout::TensorNCxHWx<32>>() {
|
| 452 |
+
return "nc32hw32";
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
template <>
|
| 456 |
+
inline std::string LayoutTypeName<cutlass::layout::TensorNCxHWx<64>>() {
|
| 457 |
+
return "nc64hw64";
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
template <>
|
| 461 |
+
inline std::string LayoutTypeName<cutlass::layout::TensorCxRSKx<32>>() {
|
| 462 |
+
return "c32rsk32";
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
template <>
|
| 466 |
+
inline std::string LayoutTypeName<cutlass::layout::TensorCxRSKx<64>>() {
|
| 467 |
+
return "c64rsk64";
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
template <>
|
| 471 |
+
inline std::string LayoutTypeName<cutlass::layout::TensorNDHWC>() {
|
| 472 |
+
return "ndhwc";
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 476 |
+
|
| 477 |
+
template <typename Element, typename Layout>
|
| 478 |
+
inline std::string TensorTypeName() {
|
| 479 |
+
std::stringstream ss;
|
| 480 |
+
ss << ElementTypeName<Element>() << LayoutTypeName<Layout>();
|
| 481 |
+
return ss.str();
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
template <typename Element>
|
| 485 |
+
inline std::string TensorTypeName() {
|
| 486 |
+
std::stringstream ss;
|
| 487 |
+
ss << ElementTypeName<Element>();
|
| 488 |
+
return ss.str();
|
| 489 |
+
}
|
| 490 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 491 |
+
|
| 492 |
+
/// Hash function on a byte array
|
| 493 |
+
struct CRC32 {
|
| 494 |
+
|
| 495 |
+
uint32_t table[256];
|
| 496 |
+
|
| 497 |
+
//
|
| 498 |
+
// Methods
|
| 499 |
+
//
|
| 500 |
+
|
| 501 |
+
CRC32() {
|
| 502 |
+
|
| 503 |
+
uint32_t rem;
|
| 504 |
+
int i, j;
|
| 505 |
+
|
| 506 |
+
for (i = 0; i < 256; i++) {
|
| 507 |
+
rem = i;
|
| 508 |
+
for (j = 0; j < 8; j++) {
|
| 509 |
+
if (rem & 1) {
|
| 510 |
+
rem >>= 1;
|
| 511 |
+
rem ^= 0xedb88320;
|
| 512 |
+
} else
|
| 513 |
+
rem >>= 1;
|
| 514 |
+
}
|
| 515 |
+
table[i] = rem;
|
| 516 |
+
}
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
/// Computes the CRC of an array of bytes
|
| 520 |
+
uint32_t operator()(void const *start, size_t length, uint32_t crc = uint32_t()) const {
|
| 521 |
+
uint8_t const *p = static_cast<uint8_t const *>(start);
|
| 522 |
+
uint8_t const *q = static_cast<uint8_t const *>(start) + length;
|
| 523 |
+
|
| 524 |
+
crc = ~crc;
|
| 525 |
+
|
| 526 |
+
for (; p != q; ++p) {
|
| 527 |
+
uint8_t octet = *p;
|
| 528 |
+
crc = (crc >> 8) ^ table[(crc & 0xff) ^ octet];
|
| 529 |
+
}
|
| 530 |
+
|
| 531 |
+
return ~crc;
|
| 532 |
+
}
|
| 533 |
+
};
|
| 534 |
+
|
| 535 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 536 |
+
|
| 537 |
+
template <
|
| 538 |
+
typename Element, typename Layout
|
| 539 |
+
>
|
| 540 |
+
uint32_t TensorHash(
|
| 541 |
+
cutlass::TensorView<Element, Layout> view,
|
| 542 |
+
CRC32 const &hash = CRC32(),
|
| 543 |
+
uint32_t crc = uint32_t()
|
| 544 |
+
) {
|
| 545 |
+
|
| 546 |
+
return hash(view.data(), view.capacity() * cutlass::sizeof_bits<Element>::value / 8, crc);
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
template <typename Element>
|
| 550 |
+
uint32_t TensorHash(
|
| 551 |
+
thrust::universal_vector<Element>& tensor,
|
| 552 |
+
CRC32 const &hash = CRC32(),
|
| 553 |
+
uint32_t crc = uint32_t()
|
| 554 |
+
) {
|
| 555 |
+
|
| 556 |
+
return hash(tensor.data().get(), tensor.size() * cutlass::sizeof_bits<Element>::value / 8, crc);
|
| 557 |
+
}
|
| 558 |
+
|
| 559 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 560 |
+
|
| 561 |
+
template <
|
| 562 |
+
typename ElementA, typename LayoutA,
|
| 563 |
+
typename ElementB, typename LayoutB,
|
| 564 |
+
typename ElementC, typename LayoutC,
|
| 565 |
+
typename ElementAccumulator,
|
| 566 |
+
typename ElementCompute
|
| 567 |
+
>
|
| 568 |
+
inline std::ostream &EncodeTypes(
|
| 569 |
+
std::ostream &out
|
| 570 |
+
) {
|
| 571 |
+
|
| 572 |
+
out << TensorTypeName<ElementA, LayoutA>() << "_"
|
| 573 |
+
<< TensorTypeName<ElementB, LayoutB>() << "_"
|
| 574 |
+
<< TensorTypeName<ElementC, LayoutC>() << "_"
|
| 575 |
+
<< ElementTypeName<ElementAccumulator>() << "_"
|
| 576 |
+
<< ElementTypeName<ElementCompute>();
|
| 577 |
+
|
| 578 |
+
return out;
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
template <
|
| 582 |
+
typename ElementA,
|
| 583 |
+
typename ElementB,
|
| 584 |
+
typename ElementC,
|
| 585 |
+
typename ElementD
|
| 586 |
+
>
|
| 587 |
+
inline std::ostream &EncodeTypes(
|
| 588 |
+
std::ostream &out
|
| 589 |
+
) {
|
| 590 |
+
|
| 591 |
+
out << TensorTypeName<ElementA>() << "_"
|
| 592 |
+
<< TensorTypeName<ElementB>() << "_"
|
| 593 |
+
<< TensorTypeName<ElementC>() << "_"
|
| 594 |
+
<< ElementTypeName<ElementD>();
|
| 595 |
+
|
| 596 |
+
return out;
|
| 597 |
+
}
|
| 598 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 599 |
+
|
| 600 |
+
template <
|
| 601 |
+
typename ElementA, typename LayoutA,
|
| 602 |
+
typename ElementB, typename LayoutB,
|
| 603 |
+
typename ElementC, typename LayoutC,
|
| 604 |
+
typename ElementAccumulator,
|
| 605 |
+
typename ElementCompute
|
| 606 |
+
>
|
| 607 |
+
inline CachedTestKey CreateCachedGemmTestKey(
|
| 608 |
+
cutlass::gemm::GemmCoord const &problem,
|
| 609 |
+
ElementCompute alpha,
|
| 610 |
+
ElementCompute beta,
|
| 611 |
+
cutlass::TensorView<ElementA, LayoutA> A,
|
| 612 |
+
cutlass::TensorView<ElementB, LayoutB> B,
|
| 613 |
+
cutlass::TensorView<ElementC, LayoutC> C
|
| 614 |
+
) {
|
| 615 |
+
|
| 616 |
+
CachedTestKey key;
|
| 617 |
+
|
| 618 |
+
// Encode gemm operator and problem sizes
|
| 619 |
+
key.op = "gemm";
|
| 620 |
+
|
| 621 |
+
std::stringstream ss_problem;
|
| 622 |
+
EncodeProblemSize(ss_problem, problem);
|
| 623 |
+
ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
|
| 624 |
+
key.problem = ss_problem.str();
|
| 625 |
+
|
| 626 |
+
// Encode problem data types
|
| 627 |
+
std::stringstream ss_types;
|
| 628 |
+
EncodeTypes<
|
| 629 |
+
ElementA, LayoutA,
|
| 630 |
+
ElementB, LayoutB,
|
| 631 |
+
ElementC, LayoutC,
|
| 632 |
+
ElementAccumulator,
|
| 633 |
+
ElementCompute>(ss_types);
|
| 634 |
+
key.types = ss_types.str();
|
| 635 |
+
|
| 636 |
+
// Encode hash for problem data
|
| 637 |
+
CRC32 crc_hash;
|
| 638 |
+
key.A = TensorHash(A, crc_hash);
|
| 639 |
+
key.B = TensorHash(B, crc_hash);
|
| 640 |
+
key.C = TensorHash(C, crc_hash);
|
| 641 |
+
|
| 642 |
+
return key;
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
template <
|
| 649 |
+
typename ElementA, typename LayoutA,
|
| 650 |
+
typename ElementB, typename LayoutB,
|
| 651 |
+
typename ElementC, typename LayoutC,
|
| 652 |
+
typename ElementAccumulator,
|
| 653 |
+
typename ElementCompute
|
| 654 |
+
>
|
| 655 |
+
inline CachedTestKey CreateCachedConv2dTestKey(
|
| 656 |
+
|
| 657 |
+
cutlass::conv::Operator conv_operator,
|
| 658 |
+
cutlass::conv::Conv2dProblemSize const &problem,
|
| 659 |
+
ElementCompute alpha,
|
| 660 |
+
ElementCompute beta,
|
| 661 |
+
cutlass::TensorView<ElementA, LayoutA> A,
|
| 662 |
+
cutlass::TensorView<ElementB, LayoutB> B,
|
| 663 |
+
cutlass::TensorView<ElementC, LayoutC> C
|
| 664 |
+
) {
|
| 665 |
+
|
| 666 |
+
CachedTestKey key;
|
| 667 |
+
|
| 668 |
+
// Encode conv2d operator and problem sizes
|
| 669 |
+
key.op = "conv2d";
|
| 670 |
+
|
| 671 |
+
std::stringstream ss_problem;
|
| 672 |
+
ss_problem << EncodeOperator(conv_operator) << "_";
|
| 673 |
+
EncodeProblemSize(ss_problem, problem);
|
| 674 |
+
ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
|
| 675 |
+
|
| 676 |
+
key.problem = ss_problem.str();
|
| 677 |
+
|
| 678 |
+
// Encode problem data types
|
| 679 |
+
std::stringstream ss_types;
|
| 680 |
+
EncodeTypes<
|
| 681 |
+
ElementA, LayoutA,
|
| 682 |
+
ElementB, LayoutB,
|
| 683 |
+
ElementC, LayoutC,
|
| 684 |
+
ElementAccumulator,
|
| 685 |
+
ElementCompute>(ss_types);
|
| 686 |
+
key.types = ss_types.str();
|
| 687 |
+
|
| 688 |
+
// Encode hash for problem data
|
| 689 |
+
CRC32 crc_hash;
|
| 690 |
+
|
| 691 |
+
key.A = TensorHash(A, crc_hash);
|
| 692 |
+
key.B = TensorHash(B, crc_hash);
|
| 693 |
+
key.C = TensorHash(C, crc_hash);
|
| 694 |
+
|
| 695 |
+
return key;
|
| 696 |
+
}
|
| 697 |
+
|
| 698 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 699 |
+
|
| 700 |
+
template <
|
| 701 |
+
typename ElementA, typename LayoutA,
|
| 702 |
+
typename ElementB, typename LayoutB,
|
| 703 |
+
typename ElementC, typename LayoutC,
|
| 704 |
+
typename ElementAccumulator,
|
| 705 |
+
typename ElementCompute
|
| 706 |
+
>
|
| 707 |
+
inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey(
|
| 708 |
+
|
| 709 |
+
cutlass::conv::Operator conv_operator,
|
| 710 |
+
cutlass::conv::Conv2dProblemSize const &problem,
|
| 711 |
+
ElementCompute alpha,
|
| 712 |
+
ElementCompute beta,
|
| 713 |
+
cutlass::TensorView<ElementA, LayoutA> A,
|
| 714 |
+
cutlass::TensorView<ElementB, LayoutB> B,
|
| 715 |
+
cutlass::TensorView<ElementC, LayoutC> C
|
| 716 |
+
) {
|
| 717 |
+
|
| 718 |
+
CachedTestKey key;
|
| 719 |
+
|
| 720 |
+
// Encode conv2d operator and problem sizes
|
| 721 |
+
key.op = "conv2d_with_broadcast";
|
| 722 |
+
|
| 723 |
+
std::stringstream ss_problem;
|
| 724 |
+
ss_problem << EncodeOperator(conv_operator) << "_";
|
| 725 |
+
EncodeProblemSize(ss_problem, problem);
|
| 726 |
+
ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
|
| 727 |
+
|
| 728 |
+
key.problem = ss_problem.str();
|
| 729 |
+
|
| 730 |
+
// Encode problem data types
|
| 731 |
+
std::stringstream ss_types;
|
| 732 |
+
EncodeTypes<
|
| 733 |
+
ElementA, LayoutA,
|
| 734 |
+
ElementB, LayoutB,
|
| 735 |
+
ElementC, LayoutC,
|
| 736 |
+
ElementAccumulator,
|
| 737 |
+
ElementCompute>(ss_types);
|
| 738 |
+
key.types = ss_types.str();
|
| 739 |
+
|
| 740 |
+
// Encode hash for problem data
|
| 741 |
+
CRC32 crc_hash;
|
| 742 |
+
|
| 743 |
+
key.A = TensorHash(A, crc_hash);
|
| 744 |
+
key.B = TensorHash(B, crc_hash);
|
| 745 |
+
key.C = TensorHash(C, crc_hash);
|
| 746 |
+
|
| 747 |
+
return key;
|
| 748 |
+
}
|
| 749 |
+
|
| 750 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 751 |
+
|
| 752 |
+
template <
|
| 753 |
+
typename ElementA, typename LayoutA,
|
| 754 |
+
typename ElementB, typename LayoutB,
|
| 755 |
+
typename ElementC, typename LayoutC,
|
| 756 |
+
typename ElementAccumulator,
|
| 757 |
+
typename ElementCompute
|
| 758 |
+
>
|
| 759 |
+
inline CachedTestKey CreateCachedConv2dWithReductionTestKey(
|
| 760 |
+
|
| 761 |
+
cutlass::conv::Operator conv_operator,
|
| 762 |
+
cutlass::conv::Conv2dProblemSize const &problem,
|
| 763 |
+
ElementCompute alpha,
|
| 764 |
+
ElementCompute beta,
|
| 765 |
+
cutlass::TensorView<ElementA, LayoutA> A,
|
| 766 |
+
cutlass::TensorView<ElementB, LayoutB> B,
|
| 767 |
+
cutlass::TensorView<ElementC, LayoutC> C
|
| 768 |
+
) {
|
| 769 |
+
|
| 770 |
+
CachedTestKey key;
|
| 771 |
+
|
| 772 |
+
// Encode conv2d operator and problem sizes
|
| 773 |
+
key.op = "conv2d_with_reduction";
|
| 774 |
+
|
| 775 |
+
std::stringstream ss_problem;
|
| 776 |
+
ss_problem << EncodeOperator(conv_operator) << "_";
|
| 777 |
+
EncodeProblemSize(ss_problem, problem);
|
| 778 |
+
ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
|
| 779 |
+
|
| 780 |
+
key.problem = ss_problem.str();
|
| 781 |
+
|
| 782 |
+
// Encode problem data types
|
| 783 |
+
std::stringstream ss_types;
|
| 784 |
+
EncodeTypes<
|
| 785 |
+
ElementA, LayoutA,
|
| 786 |
+
ElementB, LayoutB,
|
| 787 |
+
ElementC, LayoutC,
|
| 788 |
+
ElementAccumulator,
|
| 789 |
+
ElementCompute>(ss_types);
|
| 790 |
+
key.types = ss_types.str();
|
| 791 |
+
|
| 792 |
+
// Encode hash for problem data
|
| 793 |
+
CRC32 crc_hash;
|
| 794 |
+
|
| 795 |
+
key.A = TensorHash(A, crc_hash);
|
| 796 |
+
key.B = TensorHash(B, crc_hash);
|
| 797 |
+
key.C = TensorHash(C, crc_hash);
|
| 798 |
+
|
| 799 |
+
return key;
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 803 |
+
|
| 804 |
+
template <
|
| 805 |
+
typename ElementA, typename LayoutA,
|
| 806 |
+
typename ElementB, typename LayoutB,
|
| 807 |
+
typename ElementC, typename LayoutC,
|
| 808 |
+
typename ElementAccumulator,
|
| 809 |
+
typename ElementCompute
|
| 810 |
+
>
|
| 811 |
+
inline CachedTestKey CreateCachedConv3dTestKey(
|
| 812 |
+
cutlass::conv::Operator conv_operator,
|
| 813 |
+
cutlass::conv::Conv3dProblemSize const &problem,
|
| 814 |
+
ElementCompute alpha,
|
| 815 |
+
ElementCompute beta,
|
| 816 |
+
cutlass::TensorView<ElementA, LayoutA> A,
|
| 817 |
+
cutlass::TensorView<ElementB, LayoutB> B,
|
| 818 |
+
cutlass::TensorView<ElementC, LayoutC> C
|
| 819 |
+
) {
|
| 820 |
+
|
| 821 |
+
CachedTestKey key;
|
| 822 |
+
|
| 823 |
+
// Encode conv3d operator and problem sizes
|
| 824 |
+
key.op = "conv3d";
|
| 825 |
+
|
| 826 |
+
std::stringstream ss_problem;
|
| 827 |
+
|
| 828 |
+
ss_problem << EncodeOperator(conv_operator) << "_";
|
| 829 |
+
EncodeProblemSize(ss_problem, problem);
|
| 830 |
+
ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
|
| 831 |
+
|
| 832 |
+
key.problem = ss_problem.str();
|
| 833 |
+
|
| 834 |
+
// Encode problem data types
|
| 835 |
+
std::stringstream ss_types;
|
| 836 |
+
EncodeTypes<
|
| 837 |
+
ElementA, LayoutA,
|
| 838 |
+
ElementB, LayoutB,
|
| 839 |
+
ElementC, LayoutC,
|
| 840 |
+
ElementAccumulator,
|
| 841 |
+
ElementCompute>(ss_types);
|
| 842 |
+
key.types = ss_types.str();
|
| 843 |
+
|
| 844 |
+
// Encode problem data
|
| 845 |
+
CRC32 crc_hash;
|
| 846 |
+
key.A = TensorHash(A, crc_hash);
|
| 847 |
+
key.B = TensorHash(B, crc_hash);
|
| 848 |
+
key.C = TensorHash(C, crc_hash);
|
| 849 |
+
|
| 850 |
+
return key;
|
| 851 |
+
}
|
| 852 |
+
|
| 853 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 854 |
+
|
| 855 |
+
template <
|
| 856 |
+
class ProblemShape,
|
| 857 |
+
typename ElementA,
|
| 858 |
+
typename ElementB,
|
| 859 |
+
typename ElementC,
|
| 860 |
+
typename ElementD
|
| 861 |
+
>
|
| 862 |
+
inline CachedTestKey CreateCachedConvNd3xTestKey(
|
| 863 |
+
cutlass::conv::Operator conv_operator,
|
| 864 |
+
ProblemShape const& problem_shape,
|
| 865 |
+
double alpha,
|
| 866 |
+
double beta,
|
| 867 |
+
thrust::universal_vector<ElementA> A,
|
| 868 |
+
thrust::universal_vector<ElementB> B,
|
| 869 |
+
thrust::universal_vector<ElementC> C
|
| 870 |
+
) {
|
| 871 |
+
|
| 872 |
+
CachedTestKey key;
|
| 873 |
+
|
| 874 |
+
// Encode convNd operator and problem sizes
|
| 875 |
+
std::stringstream ss_op;
|
| 876 |
+
ss_op << "conv" << ProblemShape::RankS << "d";
|
| 877 |
+
key.op = ss_op.str();
|
| 878 |
+
|
| 879 |
+
std::stringstream ss_problem;
|
| 880 |
+
ss_problem << EncodeOperator(conv_operator) << "_";
|
| 881 |
+
EncodeProblemSize(ss_problem, problem_shape);
|
| 882 |
+
ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
|
| 883 |
+
key.problem = ss_problem.str();
|
| 884 |
+
|
| 885 |
+
// Encode problem data types
|
| 886 |
+
std::stringstream ss_types;
|
| 887 |
+
EncodeTypes<
|
| 888 |
+
ElementA,
|
| 889 |
+
ElementB,
|
| 890 |
+
ElementC,
|
| 891 |
+
ElementD>(ss_types);
|
| 892 |
+
key.types = ss_types.str();
|
| 893 |
+
|
| 894 |
+
// Encode problem data
|
| 895 |
+
CRC32 crc_hash;
|
| 896 |
+
key.A = TensorHash(A, crc_hash);
|
| 897 |
+
key.B = TensorHash(B, crc_hash);
|
| 898 |
+
key.C = TensorHash(C, crc_hash);
|
| 899 |
+
|
| 900 |
+
return key;
|
| 901 |
+
}
|
| 902 |
+
|
| 903 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 904 |
+
|
| 905 |
+
} // namespace test::conv::device
|
| 906 |
+
|
| 907 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h
ADDED
|
@@ -0,0 +1,927 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implicit GEMM testbed sizes for Conv2d problem
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include <vector>
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
#include "cutlass/layout/matrix.h"
|
| 40 |
+
#include "cutlass/conv/convolution.h"
|
| 41 |
+
#include "cutlass/conv/conv2d_problem_size.h"
|
| 42 |
+
|
| 43 |
+
namespace test {
|
| 44 |
+
namespace conv {
|
| 45 |
+
namespace device {
|
| 46 |
+
|
| 47 |
+
using Conv2dProblemVector = std::vector<cutlass::conv::Conv2dProblemSize>;
|
| 48 |
+
|
| 49 |
+
//
|
| 50 |
+
// Structures to prune items from Conv2dProblemVector
|
| 51 |
+
//
|
| 52 |
+
// Specification template for pruning items for convolution problem lists
|
| 53 |
+
template <typename T> struct Specification
|
| 54 |
+
{
|
| 55 |
+
virtual ~Specification() = default;
|
| 56 |
+
virtual bool is_satisfied(T item) const = 0;
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
// input size (NHWC) specification
|
| 60 |
+
struct InputSizeSpecification : Specification<cutlass::conv::Conv2dProblemSize>
|
| 61 |
+
{
|
| 62 |
+
cutlass::Tensor4DCoord input_size;
|
| 63 |
+
|
| 64 |
+
InputSizeSpecification(cutlass::Tensor4DCoord input_size_) : input_size(input_size_) {}
|
| 65 |
+
|
| 66 |
+
bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override {
|
| 67 |
+
return ((input_size.n() == item.N) && (input_size.h() == item.H) && (input_size.w() == item.W) && (input_size.c() == item.C));
|
| 68 |
+
}
|
| 69 |
+
};
|
| 70 |
+
|
| 71 |
+
// stride (stride_h, stride_w) specification
|
| 72 |
+
struct StrideSpecification : Specification<cutlass::conv::Conv2dProblemSize>
|
| 73 |
+
{
|
| 74 |
+
cutlass::MatrixCoord stride;
|
| 75 |
+
|
| 76 |
+
StrideSpecification(cutlass::MatrixCoord stride_) : stride(stride_) {}
|
| 77 |
+
|
| 78 |
+
bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override {
|
| 79 |
+
return ((stride.row() == item.stride_h) && (stride.column() == item.stride_h));
|
| 80 |
+
}
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
// channel (C,K) specification, must be multiple of minimum channel
|
| 84 |
+
struct ChannelDivisibilitySpecification : Specification<cutlass::conv::Conv2dProblemSize>
|
| 85 |
+
{
|
| 86 |
+
int channel_multiple;
|
| 87 |
+
|
| 88 |
+
ChannelDivisibilitySpecification(int channel_multiple_) : channel_multiple(channel_multiple_) {}
|
| 89 |
+
|
| 90 |
+
bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override {
|
| 91 |
+
return ((item.K % channel_multiple == 0) && (item.C % channel_multiple == 0));
|
| 92 |
+
}
|
| 93 |
+
};
|
| 94 |
+
|
| 95 |
+
//
|
| 96 |
+
// Pruning function for items from Conv2dProblemVector based on a Specification
|
| 97 |
+
//
|
| 98 |
+
inline Conv2dProblemVector prune(Conv2dProblemVector const &items,
|
| 99 |
+
Specification<cutlass::conv::Conv2dProblemSize> const &spec)
|
| 100 |
+
{
|
| 101 |
+
Conv2dProblemVector pruned_list;
|
| 102 |
+
|
| 103 |
+
for (auto& p : items)
|
| 104 |
+
if (spec.is_satisfied(p))
|
| 105 |
+
pruned_list.push_back(p);
|
| 106 |
+
return pruned_list;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
////////////////////////////////////////////////////////////////////////////
|
| 111 |
+
/// Structure TestbedConv2dProblemSizes initializes and holds conv default and
|
| 112 |
+
/// important network sizes
|
| 113 |
+
////////////////////////////////////////////////////////////////////////////
|
| 114 |
+
struct TestbedConv2dProblemSizes {
|
| 115 |
+
|
| 116 |
+
//
|
| 117 |
+
// Data members
|
| 118 |
+
//
|
| 119 |
+
int minimum_channel_size;
|
| 120 |
+
|
| 121 |
+
Conv2dProblemVector conv2d_default_sizes;
|
| 122 |
+
Conv2dProblemVector conv2d_rigorous_sizes;
|
| 123 |
+
Conv2dProblemVector conv2d_resnet50_sizes;
|
| 124 |
+
Conv2dProblemVector conv2d_resnet50_sizes_perf;
|
| 125 |
+
|
| 126 |
+
//
|
| 127 |
+
// Methods
|
| 128 |
+
//
|
| 129 |
+
/// Default ctor
|
| 130 |
+
TestbedConv2dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) {
|
| 131 |
+
initialize_conv2d_default_sizes();
|
| 132 |
+
initialize_conv2d_rigorous_sizes();
|
| 133 |
+
initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes, 1 /*batch-size*/);
|
| 134 |
+
|
| 135 |
+
initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes_perf, 34 /*batch-size*/);
|
| 136 |
+
filter_all();
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
/// Eliminates some illegal cases
|
| 140 |
+
void filter_all() {
|
| 141 |
+
|
| 142 |
+
Conv2dProblemVector *problems_vectors[] = {
|
| 143 |
+
&conv2d_default_sizes,
|
| 144 |
+
&conv2d_rigorous_sizes,
|
| 145 |
+
&conv2d_resnet50_sizes,
|
| 146 |
+
&conv2d_resnet50_sizes_perf
|
| 147 |
+
};
|
| 148 |
+
|
| 149 |
+
for (Conv2dProblemVector *problems : problems_vectors) {
|
| 150 |
+
Conv2dProblemVector filtered;
|
| 151 |
+
|
| 152 |
+
for (cutlass::conv::Conv2dProblemSize const & problem : *problems) {
|
| 153 |
+
if (!(problem.C % minimum_channel_size)) {
|
| 154 |
+
filtered.push_back(problem);
|
| 155 |
+
}
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
*problems = filtered;
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
// Add a few standard convolution problem sizes
|
| 163 |
+
void initialize_conv2d_default_sizes() {
|
| 164 |
+
|
| 165 |
+
////////////////////////////////////////////////////////////////////////////////////////////
|
| 166 |
+
// Small input size x stride (1,1)
|
| 167 |
+
// C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
|
| 168 |
+
////////////////////////////////////////////////////////////////////////////////////////////
|
| 169 |
+
|
| 170 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 171 |
+
{1, 1, 1, minimum_channel_size}, // input size (NHWC)
|
| 172 |
+
{8, 1, 1, minimum_channel_size}, // filter size (KRSC)
|
| 173 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 174 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 175 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 176 |
+
));
|
| 177 |
+
|
| 178 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 179 |
+
{1, 1, 8, minimum_channel_size}, // input size (NHWC)
|
| 180 |
+
{8, 1, 3, minimum_channel_size}, // filter size (KRSC)
|
| 181 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 182 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 183 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 184 |
+
));
|
| 185 |
+
|
| 186 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 187 |
+
{1, 7, 8, minimum_channel_size}, // input size (NHWC)
|
| 188 |
+
{8, 3, 3, minimum_channel_size}, // filter size (KRSC)
|
| 189 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 190 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 191 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 192 |
+
));
|
| 193 |
+
|
| 194 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 195 |
+
{1, 7, 9, minimum_channel_size}, // input size (NHWC)
|
| 196 |
+
{8, 4, 4, minimum_channel_size}, // filter size (KRSC)
|
| 197 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 198 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 199 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 200 |
+
));
|
| 201 |
+
|
| 202 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 203 |
+
{2, 7, 9, minimum_channel_size}, // input size (NHWC)
|
| 204 |
+
{8, 5, 5, minimum_channel_size}, // filter size (KRSC)
|
| 205 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 206 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 207 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 208 |
+
));
|
| 209 |
+
|
| 210 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 211 |
+
{3, 7, 9, minimum_channel_size}, // input size (NHWC)
|
| 212 |
+
{8, 6, 5, minimum_channel_size}, // filter size (KRSC)
|
| 213 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 214 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 215 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 216 |
+
));
|
| 217 |
+
|
| 218 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 219 |
+
{3, 7, 9, minimum_channel_size}, // input size (NHWC)
|
| 220 |
+
{8, 6, 6, minimum_channel_size}, // filter size (KRSC)
|
| 221 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 222 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 223 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 224 |
+
));
|
| 225 |
+
|
| 226 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 227 |
+
{3, 7, 9, minimum_channel_size}, // input size (NHWC)
|
| 228 |
+
{8, 7, 7, minimum_channel_size}, // filter size (KRSC)
|
| 229 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 230 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 231 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 232 |
+
));
|
| 233 |
+
|
| 234 |
+
////////////////////////////////////////////////////////////////////////////////////////////
|
| 235 |
+
// Small input size x stride (1,1) asymmetric paddings (1, 0, 1, 0)
|
| 236 |
+
// C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
|
| 237 |
+
////////////////////////////////////////////////////////////////////////////////////////////
|
| 238 |
+
|
| 239 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 240 |
+
{1, 1, 1, minimum_channel_size}, // input size (NHWC)
|
| 241 |
+
{8, 1, 1, minimum_channel_size}, // filter size (KRSC)
|
| 242 |
+
{1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
|
| 243 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 244 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 245 |
+
));
|
| 246 |
+
|
| 247 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 248 |
+
{1, 1, 8, minimum_channel_size}, // input size (NHWC)
|
| 249 |
+
{8, 1, 3, minimum_channel_size}, // filter size (KRSC)
|
| 250 |
+
{1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
|
| 251 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 252 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 253 |
+
));
|
| 254 |
+
|
| 255 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 256 |
+
{1, 7, 8, minimum_channel_size}, // input size (NHWC)
|
| 257 |
+
{8, 3, 3, minimum_channel_size}, // filter size (KRSC)
|
| 258 |
+
{1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
|
| 259 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 260 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 261 |
+
));
|
| 262 |
+
|
| 263 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 264 |
+
{1, 7, 9, minimum_channel_size}, // input size (NHWC)
|
| 265 |
+
{8, 4, 4, minimum_channel_size}, // filter size (KRSC)
|
| 266 |
+
{1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
|
| 267 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 268 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 269 |
+
));
|
| 270 |
+
|
| 271 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 272 |
+
{2, 7, 9, minimum_channel_size}, // input size (NHWC)
|
| 273 |
+
{8, 5, 5, minimum_channel_size}, // filter size (KRSC)
|
| 274 |
+
{1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
|
| 275 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 276 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 277 |
+
));
|
| 278 |
+
|
| 279 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 280 |
+
{3, 7, 9, minimum_channel_size}, // input size (NHWC)
|
| 281 |
+
{8, 6, 5, minimum_channel_size}, // filter size (KRSC)
|
| 282 |
+
{1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
|
| 283 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 284 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 285 |
+
));
|
| 286 |
+
|
| 287 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 288 |
+
{3, 7, 9, minimum_channel_size}, // input size (NHWC)
|
| 289 |
+
{8, 6, 6, minimum_channel_size}, // filter size (KRSC)
|
| 290 |
+
{1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
|
| 291 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 292 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 293 |
+
));
|
| 294 |
+
|
| 295 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 296 |
+
{3, 7, 9, minimum_channel_size}, // input size (NHWC)
|
| 297 |
+
{8, 7, 7, minimum_channel_size}, // filter size (KRSC)
|
| 298 |
+
{1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
|
| 299 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 300 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 301 |
+
));
|
| 302 |
+
|
| 303 |
+
////////////////////////////////////////////////////////////////////////////////////////////
|
| 304 |
+
// Small input size x stride (2,2)
|
| 305 |
+
// C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
|
| 306 |
+
////////////////////////////////////////////////////////////////////////////////////////////
|
| 307 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 308 |
+
{1, 11, 7, minimum_channel_size}, // input size (NHWC)
|
| 309 |
+
{8, 1, 1, minimum_channel_size}, // filter size (KRSC)
|
| 310 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 311 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 312 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 313 |
+
));
|
| 314 |
+
|
| 315 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 316 |
+
{1, 11, 7, minimum_channel_size}, // input size (NHWC)
|
| 317 |
+
{8, 3, 3, minimum_channel_size}, // filter size (KRSC)
|
| 318 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 319 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 320 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 321 |
+
));
|
| 322 |
+
|
| 323 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 324 |
+
{1, 13, 11, minimum_channel_size}, // input size (NHWC)
|
| 325 |
+
{8, 1, 1, minimum_channel_size}, // filter size (KRSC)
|
| 326 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 327 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 328 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 329 |
+
));
|
| 330 |
+
|
| 331 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 332 |
+
{1, 17, 19, minimum_channel_size}, // input size (NHWC)
|
| 333 |
+
{16, 2, 2, minimum_channel_size}, // filter size (KRSC)
|
| 334 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 335 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 336 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 337 |
+
));
|
| 338 |
+
|
| 339 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 340 |
+
{1, 23, 5, minimum_channel_size}, // input size (NHWC)
|
| 341 |
+
{16, 3, 3, minimum_channel_size}, // filter size (KRSC)
|
| 342 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 343 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 344 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 345 |
+
));
|
| 346 |
+
|
| 347 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 348 |
+
{1, 13, 17, 8}, // input size (NHWC)
|
| 349 |
+
{24, 3, 3, 8}, // filter size (KRSC)
|
| 350 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 351 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 352 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 353 |
+
));
|
| 354 |
+
|
| 355 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 356 |
+
{1, 23, 21, 8}, // input size (NHWC)
|
| 357 |
+
{24, 3, 3, 8}, // filter size (KRSC)
|
| 358 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 359 |
+
{3, 3}, // stride (stride_h, stride_w)
|
| 360 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 361 |
+
));
|
| 362 |
+
|
| 363 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 364 |
+
{1, 20, 24, 8}, // input size (NHWC)
|
| 365 |
+
{40, 3, 3, 8}, // filter size (KRSC)
|
| 366 |
+
{3, 3, 3, 3}, // padding (pad_h, _, pad_w, _)
|
| 367 |
+
{3, 3}, // stride (stride_h, stride_w)
|
| 368 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 369 |
+
));
|
| 370 |
+
|
| 371 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 372 |
+
// Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1)
|
| 373 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 374 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 375 |
+
{1, 15, 19, 160}, // input size (NHWC)
|
| 376 |
+
{224, 1, 1, 160}, // filter size (KRSC)
|
| 377 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 378 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 379 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 380 |
+
));
|
| 381 |
+
|
| 382 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 383 |
+
{1, 19, 37, 160}, // input size (NHWC)
|
| 384 |
+
{224, 3, 3, 160}, // filter size (KRSC)
|
| 385 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 386 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 387 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 388 |
+
));
|
| 389 |
+
|
| 390 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 391 |
+
{1, 16, 16, 160}, // input size (NHWC)
|
| 392 |
+
{224, 2, 3, 160}, // filter size (KRSC)
|
| 393 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 394 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 395 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 396 |
+
));
|
| 397 |
+
|
| 398 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 399 |
+
{1, 23, 21, 128}, // input size (NHWC)
|
| 400 |
+
{224, 3, 3, 128}, // filter size (KRSC)
|
| 401 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 402 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 403 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 404 |
+
));
|
| 405 |
+
|
| 406 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 407 |
+
{1, 29, 37, 160}, // input size (NHWC)
|
| 408 |
+
{224, 5, 5, 160}, // filter size (KRSC)
|
| 409 |
+
{2, 2, 2, 2}, // padding (pad_h, _, pad_w, _)
|
| 410 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 411 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 412 |
+
));
|
| 413 |
+
|
| 414 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 415 |
+
// C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
|
| 416 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 417 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 418 |
+
{1, 15, 19, 32 + minimum_channel_size}, // input size (NHWC)
|
| 419 |
+
{96, 3, 3, 32 + minimum_channel_size}, // filter size (KRSC)
|
| 420 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 421 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 422 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 423 |
+
));
|
| 424 |
+
|
| 425 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 426 |
+
{1, 16, 24, 64 + minimum_channel_size}, // input size (NHWC)
|
| 427 |
+
{96, 3, 3, 64 + minimum_channel_size}, // filter size (KRSC)
|
| 428 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 429 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 430 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 431 |
+
));
|
| 432 |
+
|
| 433 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 434 |
+
// Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2)
|
| 435 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 436 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 437 |
+
{1, 13, 16, 288}, // input size (NHWC)
|
| 438 |
+
{160, 5, 5, 288}, // filter size (KRSC)
|
| 439 |
+
{2, 2, 2, 2}, // padding (pad_h, _, pad_w, _)
|
| 440 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 441 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 442 |
+
));
|
| 443 |
+
|
| 444 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 445 |
+
{1, 55, 51, 256}, // input size (NHWC)
|
| 446 |
+
{512, 1, 1, 256}, // filter size (KRSC)
|
| 447 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 448 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 449 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 450 |
+
));
|
| 451 |
+
|
| 452 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 453 |
+
{1, 71, 80, 32}, // input size (NHWC)
|
| 454 |
+
{64, 5, 5, 32}, // filter size (KRSC)
|
| 455 |
+
{2, 2, 2, 2}, // padding (pad_h, _, pad_w, _)
|
| 456 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 457 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 458 |
+
));
|
| 459 |
+
|
| 460 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 461 |
+
{1, 224, 224, 8}, // input size (NHWC)
|
| 462 |
+
{64, 7, 7, 8}, // filter size (KRSC)
|
| 463 |
+
{3, 3, 3, 3}, // padding (pad_h, _, pad_w, _)
|
| 464 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 465 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 466 |
+
));
|
| 467 |
+
|
| 468 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 469 |
+
// Medium input size stride (3, 3), filter (3, 3), non-default padding
|
| 470 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 471 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 472 |
+
{1, 27, 23, 256}, // input size (NHWC)
|
| 473 |
+
{512, 3, 3, 256}, // filter size (KRSC)
|
| 474 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 475 |
+
{3, 3}, // stride (stride_h, stride_w)
|
| 476 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 477 |
+
));
|
| 478 |
+
|
| 479 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 480 |
+
// Medium input size padding > stride, asymmetric filter, padding and striding
|
| 481 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 482 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 483 |
+
{1, 27, 31, 256}, // input size (NHWC)
|
| 484 |
+
{512, 3, 3, 256}, // filter size (KRSC)
|
| 485 |
+
{5, 5, 7, 7}, // padding (pad_h, _, pad_w, _)
|
| 486 |
+
{3, 4}, // stride (stride_h, stride_w)
|
| 487 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 488 |
+
));
|
| 489 |
+
|
| 490 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 491 |
+
{1, 27, 35, 256}, // input size (NHWC)
|
| 492 |
+
{512, 7, 5, 256}, // filter size (KRSC)
|
| 493 |
+
{11, 11, 7, 7}, // padding (pad_h, _, pad_w, _)
|
| 494 |
+
{3, 5}, // stride (stride_h, stride_w)
|
| 495 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 496 |
+
));
|
| 497 |
+
|
| 498 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 499 |
+
// Medium input size *mixed* stride (1, 2) and (2, 1),
|
| 500 |
+
// filter (3, 3), default padding
|
| 501 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 502 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 503 |
+
{1, 27, 27, 256}, // input size (NHWC)
|
| 504 |
+
{512, 3, 3, 256}, // filter size (KRSC)
|
| 505 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 506 |
+
{1, 2}, // stride (stride_h, stride_w)
|
| 507 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 508 |
+
));
|
| 509 |
+
|
| 510 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 511 |
+
{1, 27, 27, 256}, // input size (NHWC)
|
| 512 |
+
{512, 3, 3, 256}, // filter size (KRSC)
|
| 513 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 514 |
+
{2, 1}, // stride (stride_h, stride_w)
|
| 515 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 516 |
+
));
|
| 517 |
+
|
| 518 |
+
/////////////////////////////////////////////////////////////////////////////
|
| 519 |
+
// Additional input size
|
| 520 |
+
/////////////////////////////////////////////////////////////////////////////
|
| 521 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 522 |
+
{3, 28, 28, 256}, // input size (NHWC)
|
| 523 |
+
{256, 2, 2, 256}, // filter size (KRSC)
|
| 524 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 525 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 526 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 527 |
+
));
|
| 528 |
+
|
| 529 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 530 |
+
{1, 32, 32, 16}, // input size (NHWC)
|
| 531 |
+
{32, 3, 3, 16}, // filter size (KRSC)
|
| 532 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 533 |
+
{6, 2}, // stride (stride_h, stride_w)
|
| 534 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 535 |
+
));
|
| 536 |
+
|
| 537 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 538 |
+
{32, 24, 32, 32}, // input size (NHWC)
|
| 539 |
+
{32, 1, 2, 32}, // filter size (KRSC)
|
| 540 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 541 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 542 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 543 |
+
));
|
| 544 |
+
|
| 545 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 546 |
+
{4, 4, 5, 128}, // input size (NHWC)
|
| 547 |
+
{256, 3, 6, 128}, // filter size (KRSC)
|
| 548 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 549 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 550 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 551 |
+
{4, 3, 3, 256} // output size (NPQK)
|
| 552 |
+
));
|
| 553 |
+
|
| 554 |
+
conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 555 |
+
{4, 2, 3, 256}, // input size (NHWC)
|
| 556 |
+
{328, 3, 5, 256}, // filter size (KRSC)
|
| 557 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 558 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 559 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 560 |
+
{4, 1, 1, 328} // output size (NPQK)
|
| 561 |
+
));
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
// Add a few large and rigorous convolution problem sizes
|
| 566 |
+
void initialize_conv2d_rigorous_sizes() {
|
| 567 |
+
|
| 568 |
+
#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
|
| 569 |
+
conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 570 |
+
{1, 124, 224, 96}, // input size (NHWC)
|
| 571 |
+
{24, 7, 7, 96}, // filter size (KRSC)
|
| 572 |
+
{1, 229, 129, 32} // output size (NPQK)
|
| 573 |
+
));
|
| 574 |
+
|
| 575 |
+
conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 576 |
+
{1, 233, 35, 48}, // input size (NHWC)
|
| 577 |
+
{24, 7, 5, 48}, // filter size (KRSC)
|
| 578 |
+
{1, 233, 35, 24} // output size (NPQK)
|
| 579 |
+
));
|
| 580 |
+
|
| 581 |
+
#endif
|
| 582 |
+
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
// Add resent50 layers to unit testing sizes
|
| 587 |
+
void initialize_conv2d_resnet50_sizes(Conv2dProblemVector &conv2d_problem_vector, int batch_size = 1){
|
| 588 |
+
|
| 589 |
+
#if 0 // Resnet50 first layer (layer_id = 0) with channel = 3 is not supported in cutlass
|
| 590 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 591 |
+
[1, 224, 224, 3], // input size (NHWC)
|
| 592 |
+
[64, 7, 7, 3], // filter size (KRSC)
|
| 593 |
+
[3, 3, 3, 3], // padding (pad_h, _, pad_w, _)
|
| 594 |
+
[2, 2], // stride (stride_h, stride_w)
|
| 595 |
+
[1, 1], // dilation (dilation_h, dilation_w)
|
| 596 |
+
));
|
| 597 |
+
#endif
|
| 598 |
+
|
| 599 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 600 |
+
{batch_size, 56, 56, 64}, // input size (NHWC)
|
| 601 |
+
{256, 1, 1, 64}, // filter size (KRSC)
|
| 602 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 603 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 604 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 605 |
+
));
|
| 606 |
+
|
| 607 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 608 |
+
{batch_size, 56, 56, 64}, // input size (NHWC)
|
| 609 |
+
{64, 1, 1, 64}, // filter size (KRSC)
|
| 610 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 611 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 612 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 613 |
+
));
|
| 614 |
+
|
| 615 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 616 |
+
{batch_size, 56, 56, 64}, // input size (NHWC)
|
| 617 |
+
{64, 3, 3, 64}, // filter size (KRSC)
|
| 618 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 619 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 620 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 621 |
+
));
|
| 622 |
+
|
| 623 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 624 |
+
{batch_size, 56, 56, 256}, // input size (NHWC)
|
| 625 |
+
{64, 1, 1, 256}, // filter size (KRSC)
|
| 626 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 627 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 628 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 629 |
+
));
|
| 630 |
+
|
| 631 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 632 |
+
{batch_size, 56, 56, 256}, // input size (NHWC)
|
| 633 |
+
{512, 1, 1, 256}, // filter size (KRSC)
|
| 634 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 635 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 636 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 637 |
+
));
|
| 638 |
+
|
| 639 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 640 |
+
{batch_size, 56, 56, 256}, // input size (NHWC)
|
| 641 |
+
{128, 1, 1, 256}, // filter size (KRSC)
|
| 642 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 643 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 644 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 645 |
+
));
|
| 646 |
+
|
| 647 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 648 |
+
{batch_size, 28, 28, 128}, // input size (NHWC)
|
| 649 |
+
{128, 3, 3, 128}, // filter size (KRSC)
|
| 650 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 651 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 652 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 653 |
+
));
|
| 654 |
+
|
| 655 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 656 |
+
{batch_size, 28, 28, 128}, // input size (NHWC)
|
| 657 |
+
{512, 1, 1, 128}, // filter size (KRSC)
|
| 658 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 659 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 660 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 661 |
+
));
|
| 662 |
+
|
| 663 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 664 |
+
{batch_size, 28, 28, 512}, // input size (NHWC)
|
| 665 |
+
{128, 1, 1, 512}, // filter size (KRSC)
|
| 666 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 667 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 668 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 669 |
+
));
|
| 670 |
+
|
| 671 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 672 |
+
{batch_size, 28, 28, 512}, // input size (NHWC)
|
| 673 |
+
{1024, 1, 1, 512}, // filter size (KRSC)
|
| 674 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 675 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 676 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 677 |
+
));
|
| 678 |
+
|
| 679 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 680 |
+
{batch_size, 28, 28, 512}, // input size (NHWC)
|
| 681 |
+
{256, 1, 1, 512}, // filter size (KRSC)
|
| 682 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 683 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 684 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 685 |
+
));
|
| 686 |
+
|
| 687 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 688 |
+
{batch_size, 14, 14, 256}, // input size (NHWC)
|
| 689 |
+
{256, 3, 3, 256}, // filter size (KRSC)
|
| 690 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 691 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 692 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 693 |
+
));
|
| 694 |
+
|
| 695 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 696 |
+
{batch_size, 14, 14, 256}, // input size (NHWC)
|
| 697 |
+
{1024, 1, 1, 256}, // filter size (KRSC)
|
| 698 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 699 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 700 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 701 |
+
));
|
| 702 |
+
|
| 703 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 704 |
+
{batch_size, 14, 14, 1024}, // input size (NHWC)
|
| 705 |
+
{256, 1, 1, 1024}, // filter size (KRSC)
|
| 706 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 707 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 708 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 709 |
+
));
|
| 710 |
+
|
| 711 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 712 |
+
{batch_size, 14, 14, 1024}, // input size (NHWC)
|
| 713 |
+
{2048, 1, 1, 1024}, // filter size (KRSC)
|
| 714 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 715 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 716 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 717 |
+
));
|
| 718 |
+
|
| 719 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 720 |
+
{batch_size, 14, 14, 1024}, // input size (NHWC)
|
| 721 |
+
{512, 1, 1, 1024}, // filter size (KRSC)
|
| 722 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 723 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 724 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 725 |
+
));
|
| 726 |
+
|
| 727 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 728 |
+
{batch_size, 7, 7, 512}, // input size (NHWC)
|
| 729 |
+
{512, 3, 3, 512}, // filter size (KRSC)
|
| 730 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 731 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 732 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 733 |
+
));
|
| 734 |
+
|
| 735 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 736 |
+
{batch_size, 7, 7, 512}, // input size (NHWC)
|
| 737 |
+
{2048, 1, 1, 512}, // filter size (KRSC)
|
| 738 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 739 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 740 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 741 |
+
));
|
| 742 |
+
|
| 743 |
+
conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
|
| 744 |
+
{batch_size, 7, 7, 2048}, // input size (NHWC)
|
| 745 |
+
{512, 1, 1, 2048}, // filter size (KRSC)
|
| 746 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 747 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 748 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 749 |
+
));
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
};
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
////////////////////////////////////////////////////////////////////////////
|
| 756 |
+
/// Structure TestbedGroupConv2dProblemSizes initializes and holds group conv default and
|
| 757 |
+
/// important network sizes
|
| 758 |
+
////////////////////////////////////////////////////////////////////////////
|
| 759 |
+
struct TestbedGroupConv2dProblemSizes {
|
| 760 |
+
|
| 761 |
+
//
|
| 762 |
+
// Data members
|
| 763 |
+
//
|
| 764 |
+
int threadblock_n;
|
| 765 |
+
int threadblock_k;
|
| 766 |
+
int minimum_channel_size;
|
| 767 |
+
|
| 768 |
+
Conv2dProblemVector default_single_group_sizes;
|
| 769 |
+
Conv2dProblemVector default_multiple_group_sizes;
|
| 770 |
+
|
| 771 |
+
//
|
| 772 |
+
// Methods
|
| 773 |
+
//
|
| 774 |
+
/// Default ctor
|
| 775 |
+
TestbedGroupConv2dProblemSizes(
|
| 776 |
+
int threadblock_n_,
|
| 777 |
+
int threadblock_k_,
|
| 778 |
+
int minimum_channel_size_ = 64)
|
| 779 |
+
: threadblock_n (threadblock_n_),
|
| 780 |
+
threadblock_k (threadblock_k_),
|
| 781 |
+
minimum_channel_size (minimum_channel_size_) {
|
| 782 |
+
initialize_group_conv2d_default_sizes();
|
| 783 |
+
filter_all();
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
/// Eliminates some illegal cases
|
| 787 |
+
void filter_all() {
|
| 788 |
+
|
| 789 |
+
Conv2dProblemVector *problems_vectors[] = {
|
| 790 |
+
&default_single_group_sizes,
|
| 791 |
+
&default_multiple_group_sizes
|
| 792 |
+
};
|
| 793 |
+
|
| 794 |
+
for (Conv2dProblemVector *problems : problems_vectors) {
|
| 795 |
+
Conv2dProblemVector filtered;
|
| 796 |
+
|
| 797 |
+
for (cutlass::conv::Conv2dProblemSize const & problem : *problems) {
|
| 798 |
+
if (!((problem.C / problem.groups) % minimum_channel_size)) {
|
| 799 |
+
filtered.push_back(problem);
|
| 800 |
+
}
|
| 801 |
+
}
|
| 802 |
+
|
| 803 |
+
*problems = filtered;
|
| 804 |
+
}
|
| 805 |
+
}
|
| 806 |
+
|
| 807 |
+
// Add a few standard convolution problem sizes
|
| 808 |
+
void initialize_group_conv2d_default_sizes() {
|
| 809 |
+
|
| 810 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 811 |
+
// One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0
|
| 812 |
+
// One CTA calculates a single group
|
| 813 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 814 |
+
|
| 815 |
+
for (int cta_per_group_k = 1; cta_per_group_k < 4; ++cta_per_group_k) {
|
| 816 |
+
// groups = 2, 3, 4
|
| 817 |
+
for (int groups = 2; groups < 5; ++groups) {
|
| 818 |
+
|
| 819 |
+
int conv_k = cta_per_group_k * threadblock_n * groups;
|
| 820 |
+
default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 821 |
+
{1, 8, 8, threadblock_k * 2 * groups}, // input size (NHWC)
|
| 822 |
+
{conv_k, 3, 3, threadblock_k * 2}, // filter size (KRSC)
|
| 823 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 824 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 825 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 826 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 827 |
+
1, // split_k_slices
|
| 828 |
+
groups // groups
|
| 829 |
+
));
|
| 830 |
+
|
| 831 |
+
} // loop groups
|
| 832 |
+
} // loop cta_per_group_k
|
| 833 |
+
|
| 834 |
+
// Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K
|
| 835 |
+
default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 836 |
+
{1, 8, 8, threadblock_k}, // input size (NHWC)
|
| 837 |
+
{threadblock_n * 2, 3, 3, threadblock_k / 2}, // filter size (KRSC)
|
| 838 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 839 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 840 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 841 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 842 |
+
1, // split_k_slices
|
| 843 |
+
2 // groups
|
| 844 |
+
));
|
| 845 |
+
|
| 846 |
+
// Larger problem sizes
|
| 847 |
+
|
| 848 |
+
default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 849 |
+
{1, 56, 56, 696}, // input size (NHWC)
|
| 850 |
+
{768, 3, 3, 232}, // filter size (KRSC)
|
| 851 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 852 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 853 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 854 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 855 |
+
1, // split_k_slices
|
| 856 |
+
3 // groups
|
| 857 |
+
));
|
| 858 |
+
default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 859 |
+
{1, 14, 14, 1392}, // input size (NHWC)
|
| 860 |
+
{1536, 3, 3, 232}, // filter size (KRSC)
|
| 861 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 862 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 863 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 864 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 865 |
+
1, // split_k_slices
|
| 866 |
+
3 // groups
|
| 867 |
+
));
|
| 868 |
+
|
| 869 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 870 |
+
// One CTA calculate multiple groups: CTA::N % k_per_group = 0
|
| 871 |
+
////////////////////////////////////////////////////////////////////////////////////
|
| 872 |
+
|
| 873 |
+
// 2 groups per CTA
|
| 874 |
+
default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 875 |
+
{1, 8, 8, threadblock_k * 4}, // input size (NHWC)
|
| 876 |
+
{threadblock_n, 3, 3, threadblock_k * 2}, // filter size (KRSC)
|
| 877 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 878 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 879 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 880 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 881 |
+
1, // split_k_slices
|
| 882 |
+
2 // groups
|
| 883 |
+
));
|
| 884 |
+
|
| 885 |
+
// 2 groups per CTA and partial gemm_k
|
| 886 |
+
default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 887 |
+
{1, 8, 8, threadblock_k}, // input size (NHWC)
|
| 888 |
+
{threadblock_n, 3, 3, threadblock_k / 2}, // filter size (KRSC)
|
| 889 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 890 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 891 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 892 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 893 |
+
1, // split_k_slices
|
| 894 |
+
2 // groups
|
| 895 |
+
));
|
| 896 |
+
|
| 897 |
+
// 4 groups per CTA
|
| 898 |
+
default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 899 |
+
{1, 8, 8, threadblock_k * 8}, // input size (NHWC)
|
| 900 |
+
{threadblock_n / 2, 3, 3, threadblock_k * 2}, // filter size (KRSC)
|
| 901 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 902 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 903 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 904 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 905 |
+
1, // split_k_slices
|
| 906 |
+
4 // groups
|
| 907 |
+
));
|
| 908 |
+
|
| 909 |
+
// 4 groups per CTA and partial gemm_k
|
| 910 |
+
default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
|
| 911 |
+
{1, 8, 8, threadblock_k * 2}, // input size (NHWC)
|
| 912 |
+
{threadblock_n / 2, 3, 3, threadblock_k / 2}, // filter size (KRSC)
|
| 913 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 914 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 915 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 916 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 917 |
+
1, // split_k_slices
|
| 918 |
+
4 // groups
|
| 919 |
+
));
|
| 920 |
+
}
|
| 921 |
+
|
| 922 |
+
};
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
} // namespace device
|
| 926 |
+
} // namespace conv
|
| 927 |
+
} // namespace test
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h
ADDED
|
@@ -0,0 +1,818 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implicit GEMM testbed
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include <fstream>
|
| 37 |
+
|
| 38 |
+
#include "../../common/cutlass_unit_test.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
|
| 41 |
+
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
| 42 |
+
#include "cutlass/reduction/device/reduce_split_k.h"
|
| 43 |
+
#include "cutlass/reduction/thread/reduction_operators.h"
|
| 44 |
+
|
| 45 |
+
#include "conv2d_problems.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/util/host_tensor.h"
|
| 48 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 49 |
+
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 50 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/util/reference/host/convolution.h"
|
| 53 |
+
#include "cutlass/util/reference/device/convolution.h"
|
| 54 |
+
|
| 55 |
+
#include "cutlass/core_io.h"
|
| 56 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 57 |
+
|
| 58 |
+
#include "../cache_testbed_output.h"
|
| 59 |
+
|
| 60 |
+
namespace test {
|
| 61 |
+
namespace conv {
|
| 62 |
+
namespace device {
|
| 63 |
+
|
| 64 |
+
template <typename Conv2d>
|
| 65 |
+
class TestbedConv2d {
|
| 66 |
+
public:
|
| 67 |
+
|
| 68 |
+
using ElementA = typename Conv2d::ElementA;
|
| 69 |
+
using LayoutA = typename Conv2d::LayoutA;
|
| 70 |
+
using ElementB = typename Conv2d::ElementB;
|
| 71 |
+
using LayoutB = typename Conv2d::LayoutB;
|
| 72 |
+
using ElementC = typename Conv2d::ElementC;
|
| 73 |
+
using LayoutC = typename Conv2d::LayoutC;
|
| 74 |
+
using ElementAccumulator = typename Conv2d::ElementAccumulator;
|
| 75 |
+
using ElementCompute = typename Conv2d::ElementCompute;
|
| 76 |
+
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
|
| 77 |
+
|
| 78 |
+
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
| 79 |
+
|
| 80 |
+
/// Reduction kernel
|
| 81 |
+
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
|
| 82 |
+
ElementAccumulator,
|
| 83 |
+
typename EpilogueOutputOp::ElementAccumulator,
|
| 84 |
+
EpilogueOutputOp::kCount
|
| 85 |
+
>;
|
| 86 |
+
|
| 87 |
+
using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
|
| 88 |
+
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
|
| 89 |
+
EpilogueOutputOp,
|
| 90 |
+
ReductionOp
|
| 91 |
+
>;
|
| 92 |
+
|
| 93 |
+
using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
|
| 94 |
+
using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
|
| 95 |
+
|
| 96 |
+
public:
|
| 97 |
+
|
| 98 |
+
/// Initialization
|
| 99 |
+
cutlass::Distribution::Kind init_A;
|
| 100 |
+
cutlass::Distribution::Kind init_B;
|
| 101 |
+
cutlass::Distribution::Kind init_C;
|
| 102 |
+
uint64_t seed;
|
| 103 |
+
|
| 104 |
+
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
| 105 |
+
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
| 106 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
| 107 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
|
| 108 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
|
| 109 |
+
|
| 110 |
+
int tested_problem_count;
|
| 111 |
+
|
| 112 |
+
public:
|
| 113 |
+
|
| 114 |
+
TestbedConv2d(
|
| 115 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 116 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 117 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 118 |
+
uint64_t seed_ = 2080
|
| 119 |
+
):
|
| 120 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {
|
| 121 |
+
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
/// Helper to initialize a tensor view
|
| 125 |
+
template <typename Element, typename Layout>
|
| 126 |
+
void initialize_tensor(
|
| 127 |
+
cutlass::TensorView<Element, Layout> view,
|
| 128 |
+
cutlass::Distribution::Kind dist_kind,
|
| 129 |
+
uint64_t seed) {
|
| 130 |
+
|
| 131 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 132 |
+
|
| 133 |
+
int scope;
|
| 134 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 135 |
+
|
| 136 |
+
if (bits <= 8) {
|
| 137 |
+
scope = 2;
|
| 138 |
+
}
|
| 139 |
+
else if (bits == 16) {
|
| 140 |
+
if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
|
| 141 |
+
scope = 3;
|
| 142 |
+
}
|
| 143 |
+
else {
|
| 144 |
+
scope = 5;
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
else {
|
| 148 |
+
scope = 8;
|
| 149 |
+
}
|
| 150 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 151 |
+
view, seed, scope, -scope, 0);
|
| 152 |
+
}
|
| 153 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 154 |
+
|
| 155 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 156 |
+
}
|
| 157 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 158 |
+
|
| 159 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 160 |
+
}
|
| 161 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 162 |
+
|
| 163 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 164 |
+
}
|
| 165 |
+
else {
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
void initialize(
|
| 170 |
+
cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
|
| 171 |
+
|
| 172 |
+
tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
|
| 173 |
+
tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
|
| 174 |
+
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 175 |
+
tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 176 |
+
tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 177 |
+
|
| 178 |
+
initialize_tensor(tensor_A.host_view(), init_A, seed);
|
| 179 |
+
initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
|
| 180 |
+
initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
|
| 181 |
+
|
| 182 |
+
tensor_A.sync_device();
|
| 183 |
+
tensor_B.sync_device();
|
| 184 |
+
tensor_C.sync_device();
|
| 185 |
+
tensor_D_computed.sync_device();
|
| 186 |
+
tensor_D_reference.sync_device();
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
bool sufficient() const {
|
| 190 |
+
//
|
| 191 |
+
// Determine SMEM requirements and waive if not satisfied
|
| 192 |
+
//
|
| 193 |
+
|
| 194 |
+
size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage);
|
| 195 |
+
|
| 196 |
+
cudaDeviceProp properties;
|
| 197 |
+
int device_idx;
|
| 198 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 199 |
+
|
| 200 |
+
if (result != cudaSuccess) {
|
| 201 |
+
throw std::runtime_error("cudaGetDevice() API call failed.");
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
result = cudaGetDeviceProperties(&properties, device_idx);
|
| 205 |
+
|
| 206 |
+
if (result != cudaSuccess) {
|
| 207 |
+
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
if (properties.sharedMemPerBlockOptin < smem_size) {
|
| 211 |
+
return false;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
return true;
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
/// Executes one test
|
| 218 |
+
bool run(
|
| 219 |
+
cutlass::conv::Conv2dProblemSize const &problem_size,
|
| 220 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 221 |
+
ElementCompute alpha = ElementCompute(1),
|
| 222 |
+
ElementCompute beta = ElementCompute(0)) {
|
| 223 |
+
|
| 224 |
+
// Waive test if insufficient CUDA device
|
| 225 |
+
if (!sufficient()) {
|
| 226 |
+
if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
|
| 227 |
+
std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
|
| 228 |
+
}
|
| 229 |
+
return true;
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
// increment tested problem count run by the testbed
|
| 233 |
+
tested_problem_count++;
|
| 234 |
+
|
| 235 |
+
#if 0 // display conv2d problem size for debugging
|
| 236 |
+
std::cout << problem_size << std::endl
|
| 237 |
+
<< "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
|
| 238 |
+
<< "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
|
| 239 |
+
<< std::endl;
|
| 240 |
+
#endif
|
| 241 |
+
|
| 242 |
+
initialize(problem_size);
|
| 243 |
+
|
| 244 |
+
// configure the operator
|
| 245 |
+
Conv2d conv2d_op;
|
| 246 |
+
|
| 247 |
+
typename Conv2d::Arguments conv2d_args(
|
| 248 |
+
problem_size,
|
| 249 |
+
tensor_A.device_ref(),
|
| 250 |
+
tensor_B.device_ref(),
|
| 251 |
+
tensor_C.device_ref(),
|
| 252 |
+
tensor_D_computed.device_ref(),
|
| 253 |
+
{alpha, beta},
|
| 254 |
+
split_k_mode
|
| 255 |
+
);
|
| 256 |
+
|
| 257 |
+
// find workspace requirement for parallel split-k reduction
|
| 258 |
+
size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
|
| 259 |
+
|
| 260 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 261 |
+
|
| 262 |
+
cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get());
|
| 263 |
+
|
| 264 |
+
if (status != cutlass::Status::kSuccess) {
|
| 265 |
+
cudaError_t error = cudaGetLastError();
|
| 266 |
+
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
|
| 267 |
+
return true;
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
// conv2d operation with parallel split-k-mode
|
| 271 |
+
if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
|
| 272 |
+
|
| 273 |
+
// conv2d output is written to workspace in global memory
|
| 274 |
+
conv2d_args.ref_D.reset(reinterpret_cast<ElementC*>(workspace.get()));
|
| 275 |
+
// accumulate mma for each cta in k-dimension (1.0 * A * B)
|
| 276 |
+
conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)};
|
| 277 |
+
// update conv2d operator arguments
|
| 278 |
+
status = conv2d_op.update(conv2d_args, workspace.get());
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 282 |
+
if (status != cutlass::Status::kSuccess) {
|
| 283 |
+
return false;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
// run conv2d operator
|
| 287 |
+
status = conv2d_op();
|
| 288 |
+
|
| 289 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 290 |
+
if (status != cutlass::Status::kSuccess) {
|
| 291 |
+
std::cerr << "Failed to run." << std::endl;
|
| 292 |
+
return false;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
|
| 297 |
+
|
| 298 |
+
// configure parallel reduction operator
|
| 299 |
+
ReductionDevice reduction_op;
|
| 300 |
+
|
| 301 |
+
typename ReductionDevice::Arguments reduction_args(
|
| 302 |
+
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
|
| 303 |
+
problem_size.split_k_slices,
|
| 304 |
+
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
|
| 305 |
+
{
|
| 306 |
+
reinterpret_cast<ElementAccumulator*> (workspace.get()),
|
| 307 |
+
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
tensor_D_computed.device_data(),
|
| 311 |
+
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
|
| 312 |
+
},
|
| 313 |
+
{
|
| 314 |
+
tensor_C.device_data(),
|
| 315 |
+
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
|
| 316 |
+
},
|
| 317 |
+
// apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C
|
| 318 |
+
{alpha, beta}
|
| 319 |
+
);
|
| 320 |
+
|
| 321 |
+
status = reduction_op.initialize(reduction_args, nullptr);
|
| 322 |
+
|
| 323 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 324 |
+
if (status != cutlass::Status::kSuccess) {
|
| 325 |
+
return false;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
// run prallel reduction kernel
|
| 329 |
+
status = reduction_op();
|
| 330 |
+
|
| 331 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 332 |
+
if (status != cutlass::Status::kSuccess) {
|
| 333 |
+
return false;
|
| 334 |
+
}
|
| 335 |
+
}
|
| 336 |
+
bool passed = false;
|
| 337 |
+
|
| 338 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 339 |
+
EXPECT_EQ(result, cudaSuccess) << " device reference error: "
|
| 340 |
+
<< cudaGetErrorString(result);
|
| 341 |
+
|
| 342 |
+
tensor_D_computed.sync_host();
|
| 343 |
+
|
| 344 |
+
//
|
| 345 |
+
// Reference check - support caching results
|
| 346 |
+
//
|
| 347 |
+
|
| 348 |
+
CachedTestKey cached_test_key = CreateCachedConv2dTestKey<
|
| 349 |
+
ElementA, LayoutA,
|
| 350 |
+
ElementB, LayoutB,
|
| 351 |
+
ElementC, LayoutC,
|
| 352 |
+
ElementAccumulator,
|
| 353 |
+
ElementCompute
|
| 354 |
+
>(
|
| 355 |
+
kConvolutionalOperator,
|
| 356 |
+
problem_size,
|
| 357 |
+
alpha,
|
| 358 |
+
beta,
|
| 359 |
+
tensor_A.host_view(),
|
| 360 |
+
tensor_B.host_view(),
|
| 361 |
+
tensor_C.host_view()
|
| 362 |
+
);
|
| 363 |
+
|
| 364 |
+
//
|
| 365 |
+
// Look for the cached key
|
| 366 |
+
//
|
| 367 |
+
|
| 368 |
+
bool cached_result_loaded = false;
|
| 369 |
+
CachedTestResult cached_test_result;
|
| 370 |
+
|
| 371 |
+
std::string conv2d_result_cache_name =
|
| 372 |
+
std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
|
| 373 |
+
|
| 374 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 375 |
+
|
| 376 |
+
CachedTestResultListing cached_results(conv2d_result_cache_name);
|
| 377 |
+
|
| 378 |
+
auto cached = cached_results.find(cached_test_key);
|
| 379 |
+
|
| 380 |
+
cached_result_loaded = cached.first;
|
| 381 |
+
if (cached_result_loaded) {
|
| 382 |
+
cached_test_result = cached.second;
|
| 383 |
+
}
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
if (!cached_result_loaded) {
|
| 387 |
+
|
| 388 |
+
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
|
| 389 |
+
|
| 390 |
+
cutlass::reference::device::Conv2d<
|
| 391 |
+
ElementA,
|
| 392 |
+
LayoutA,
|
| 393 |
+
ElementB,
|
| 394 |
+
LayoutB,
|
| 395 |
+
ElementC,
|
| 396 |
+
LayoutC,
|
| 397 |
+
ElementCompute,
|
| 398 |
+
ElementAccumulator
|
| 399 |
+
>(
|
| 400 |
+
kConvolutionalOperator,
|
| 401 |
+
problem_size,
|
| 402 |
+
tensor_A.device_ref(),
|
| 403 |
+
tensor_B.device_ref(),
|
| 404 |
+
tensor_C.device_ref(),
|
| 405 |
+
tensor_D_reference.device_ref(),
|
| 406 |
+
alpha,
|
| 407 |
+
beta);
|
| 408 |
+
|
| 409 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 410 |
+
tensor_D_reference.sync_host();
|
| 411 |
+
|
| 412 |
+
#else
|
| 413 |
+
|
| 414 |
+
cutlass::reference::host::Conv2d<
|
| 415 |
+
ElementA,
|
| 416 |
+
LayoutA,
|
| 417 |
+
ElementB,
|
| 418 |
+
LayoutB,
|
| 419 |
+
ElementC,
|
| 420 |
+
LayoutC,
|
| 421 |
+
ElementCompute,
|
| 422 |
+
ElementAccumulator
|
| 423 |
+
>(
|
| 424 |
+
kConvolutionalOperator,
|
| 425 |
+
problem_size,
|
| 426 |
+
tensor_A.host_ref(),
|
| 427 |
+
tensor_B.host_ref(),
|
| 428 |
+
tensor_C.host_ref(),
|
| 429 |
+
tensor_D_reference.host_ref(),
|
| 430 |
+
alpha,
|
| 431 |
+
beta);
|
| 432 |
+
|
| 433 |
+
#endif
|
| 434 |
+
|
| 435 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 436 |
+
|
| 437 |
+
cached_test_result.D = TensorHash(tensor_D_reference.host_view());
|
| 438 |
+
|
| 439 |
+
CachedTestResultListing cached_results(conv2d_result_cache_name);
|
| 440 |
+
|
| 441 |
+
cached_results.append(cached_test_key, cached_test_result);
|
| 442 |
+
cached_results.write(conv2d_result_cache_name);
|
| 443 |
+
}
|
| 444 |
+
} // if (!cached_result_loaded)
|
| 445 |
+
|
| 446 |
+
uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
|
| 447 |
+
|
| 448 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 449 |
+
passed = (tensor_D_hash == cached_test_result.D);
|
| 450 |
+
|
| 451 |
+
EXPECT_EQ(tensor_D_hash, cached_test_result.D)
|
| 452 |
+
<< "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
|
| 453 |
+
}
|
| 454 |
+
else {
|
| 455 |
+
|
| 456 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 457 |
+
tensor_D_computed.host_view(),
|
| 458 |
+
tensor_D_reference.host_view());
|
| 459 |
+
}
|
| 460 |
+
|
| 461 |
+
EXPECT_TRUE(passed);
|
| 462 |
+
|
| 463 |
+
std::stringstream ss_problem_size_text;
|
| 464 |
+
ss_problem_size_text << "nhwc_"
|
| 465 |
+
<< problem_size.N << "x"
|
| 466 |
+
<< problem_size.H << "x"
|
| 467 |
+
<< problem_size.W << "x"
|
| 468 |
+
<< problem_size.C
|
| 469 |
+
<< "_krsc_"
|
| 470 |
+
<< problem_size.K << "x"
|
| 471 |
+
<< problem_size.R << "x"
|
| 472 |
+
<< problem_size.S << "x"
|
| 473 |
+
<< problem_size.C
|
| 474 |
+
<< "_padding_"
|
| 475 |
+
<< problem_size.pad_h << "x"
|
| 476 |
+
<< problem_size.pad_w
|
| 477 |
+
<< "_stride_"
|
| 478 |
+
<< problem_size.stride_h << "x"
|
| 479 |
+
<< problem_size.stride_w
|
| 480 |
+
<< "_dilation_"
|
| 481 |
+
<< problem_size.dilation_h << "x"
|
| 482 |
+
<< problem_size.dilation_w << "_"
|
| 483 |
+
<< (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_");
|
| 484 |
+
|
| 485 |
+
if (!passed) {
|
| 486 |
+
std::stringstream fname;
|
| 487 |
+
|
| 488 |
+
fname << "error_Conv2d_ImplicitGemm_device_"
|
| 489 |
+
<< (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
|
| 490 |
+
<< (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
|
| 491 |
+
(Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" :
|
| 492 |
+
(Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_")))
|
| 493 |
+
<< ss_problem_size_text.str()
|
| 494 |
+
<< Conv2d::ThreadblockShape::kM << "x"
|
| 495 |
+
<< Conv2d::ThreadblockShape::kN << "x"
|
| 496 |
+
<< Conv2d::ThreadblockShape::kK << "_"
|
| 497 |
+
<< Conv2d::WarpShape::kM << "x"
|
| 498 |
+
<< Conv2d::WarpShape::kN << "x"
|
| 499 |
+
<< Conv2d::WarpShape::kK << ".txt";
|
| 500 |
+
|
| 501 |
+
std::cout << fname.str() << std::endl;
|
| 502 |
+
|
| 503 |
+
std::ofstream results(fname.str());
|
| 504 |
+
|
| 505 |
+
results << problem_size << std::endl;
|
| 506 |
+
|
| 507 |
+
results
|
| 508 |
+
<< "\nA:\n" << tensor_A.host_view() << "\n"
|
| 509 |
+
<< "\nB:\n" << tensor_B.host_view() << "\n"
|
| 510 |
+
<< "\nC:\n" << tensor_C.host_view() << "\n";
|
| 511 |
+
|
| 512 |
+
results << "\nD reference (hash: " << cached_test_result.D << ")\n";
|
| 513 |
+
|
| 514 |
+
if (!cached_result_loaded) {
|
| 515 |
+
results
|
| 516 |
+
<< tensor_D_reference.host_view() << "\n";
|
| 517 |
+
}
|
| 518 |
+
|
| 519 |
+
results
|
| 520 |
+
<< "\nD computed (hash: " << tensor_D_hash << ")\n"
|
| 521 |
+
<< tensor_D_computed.host_view() << "\n";
|
| 522 |
+
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
return passed;
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
};
|
| 529 |
+
|
| 530 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 531 |
+
|
| 532 |
+
template <typename ImplicitGemm>
|
| 533 |
+
bool TestSpecificConv2d(
|
| 534 |
+
const Conv2dProblemVector & problem_sizes) {
|
| 535 |
+
|
| 536 |
+
bool passed = true;
|
| 537 |
+
|
| 538 |
+
//
|
| 539 |
+
// Testbed object
|
| 540 |
+
//
|
| 541 |
+
|
| 542 |
+
TestbedConv2d<ImplicitGemm> testbed;
|
| 543 |
+
|
| 544 |
+
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 545 |
+
for(auto conv_problem : problem_sizes) {
|
| 546 |
+
|
| 547 |
+
//
|
| 548 |
+
// Test
|
| 549 |
+
//
|
| 550 |
+
|
| 551 |
+
// test mode = xcross
|
| 552 |
+
passed = testbed.run(
|
| 553 |
+
conv_problem,
|
| 554 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 555 |
+
|
| 556 |
+
if (!passed) {
|
| 557 |
+
return false;
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
+
// test mode = convolution
|
| 561 |
+
passed = testbed.run(
|
| 562 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 563 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 564 |
+
|
| 565 |
+
if (!passed) {
|
| 566 |
+
return false;
|
| 567 |
+
}
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
return true;
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 574 |
+
// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
|
| 575 |
+
// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
|
| 576 |
+
// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
|
| 577 |
+
// (conv_blacklist_sizes)
|
| 578 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 579 |
+
template <typename ImplicitGemm>
|
| 580 |
+
bool TestAllConv2d(
|
| 581 |
+
const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(),
|
| 582 |
+
const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) {
|
| 583 |
+
|
| 584 |
+
bool passed = true;
|
| 585 |
+
|
| 586 |
+
//
|
| 587 |
+
// Testbed object
|
| 588 |
+
//
|
| 589 |
+
|
| 590 |
+
TestbedConv2d<ImplicitGemm> testbed;
|
| 591 |
+
|
| 592 |
+
//
|
| 593 |
+
// Get conv problem sizes to run conv operator
|
| 594 |
+
//
|
| 595 |
+
TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
|
| 596 |
+
|
| 597 |
+
// Vector of conv2d problem sizes to avoid duplicate runs
|
| 598 |
+
Conv2dProblemVector conv_tested_sizes;
|
| 599 |
+
|
| 600 |
+
// Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes)
|
| 601 |
+
std::vector<Conv2dProblemVector> problem_vectors = {
|
| 602 |
+
conv_test_sizes, // run user specified sizes
|
| 603 |
+
conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
|
| 604 |
+
//conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
|
| 605 |
+
#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
|
| 606 |
+
conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
|
| 607 |
+
#endif
|
| 608 |
+
};
|
| 609 |
+
|
| 610 |
+
// Flatten 2D problem_vectors into a 1D problem_sizes
|
| 611 |
+
std::vector<cutlass::conv::Conv2dProblemSize> problem_sizes;
|
| 612 |
+
for (auto problem_vector : problem_vectors) {
|
| 613 |
+
for(auto conv_problem : problem_vector) {
|
| 614 |
+
problem_sizes.push_back(conv_problem);
|
| 615 |
+
}
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient)
|
| 619 |
+
// run the most rigorous problem size first
|
| 620 |
+
if (CutlassUnitTestProblemCount()) {
|
| 621 |
+
std::reverse(problem_sizes.begin(), problem_sizes.end());
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 625 |
+
for(auto conv_problem : problem_sizes) {
|
| 626 |
+
|
| 627 |
+
// Skip blacklist and avoid duplicate problem sizes
|
| 628 |
+
if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
|
| 629 |
+
std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
|
| 630 |
+
continue;
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
//
|
| 634 |
+
// Procedurally disable certain cases
|
| 635 |
+
//
|
| 636 |
+
|
| 637 |
+
// CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
|
| 638 |
+
if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
|
| 639 |
+
ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
|
| 640 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 641 |
+
cutlass::conv::StrideSupport::kUnity)) {
|
| 642 |
+
if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
|
| 643 |
+
continue;
|
| 644 |
+
}
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
// Fixed channels algorithm requires channel count to match access size
|
| 648 |
+
if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
|
| 649 |
+
cutlass::conv::IteratorAlgorithm::kFixedChannels) {
|
| 650 |
+
if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) {
|
| 651 |
+
continue;
|
| 652 |
+
}
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
// Few channels algorithm requires channel count to match access size
|
| 656 |
+
if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
|
| 657 |
+
cutlass::conv::IteratorAlgorithm::kFewChannels) {
|
| 658 |
+
if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) {
|
| 659 |
+
continue;
|
| 660 |
+
}
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
// CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w}
|
| 664 |
+
// Although strided dgrad works for all stride combinations, we are only going
|
| 665 |
+
// to run strided dgrad for non-unity strides
|
| 666 |
+
if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
|
| 667 |
+
ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
|
| 668 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 669 |
+
cutlass::conv::StrideSupport::kStrided)) {
|
| 670 |
+
if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
|
| 671 |
+
continue;
|
| 672 |
+
}
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
//
|
| 676 |
+
// Test
|
| 677 |
+
//
|
| 678 |
+
// push back tested problem size to avoid re-running duplicates
|
| 679 |
+
conv_tested_sizes.push_back(conv_problem);
|
| 680 |
+
|
| 681 |
+
// test mode = xcross
|
| 682 |
+
passed = testbed.run(
|
| 683 |
+
conv_problem,
|
| 684 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 685 |
+
|
| 686 |
+
if (!passed) {
|
| 687 |
+
return false;
|
| 688 |
+
}
|
| 689 |
+
|
| 690 |
+
// test mode = convolution
|
| 691 |
+
passed = testbed.run(
|
| 692 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 693 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 694 |
+
|
| 695 |
+
if (!passed) {
|
| 696 |
+
return false;
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts
|
| 700 |
+
if (CutlassUnitTestProblemCount() &&
|
| 701 |
+
testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
|
| 702 |
+
return true;
|
| 703 |
+
}
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
// Small-channels convolution can't run here.
|
| 707 |
+
if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
|
| 708 |
+
cutlass::conv::IteratorAlgorithm::kFixedChannels) {
|
| 709 |
+
|
| 710 |
+
return true;
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
// Small-channels convolution can't run here.
|
| 714 |
+
if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
|
| 715 |
+
cutlass::conv::IteratorAlgorithm::kFewChannels) {
|
| 716 |
+
|
| 717 |
+
return true;
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
// CUTLASS DGRAD's *strided* specialization does not support split-k mode
|
| 721 |
+
if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
|
| 722 |
+
ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
|
| 723 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 724 |
+
cutlass::conv::StrideSupport::kStrided)) {
|
| 725 |
+
|
| 726 |
+
passed = testbed.run(
|
| 727 |
+
cutlass::conv::Conv2dProblemSize(
|
| 728 |
+
{1, 56, 56, 8}, // input size (NHWC)
|
| 729 |
+
{8, 1, 1, 8}, // filter size (KRSC)
|
| 730 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 731 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 732 |
+
{1, 1}), // dilation (dilation_h, dilation_w)
|
| 733 |
+
cutlass::conv::SplitKMode::kSerial,
|
| 734 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0),
|
| 735 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0));
|
| 736 |
+
|
| 737 |
+
passed = testbed.run(
|
| 738 |
+
cutlass::conv::Conv2dProblemSize(
|
| 739 |
+
{1, 56, 56, 8}, // input size (NHWC)
|
| 740 |
+
{8, 1, 1, 8}, // filter size (KRSC)
|
| 741 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 742 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 743 |
+
{1, 1}) // dilation (dilation_h, dilation_w)
|
| 744 |
+
.reset_split_k_slices(2),
|
| 745 |
+
cutlass::conv::SplitKMode::kSerial,
|
| 746 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0),
|
| 747 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0));
|
| 748 |
+
|
| 749 |
+
if (!passed) {
|
| 750 |
+
return false;
|
| 751 |
+
}
|
| 752 |
+
|
| 753 |
+
return passed;
|
| 754 |
+
}
|
| 755 |
+
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
| 756 |
+
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
| 757 |
+
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
| 758 |
+
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
| 759 |
+
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
| 760 |
+
{1, 17, 11, 288}, // input size (NHWC)
|
| 761 |
+
{160, 3, 3, 288}, // filter size (KRSC)
|
| 762 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 763 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 764 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 765 |
+
);
|
| 766 |
+
|
| 767 |
+
cutlass::conv::SplitKMode split_k_modes [] = {
|
| 768 |
+
cutlass::conv::SplitKMode::kSerial,
|
| 769 |
+
cutlass::conv::SplitKMode::kParallel,
|
| 770 |
+
};
|
| 771 |
+
|
| 772 |
+
int split_k_slices[] = {
|
| 773 |
+
1, 2, 3, 4, 201
|
| 774 |
+
};
|
| 775 |
+
|
| 776 |
+
double problem_alpha[] = {
|
| 777 |
+
2.0
|
| 778 |
+
};
|
| 779 |
+
|
| 780 |
+
double problem_beta[] = {
|
| 781 |
+
2.0
|
| 782 |
+
};
|
| 783 |
+
|
| 784 |
+
for (auto split_k_mode : split_k_modes) {
|
| 785 |
+
for (auto split_k_slice : split_k_slices) {
|
| 786 |
+
for (auto alpha : problem_alpha) {
|
| 787 |
+
for (auto beta : problem_beta) {
|
| 788 |
+
|
| 789 |
+
passed = testbed.run(
|
| 790 |
+
conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
|
| 791 |
+
split_k_mode,
|
| 792 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
|
| 793 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
|
| 794 |
+
|
| 795 |
+
if (!passed) {
|
| 796 |
+
return false;
|
| 797 |
+
}
|
| 798 |
+
|
| 799 |
+
// If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts
|
| 800 |
+
if (CutlassUnitTestProblemCount() &&
|
| 801 |
+
testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
|
| 802 |
+
return true;
|
| 803 |
+
}
|
| 804 |
+
}
|
| 805 |
+
}
|
| 806 |
+
}
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
return passed;
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 813 |
+
|
| 814 |
+
} // namespace device
|
| 815 |
+
} // namespace conv
|
| 816 |
+
} // namespace test
|
| 817 |
+
|
| 818 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implicit GEMM testbed
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include <fstream>
|
| 37 |
+
|
| 38 |
+
#include "../../common/cutlass_unit_test.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
|
| 41 |
+
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
| 42 |
+
#include "cutlass/reduction/device/reduce_split_k.h"
|
| 43 |
+
#include "cutlass/reduction/thread/reduction_operators.h"
|
| 44 |
+
|
| 45 |
+
#include "conv2d_problems.h"
|
| 46 |
+
|
| 47 |
+
#include "cutlass/util/host_tensor.h"
|
| 48 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 49 |
+
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 50 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 51 |
+
#include "cutlass/util/host_reorder.h"
|
| 52 |
+
|
| 53 |
+
#include "cutlass/util/reference/host/convolution.h"
|
| 54 |
+
#include "cutlass/util/reference/device/convolution.h"
|
| 55 |
+
|
| 56 |
+
#include "cutlass/core_io.h"
|
| 57 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 58 |
+
|
| 59 |
+
#include "../cache_testbed_output.h"
|
| 60 |
+
|
| 61 |
+
namespace test {
|
| 62 |
+
namespace conv {
|
| 63 |
+
namespace device {
|
| 64 |
+
|
| 65 |
+
template <typename Conv2d, int InterleavedK>
|
| 66 |
+
class InterleavedTestbedConv2d {
|
| 67 |
+
public:
|
| 68 |
+
|
| 69 |
+
using ElementA = typename Conv2d::ElementA;
|
| 70 |
+
using LayoutA = typename Conv2d::LayoutA;
|
| 71 |
+
using ElementB = typename Conv2d::ElementB;
|
| 72 |
+
using LayoutB = typename Conv2d::LayoutB;
|
| 73 |
+
using ElementC = typename Conv2d::ElementC;
|
| 74 |
+
using LayoutC = typename Conv2d::LayoutC;
|
| 75 |
+
using ElementAccumulator = typename Conv2d::ElementAccumulator;
|
| 76 |
+
using ElementCompute = typename Conv2d::ElementCompute;
|
| 77 |
+
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
|
| 78 |
+
|
| 79 |
+
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
| 80 |
+
|
| 81 |
+
/// Reduction kernel
|
| 82 |
+
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
|
| 83 |
+
ElementAccumulator,
|
| 84 |
+
typename EpilogueOutputOp::ElementAccumulator,
|
| 85 |
+
EpilogueOutputOp::kCount
|
| 86 |
+
>;
|
| 87 |
+
|
| 88 |
+
using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
|
| 89 |
+
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
|
| 90 |
+
EpilogueOutputOp,
|
| 91 |
+
ReductionOp
|
| 92 |
+
>;
|
| 93 |
+
|
| 94 |
+
using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
|
| 95 |
+
using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
|
| 96 |
+
|
| 97 |
+
public:
|
| 98 |
+
|
| 99 |
+
/// Initialization
|
| 100 |
+
cutlass::Distribution::Kind init_A;
|
| 101 |
+
cutlass::Distribution::Kind init_B;
|
| 102 |
+
cutlass::Distribution::Kind init_C;
|
| 103 |
+
uint64_t seed;
|
| 104 |
+
|
| 105 |
+
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
| 106 |
+
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
| 107 |
+
cutlass::HostTensor<ElementB, LayoutB> tensor_B_reordered;
|
| 108 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
| 109 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
|
| 110 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
|
| 111 |
+
|
| 112 |
+
public:
|
| 113 |
+
|
| 114 |
+
InterleavedTestbedConv2d(
|
| 115 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 116 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 117 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 118 |
+
uint64_t seed_ = 2080
|
| 119 |
+
):
|
| 120 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
|
| 121 |
+
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
/// Helper to initialize a tensor view
|
| 125 |
+
template <typename Element, typename Layout>
|
| 126 |
+
void initialize_tensor(
|
| 127 |
+
cutlass::TensorView<Element, Layout> view,
|
| 128 |
+
cutlass::Distribution::Kind dist_kind,
|
| 129 |
+
uint64_t seed) {
|
| 130 |
+
|
| 131 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 132 |
+
|
| 133 |
+
int scope;
|
| 134 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 135 |
+
|
| 136 |
+
if (bits <= 8) {
|
| 137 |
+
scope = 2;
|
| 138 |
+
}
|
| 139 |
+
else if (bits == 16) {
|
| 140 |
+
scope = 3;
|
| 141 |
+
}
|
| 142 |
+
else {
|
| 143 |
+
scope = 8;
|
| 144 |
+
}
|
| 145 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 146 |
+
view, seed, scope, -scope, 0);
|
| 147 |
+
}
|
| 148 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 149 |
+
|
| 150 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 151 |
+
}
|
| 152 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 153 |
+
|
| 154 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 155 |
+
}
|
| 156 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 157 |
+
|
| 158 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 159 |
+
}
|
| 160 |
+
else {
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
void initialize(
|
| 165 |
+
cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
|
| 166 |
+
|
| 167 |
+
tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
|
| 168 |
+
tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
|
| 169 |
+
tensor_B_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
|
| 170 |
+
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 171 |
+
tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 172 |
+
tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 173 |
+
|
| 174 |
+
initialize_tensor(tensor_A.host_view(), init_A, seed);
|
| 175 |
+
initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
|
| 176 |
+
initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
|
| 177 |
+
|
| 178 |
+
cutlass::reorder_convK<InterleavedK>(
|
| 179 |
+
tensor_B_reordered.host_ref(), tensor_B.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size));
|
| 180 |
+
|
| 181 |
+
tensor_A.sync_device();
|
| 182 |
+
tensor_B.sync_device();
|
| 183 |
+
tensor_B_reordered.sync_device();
|
| 184 |
+
tensor_C.sync_device();
|
| 185 |
+
tensor_D_computed.sync_device();
|
| 186 |
+
tensor_D_reference.sync_device();
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
bool sufficient() const {
|
| 190 |
+
//
|
| 191 |
+
// Determine SMEM requirements and waive if not satisfied
|
| 192 |
+
//
|
| 193 |
+
|
| 194 |
+
size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage);
|
| 195 |
+
|
| 196 |
+
cudaDeviceProp properties;
|
| 197 |
+
int device_idx;
|
| 198 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 199 |
+
|
| 200 |
+
if (result != cudaSuccess) {
|
| 201 |
+
throw std::runtime_error("cudaGetDevice() API call failed.");
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
result = cudaGetDeviceProperties(&properties, device_idx);
|
| 205 |
+
|
| 206 |
+
if (result != cudaSuccess) {
|
| 207 |
+
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
if (properties.sharedMemPerMultiprocessor < smem_size) {
|
| 211 |
+
return false;
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
return true;
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
/// Executes one test
|
| 218 |
+
bool run(
|
| 219 |
+
cutlass::conv::Conv2dProblemSize const &problem_size,
|
| 220 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 221 |
+
ElementCompute alpha = ElementCompute(1),
|
| 222 |
+
ElementCompute beta = ElementCompute(0)) {
|
| 223 |
+
|
| 224 |
+
// Waive test if insufficient CUDA device
|
| 225 |
+
if (!sufficient()) {
|
| 226 |
+
if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
|
| 227 |
+
std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
|
| 228 |
+
}
|
| 229 |
+
return true;
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
#if 0 //display conv2d problem size for debugging
|
| 233 |
+
std::cout << problem_size << std::endl
|
| 234 |
+
<< "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl
|
| 235 |
+
<< "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
|
| 236 |
+
<< std::endl;
|
| 237 |
+
#endif
|
| 238 |
+
|
| 239 |
+
initialize(problem_size);
|
| 240 |
+
|
| 241 |
+
// configure the operator
|
| 242 |
+
Conv2d conv2d_op;
|
| 243 |
+
|
| 244 |
+
typename Conv2d::Arguments conv2d_args(
|
| 245 |
+
problem_size,
|
| 246 |
+
tensor_A.device_ref(),
|
| 247 |
+
tensor_B_reordered.device_ref(),
|
| 248 |
+
tensor_C.device_ref(),
|
| 249 |
+
tensor_D_computed.device_ref(),
|
| 250 |
+
{alpha, beta},
|
| 251 |
+
split_k_mode
|
| 252 |
+
);
|
| 253 |
+
|
| 254 |
+
// find workspace requirement for parallel split-k reduction
|
| 255 |
+
size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
|
| 256 |
+
|
| 257 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 258 |
+
|
| 259 |
+
cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get());
|
| 260 |
+
|
| 261 |
+
// conv2d operation with parallel split-k-mode
|
| 262 |
+
if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
|
| 263 |
+
|
| 264 |
+
// conv2d output is written to workspace in global memory
|
| 265 |
+
conv2d_args.ref_D.reset(reinterpret_cast<ElementC*>(workspace.get()));
|
| 266 |
+
// accumulate mma for each cta in k-dimension (1.0 * A * B)
|
| 267 |
+
conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)};
|
| 268 |
+
// update conv2d operator arguments
|
| 269 |
+
status = conv2d_op.update(conv2d_args, workspace.get());
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 273 |
+
if (status != cutlass::Status::kSuccess) {
|
| 274 |
+
return false;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
// run conv2d operator
|
| 278 |
+
status = conv2d_op();
|
| 279 |
+
|
| 280 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 281 |
+
if (status != cutlass::Status::kSuccess) {
|
| 282 |
+
return false;
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
|
| 286 |
+
|
| 287 |
+
// configure parallel reduction operator
|
| 288 |
+
ReductionDevice reduction_op;
|
| 289 |
+
|
| 290 |
+
typename ReductionDevice::Arguments reduction_args(
|
| 291 |
+
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
|
| 292 |
+
problem_size.split_k_slices,
|
| 293 |
+
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
|
| 294 |
+
{
|
| 295 |
+
reinterpret_cast<ElementAccumulator*> (workspace.get()),
|
| 296 |
+
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
tensor_D_computed.device_data(),
|
| 300 |
+
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
|
| 301 |
+
},
|
| 302 |
+
{
|
| 303 |
+
tensor_C.device_data(),
|
| 304 |
+
ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
|
| 305 |
+
},
|
| 306 |
+
// apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C
|
| 307 |
+
{alpha, beta}
|
| 308 |
+
);
|
| 309 |
+
|
| 310 |
+
status = reduction_op.initialize(reduction_args, nullptr);
|
| 311 |
+
|
| 312 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 313 |
+
if (status != cutlass::Status::kSuccess) {
|
| 314 |
+
return false;
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
// run prallel reduction kernel
|
| 318 |
+
status = reduction_op();
|
| 319 |
+
|
| 320 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 321 |
+
if (status != cutlass::Status::kSuccess) {
|
| 322 |
+
return false;
|
| 323 |
+
}
|
| 324 |
+
}
|
| 325 |
+
bool passed = false;
|
| 326 |
+
|
| 327 |
+
tensor_D_computed.sync_host();
|
| 328 |
+
|
| 329 |
+
//
|
| 330 |
+
// Reference check - support caching results
|
| 331 |
+
//
|
| 332 |
+
|
| 333 |
+
CachedTestKey cached_test_key = CreateCachedConv2dTestKey<
|
| 334 |
+
ElementA, LayoutA,
|
| 335 |
+
ElementB, LayoutB,
|
| 336 |
+
ElementC, LayoutC,
|
| 337 |
+
ElementAccumulator,
|
| 338 |
+
ElementCompute
|
| 339 |
+
>(
|
| 340 |
+
kConvolutionalOperator,
|
| 341 |
+
problem_size,
|
| 342 |
+
alpha,
|
| 343 |
+
beta,
|
| 344 |
+
tensor_A.host_view(),
|
| 345 |
+
tensor_B.host_view(),
|
| 346 |
+
tensor_C.host_view()
|
| 347 |
+
);
|
| 348 |
+
|
| 349 |
+
//
|
| 350 |
+
// Look for the cached key
|
| 351 |
+
//
|
| 352 |
+
|
| 353 |
+
bool cached_result_loaded = false;
|
| 354 |
+
CachedTestResult cached_test_result;
|
| 355 |
+
|
| 356 |
+
std::string conv2d_result_cache_name =
|
| 357 |
+
std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
|
| 358 |
+
|
| 359 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 360 |
+
|
| 361 |
+
CachedTestResultListing cached_results(conv2d_result_cache_name);
|
| 362 |
+
|
| 363 |
+
auto cached = cached_results.find(cached_test_key);
|
| 364 |
+
|
| 365 |
+
cached_result_loaded = cached.first;
|
| 366 |
+
if (cached_result_loaded) {
|
| 367 |
+
cached_test_result = cached.second;
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
if (!cached_result_loaded) {
|
| 372 |
+
|
| 373 |
+
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
|
| 374 |
+
|
| 375 |
+
cutlass::reference::device::Conv2d<
|
| 376 |
+
ElementA,
|
| 377 |
+
LayoutA,
|
| 378 |
+
ElementB,
|
| 379 |
+
LayoutB,
|
| 380 |
+
ElementC,
|
| 381 |
+
LayoutC,
|
| 382 |
+
ElementCompute,
|
| 383 |
+
ElementAccumulator,
|
| 384 |
+
cutlass::NumericConverterClamp<ElementC, ElementCompute>
|
| 385 |
+
>(
|
| 386 |
+
kConvolutionalOperator,
|
| 387 |
+
problem_size,
|
| 388 |
+
tensor_A.device_ref(),
|
| 389 |
+
tensor_B.device_ref(),
|
| 390 |
+
tensor_C.device_ref(),
|
| 391 |
+
tensor_D_reference.device_ref(),
|
| 392 |
+
alpha,
|
| 393 |
+
beta);
|
| 394 |
+
|
| 395 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 396 |
+
EXPECT_EQ(result, cudaSuccess) << " device reference error: "
|
| 397 |
+
<< cudaGetErrorString(result);
|
| 398 |
+
|
| 399 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 400 |
+
tensor_D_reference.sync_host();
|
| 401 |
+
|
| 402 |
+
#else
|
| 403 |
+
|
| 404 |
+
cutlass::reference::host::Conv2d<
|
| 405 |
+
ElementA,
|
| 406 |
+
LayoutA,
|
| 407 |
+
ElementB,
|
| 408 |
+
LayoutB,
|
| 409 |
+
ElementC,
|
| 410 |
+
LayoutC,
|
| 411 |
+
ElementCompute,
|
| 412 |
+
ElementAccumulator,
|
| 413 |
+
ElementC,
|
| 414 |
+
cutlass::NumericConverterClamp<ElementC, ElementCompute>
|
| 415 |
+
>(
|
| 416 |
+
kConvolutionalOperator,
|
| 417 |
+
problem_size,
|
| 418 |
+
tensor_A.host_ref(),
|
| 419 |
+
tensor_B.host_ref(),
|
| 420 |
+
tensor_C.host_ref(),
|
| 421 |
+
tensor_D_reference.host_ref(),
|
| 422 |
+
alpha,
|
| 423 |
+
beta);
|
| 424 |
+
|
| 425 |
+
#endif
|
| 426 |
+
|
| 427 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 428 |
+
|
| 429 |
+
cached_test_result.D = TensorHash(tensor_D_reference.host_view());
|
| 430 |
+
|
| 431 |
+
CachedTestResultListing cached_results(conv2d_result_cache_name);
|
| 432 |
+
|
| 433 |
+
cached_results.append(cached_test_key, cached_test_result);
|
| 434 |
+
cached_results.write(conv2d_result_cache_name);
|
| 435 |
+
}
|
| 436 |
+
} // if (!cached_result_loaded)
|
| 437 |
+
|
| 438 |
+
uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
|
| 439 |
+
|
| 440 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 441 |
+
passed = (tensor_D_hash == cached_test_result.D);
|
| 442 |
+
|
| 443 |
+
EXPECT_EQ(tensor_D_hash, cached_test_result.D)
|
| 444 |
+
<< "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
|
| 445 |
+
}
|
| 446 |
+
else {
|
| 447 |
+
|
| 448 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 449 |
+
tensor_D_computed.host_view(),
|
| 450 |
+
tensor_D_reference.host_view());
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
EXPECT_TRUE(passed);
|
| 454 |
+
|
| 455 |
+
if (!passed) {
|
| 456 |
+
std::stringstream fname;
|
| 457 |
+
|
| 458 |
+
fname << "error_Conv2d_ImplicitGemm_device_"
|
| 459 |
+
<< (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
|
| 460 |
+
<< (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
|
| 461 |
+
(Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_"))
|
| 462 |
+
<< "ncxhwx_"
|
| 463 |
+
<< problem_size.N << "x"
|
| 464 |
+
<< problem_size.H << "x"
|
| 465 |
+
<< problem_size.W << "x"
|
| 466 |
+
<< problem_size.C
|
| 467 |
+
<< "_cxrskx_"
|
| 468 |
+
<< problem_size.K << "x"
|
| 469 |
+
<< problem_size.R << "x"
|
| 470 |
+
<< problem_size.S << "x"
|
| 471 |
+
<< problem_size.C
|
| 472 |
+
<< "_padding_"
|
| 473 |
+
<< problem_size.pad_h << "x"
|
| 474 |
+
<< problem_size.pad_w
|
| 475 |
+
<< "_stride_"
|
| 476 |
+
<< problem_size.stride_h << "x"
|
| 477 |
+
<< problem_size.stride_w
|
| 478 |
+
<< "_dilation_"
|
| 479 |
+
<< problem_size.dilation_h << "x"
|
| 480 |
+
<< problem_size.dilation_w << "_"
|
| 481 |
+
<< (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_")
|
| 482 |
+
<< Conv2d::ThreadblockShape::kM << "x"
|
| 483 |
+
<< Conv2d::ThreadblockShape::kN << "x"
|
| 484 |
+
<< Conv2d::ThreadblockShape::kK << "_"
|
| 485 |
+
<< Conv2d::WarpShape::kM << "x"
|
| 486 |
+
<< Conv2d::WarpShape::kN << "x"
|
| 487 |
+
<< Conv2d::WarpShape::kK << ".txt";
|
| 488 |
+
|
| 489 |
+
std::cout << fname.str() << std::endl;
|
| 490 |
+
|
| 491 |
+
std::ofstream results(fname.str());
|
| 492 |
+
|
| 493 |
+
results << problem_size << std::endl;
|
| 494 |
+
|
| 495 |
+
results
|
| 496 |
+
<< "\nA:\n" << tensor_A.host_view() << "\n"
|
| 497 |
+
<< "\nB:\n" << tensor_B.host_view() << "\n"
|
| 498 |
+
<< "\nC:\n" << tensor_C.host_view() << "\n";
|
| 499 |
+
|
| 500 |
+
results << "\nD reference (hash: " << cached_test_result.D << ")\n";
|
| 501 |
+
|
| 502 |
+
if (!cached_result_loaded) {
|
| 503 |
+
results
|
| 504 |
+
<< tensor_D_reference.host_view() << "\n";
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
results
|
| 508 |
+
<< "\nD computed (hash: " << tensor_D_hash << ")\n"
|
| 509 |
+
<< tensor_D_computed.host_view() << "\n";
|
| 510 |
+
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
return passed;
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
};
|
| 517 |
+
|
| 518 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 519 |
+
// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
|
| 520 |
+
// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
|
| 521 |
+
// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
|
| 522 |
+
// (conv_blacklist_sizes)
|
| 523 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 524 |
+
template <typename ImplicitGemm, int InterleavedK>
|
| 525 |
+
bool TestAllInterleavedConv2d(
|
| 526 |
+
const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(),
|
| 527 |
+
const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) {
|
| 528 |
+
|
| 529 |
+
bool passed = true;
|
| 530 |
+
|
| 531 |
+
//
|
| 532 |
+
// Testbed object
|
| 533 |
+
//
|
| 534 |
+
|
| 535 |
+
InterleavedTestbedConv2d<ImplicitGemm, InterleavedK> testbed;
|
| 536 |
+
|
| 537 |
+
//
|
| 538 |
+
// Get conv problem sizes to run conv operator
|
| 539 |
+
//
|
| 540 |
+
TestbedConv2dProblemSizes conv_problems(InterleavedK); // minimum channel size must be multiple of InterleavedK for interleaved layout
|
| 541 |
+
|
| 542 |
+
// Vector of conv2d problem sizes to avoid duplicate runs
|
| 543 |
+
Conv2dProblemVector conv_tested_sizes;
|
| 544 |
+
|
| 545 |
+
Conv2dProblemVector const *problem_vectors[] = {
|
| 546 |
+
&conv_test_sizes, // run user specified sizes
|
| 547 |
+
&conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
|
| 548 |
+
&conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
|
| 549 |
+
#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
|
| 550 |
+
&conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
|
| 551 |
+
#endif
|
| 552 |
+
};
|
| 553 |
+
|
| 554 |
+
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 555 |
+
for (Conv2dProblemVector const * problem_vector : problem_vectors) {
|
| 556 |
+
|
| 557 |
+
ChannelDivisibilitySpecification channel_spec(InterleavedK); //input and output channels must be multiple of InterleavedK
|
| 558 |
+
auto pruned_problem_vector = prune(*problem_vector, channel_spec);
|
| 559 |
+
|
| 560 |
+
// Run conv testbed on default convolution sizes
|
| 561 |
+
for(auto conv_problem : pruned_problem_vector) {
|
| 562 |
+
|
| 563 |
+
// Skip blacklist and avoid duplicate problem sizes
|
| 564 |
+
if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
|
| 565 |
+
std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
|
| 566 |
+
continue;
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
//
|
| 570 |
+
// Procedurally disable certain cases
|
| 571 |
+
//
|
| 572 |
+
|
| 573 |
+
// CUTLASS DGRAD's unity stride specialization only support stride {1, 1}
|
| 574 |
+
if ((ImplicitGemm::kConvolutionalOperator ==
|
| 575 |
+
cutlass::conv::Operator::kDgrad) &&
|
| 576 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 577 |
+
cutlass::conv::StrideSupport::kUnity)) {
|
| 578 |
+
if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
|
| 579 |
+
continue;
|
| 580 |
+
}
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
//
|
| 584 |
+
// Test
|
| 585 |
+
//
|
| 586 |
+
// push back tested problem size to avoid re-running duplicates
|
| 587 |
+
conv_tested_sizes.push_back(conv_problem);
|
| 588 |
+
|
| 589 |
+
// test mode = xcross
|
| 590 |
+
passed = testbed.run(
|
| 591 |
+
conv_problem,
|
| 592 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 593 |
+
|
| 594 |
+
if (!passed) {
|
| 595 |
+
return false;
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
// test mode = convolution
|
| 599 |
+
passed = testbed.run(
|
| 600 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 601 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 602 |
+
|
| 603 |
+
if (!passed) {
|
| 604 |
+
return false;
|
| 605 |
+
}
|
| 606 |
+
}
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
#if 0
|
| 610 |
+
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
| 611 |
+
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
| 612 |
+
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
| 613 |
+
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
| 614 |
+
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
| 615 |
+
{1, 17, 11, 288}, // input size (NHWC)
|
| 616 |
+
{160, 3, 3, 288}, // filter size (KRSC)
|
| 617 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 618 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 619 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 620 |
+
);
|
| 621 |
+
|
| 622 |
+
cutlass::conv::SplitKMode split_k_modes [] = {
|
| 623 |
+
cutlass::conv::SplitKMode::kSerial,
|
| 624 |
+
cutlass::conv::SplitKMode::kParallel,
|
| 625 |
+
};
|
| 626 |
+
|
| 627 |
+
int split_k_slices[] = {
|
| 628 |
+
1, 2, 3, 4, 201
|
| 629 |
+
};
|
| 630 |
+
|
| 631 |
+
double problem_alpha[] = {
|
| 632 |
+
2.0
|
| 633 |
+
};
|
| 634 |
+
|
| 635 |
+
double problem_beta[] = {
|
| 636 |
+
2.0
|
| 637 |
+
};
|
| 638 |
+
|
| 639 |
+
for (auto split_k_mode : split_k_modes) {
|
| 640 |
+
for (auto split_k_slice : split_k_slices) {
|
| 641 |
+
for (auto alpha : problem_alpha) {
|
| 642 |
+
for (auto beta : problem_beta) {
|
| 643 |
+
|
| 644 |
+
passed = testbed.run(
|
| 645 |
+
conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
|
| 646 |
+
split_k_mode,
|
| 647 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
|
| 648 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
|
| 649 |
+
|
| 650 |
+
if (!passed) {
|
| 651 |
+
return false;
|
| 652 |
+
}
|
| 653 |
+
}
|
| 654 |
+
}
|
| 655 |
+
}
|
| 656 |
+
}
|
| 657 |
+
#endif
|
| 658 |
+
|
| 659 |
+
return passed;
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 663 |
+
|
| 664 |
+
} // namespace device
|
| 665 |
+
} // namespace conv
|
| 666 |
+
} // namespace test
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
/*! \file
|
| 33 |
+
\brief Testbed for running device-level Conv2Ds with absolute maximum calculation and scaling
|
| 34 |
+
*/
|
| 35 |
+
|
| 36 |
+
#pragma once
|
| 37 |
+
|
| 38 |
+
#include <iostream>
|
| 39 |
+
#include <fstream>
|
| 40 |
+
#include <sstream>
|
| 41 |
+
|
| 42 |
+
#include "conv2d_problems.h"
|
| 43 |
+
#include "../../common/cutlass_unit_test.h"
|
| 44 |
+
#include "../../gemm/device/testbed_utils.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/matrix_coord.h"
|
| 47 |
+
#include "cutlass/conv/convolution.h"
|
| 48 |
+
#include "cutlass/layout/matrix.h"
|
| 49 |
+
|
| 50 |
+
#include "cutlass/util/host_tensor.h"
|
| 51 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 52 |
+
#include "cutlass/util/distribution.h"
|
| 53 |
+
#include "cutlass/util/reference/host/convolution.h"
|
| 54 |
+
#include "cutlass/util/reference/host/tensor_copy.h"
|
| 55 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 56 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 57 |
+
#include "cutlass/util/reference/host/tensor_reduce.h"
|
| 58 |
+
|
| 59 |
+
namespace test {
|
| 60 |
+
namespace conv {
|
| 61 |
+
namespace device {
|
| 62 |
+
|
| 63 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
template <
|
| 66 |
+
typename Conv,
|
| 67 |
+
template<typename T> class ActivationFunctor
|
| 68 |
+
>
|
| 69 |
+
struct TestbedConv2dWithAbsMax {
|
| 70 |
+
|
| 71 |
+
using ElementAccumulator = typename Conv::ElementAccumulator;
|
| 72 |
+
using ElementCompute = typename Conv::UnderlyingKernel::Epilogue::OutputOp::ElementCompute;
|
| 73 |
+
using ElementScalingFactor = typename Conv::EpilogueOutputOp::ElementScalingFactor;
|
| 74 |
+
using ElementAbsmax = typename Conv::EpilogueOutputOp::ElementAbsmax;
|
| 75 |
+
static cutlass::conv::Operator const kConvolutionalOperator = Conv::kConvolutionalOperator;
|
| 76 |
+
|
| 77 |
+
static bool const kScaleAux = Conv::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded;
|
| 78 |
+
static bool const kScaleOutput = Conv::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded;
|
| 79 |
+
bool doScaleA;
|
| 80 |
+
bool doScaleB;
|
| 81 |
+
bool doScaleC;
|
| 82 |
+
|
| 83 |
+
/// Initialization
|
| 84 |
+
cutlass::Distribution::Kind init_A;
|
| 85 |
+
cutlass::Distribution::Kind init_B;
|
| 86 |
+
cutlass::Distribution::Kind init_C;
|
| 87 |
+
uint64_t seed;
|
| 88 |
+
|
| 89 |
+
cutlass::HostTensor<typename Conv::ElementA, typename Conv::LayoutA> tensor_A;
|
| 90 |
+
cutlass::HostTensor<typename Conv::ElementB, typename Conv::LayoutB> tensor_B;
|
| 91 |
+
cutlass::HostTensor<typename Conv::ElementC, typename Conv::LayoutC> tensor_C;
|
| 92 |
+
cutlass::HostTensor<typename Conv::EpilogueOutputOp::ElementAuxOutput, typename Conv::LayoutC> tensor_Aux;
|
| 93 |
+
cutlass::HostTensor<typename Conv::EpilogueOutputOp::ElementOutput, typename Conv::LayoutC> tensor_D;
|
| 94 |
+
cutlass::HostTensor<typename Conv::ElementC, typename Conv::LayoutC> tensor_Vector;
|
| 95 |
+
cutlass::HostTensor<ElementAccumulator, typename Conv::LayoutC> tmp_D;
|
| 96 |
+
cutlass::HostTensor<typename Conv::EpilogueOutputOp::ElementOutput, typename Conv::LayoutC> reference_D;
|
| 97 |
+
cutlass::HostTensor<typename Conv::EpilogueOutputOp::ElementAuxOutput, typename Conv::LayoutC> reference_Aux;
|
| 98 |
+
cutlass::HostTensor<ElementScalingFactor, typename Conv::LayoutC> scale_A;
|
| 99 |
+
cutlass::HostTensor<ElementScalingFactor, typename Conv::LayoutC> scale_B;
|
| 100 |
+
cutlass::HostTensor<ElementScalingFactor, typename Conv::LayoutC> scale_C;
|
| 101 |
+
cutlass::HostTensor<ElementScalingFactor, typename Conv::LayoutC> scale_D;
|
| 102 |
+
cutlass::HostTensor<ElementScalingFactor, typename Conv::LayoutC> scale_Aux;
|
| 103 |
+
cutlass::HostTensor<ElementAbsmax, typename Conv::LayoutC> abs_max_Aux;
|
| 104 |
+
cutlass::HostTensor<ElementAbsmax, typename Conv::LayoutC> abs_max_D;
|
| 105 |
+
cutlass::HostTensor<ElementAbsmax, typename Conv::LayoutC> reference_abs_max_Aux;
|
| 106 |
+
cutlass::HostTensor<ElementAbsmax, typename Conv::LayoutC> reference_abs_max_D;
|
| 107 |
+
|
| 108 |
+
//
|
| 109 |
+
// Methods
|
| 110 |
+
//
|
| 111 |
+
|
| 112 |
+
TestbedConv2dWithAbsMax(
|
| 113 |
+
bool scaleA = true,
|
| 114 |
+
bool scaleB = true,
|
| 115 |
+
bool scaleC = true,
|
| 116 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 117 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 118 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 119 |
+
uint64_t seed_ = 2080
|
| 120 |
+
):
|
| 121 |
+
doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC),
|
| 122 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
|
| 123 |
+
|
| 124 |
+
/// Helper to initialize scaling factors
|
| 125 |
+
template <typename Element, typename Layout>
|
| 126 |
+
bool initialize_scale_factor(cutlass::TensorView<Element, Layout> view, uint64_t seed, int bits=0) {
|
| 127 |
+
cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits);
|
| 128 |
+
return true;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
/// Helper to initialize a tensor view
|
| 132 |
+
template <typename Element, typename Layout>
|
| 133 |
+
bool initialize_tensor(
|
| 134 |
+
cutlass::TensorView<Element, Layout> view,
|
| 135 |
+
cutlass::Distribution::Kind dist_kind,
|
| 136 |
+
uint64_t seed) {
|
| 137 |
+
|
| 138 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 139 |
+
|
| 140 |
+
double scope_max, scope_min;
|
| 141 |
+
int bits_input = cutlass::sizeof_bits<Element>::value;
|
| 142 |
+
int bits_output = cutlass::sizeof_bits<typename Conv::ElementC>::value;
|
| 143 |
+
|
| 144 |
+
if (bits_input == 1) {
|
| 145 |
+
scope_max = 2;
|
| 146 |
+
scope_min = 0;
|
| 147 |
+
} else if (bits_input <= 8) {
|
| 148 |
+
scope_max = 2;
|
| 149 |
+
scope_min = -2;
|
| 150 |
+
} else if (bits_output == 16) {
|
| 151 |
+
scope_max = 5;
|
| 152 |
+
scope_min = -5;
|
| 153 |
+
} else {
|
| 154 |
+
scope_max = 8;
|
| 155 |
+
scope_min = -8;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 159 |
+
view, seed, scope_max, scope_min, 0);
|
| 160 |
+
}
|
| 161 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 162 |
+
|
| 163 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 164 |
+
}
|
| 165 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 166 |
+
|
| 167 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 168 |
+
}
|
| 169 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 170 |
+
|
| 171 |
+
cutlass::reference::host::BlockFillSequential(
|
| 172 |
+
view.data(), view.capacity());
|
| 173 |
+
}
|
| 174 |
+
else {
|
| 175 |
+
EXPECT_TRUE(false) << "Not implemented";
|
| 176 |
+
return false;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
return true;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
/// Initializes data structures
|
| 183 |
+
void initialize(cutlass::conv::Conv2dProblemSize const &problem_size) {
|
| 184 |
+
//
|
| 185 |
+
// Allocate the GEMM workspace
|
| 186 |
+
//
|
| 187 |
+
|
| 188 |
+
tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
|
| 189 |
+
tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
|
| 190 |
+
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 191 |
+
tensor_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 192 |
+
tensor_Vector.resize({1, 1, 1, implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()});
|
| 193 |
+
reference_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false);
|
| 194 |
+
tmp_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false);
|
| 195 |
+
|
| 196 |
+
EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019));
|
| 197 |
+
EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018));
|
| 198 |
+
EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017));
|
| 199 |
+
EXPECT_TRUE(initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020));
|
| 200 |
+
|
| 201 |
+
// It is possible to randomly initialize to all zeros, so override this with non-zeros
|
| 202 |
+
// in the upper left corner of each operand.
|
| 203 |
+
cutlass::Coord<4> origin(0);
|
| 204 |
+
tensor_A.host_view().at(origin) = typename Conv::ElementA(1);
|
| 205 |
+
tensor_B.host_view().at(origin) = typename Conv::ElementB(1);
|
| 206 |
+
tensor_C.host_view().at(origin) = typename Conv::ElementC(1);
|
| 207 |
+
tensor_Vector.host_view().at(origin) = typename Conv::ElementC(1);
|
| 208 |
+
|
| 209 |
+
cutlass::reference::host::TensorFill(tensor_D.host_view());
|
| 210 |
+
cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
|
| 211 |
+
|
| 212 |
+
tensor_A.sync_device();
|
| 213 |
+
tensor_B.sync_device();
|
| 214 |
+
tensor_C.sync_device();
|
| 215 |
+
tensor_D.sync_device();
|
| 216 |
+
tensor_Vector.sync_device();
|
| 217 |
+
|
| 218 |
+
int scale_bits = 2;
|
| 219 |
+
if (doScaleA) {
|
| 220 |
+
scale_A.resize({1, 1, 1, 1});
|
| 221 |
+
EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits));
|
| 222 |
+
scale_A.sync_device();
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
if (doScaleB) {
|
| 226 |
+
scale_B.resize({1, 1, 1, 1});
|
| 227 |
+
EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits));
|
| 228 |
+
scale_B.sync_device();
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
if (doScaleC) {
|
| 232 |
+
scale_C.resize({1, 1, 1, 1});
|
| 233 |
+
EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits));
|
| 234 |
+
scale_C.sync_device();
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
if (kScaleOutput) {
|
| 238 |
+
scale_D.resize({1, 1, 1, 1});
|
| 239 |
+
EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits));
|
| 240 |
+
scale_D.sync_device();
|
| 241 |
+
|
| 242 |
+
abs_max_D.resize({1, 1, 1, 1});
|
| 243 |
+
cutlass::reference::host::TensorFill(abs_max_D.host_view());
|
| 244 |
+
abs_max_D.sync_device();
|
| 245 |
+
|
| 246 |
+
reference_abs_max_D.resize({1, 1, 1, 1});
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
if (kScaleAux) {
|
| 250 |
+
tensor_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 251 |
+
cutlass::reference::host::TensorFill(tensor_Aux.host_view());
|
| 252 |
+
tensor_Aux.sync_device();
|
| 253 |
+
|
| 254 |
+
scale_Aux.resize({1, 1, 1, 1});
|
| 255 |
+
EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits));
|
| 256 |
+
scale_Aux.sync_device();
|
| 257 |
+
|
| 258 |
+
abs_max_Aux.resize({1, 1, 1, 1});
|
| 259 |
+
cutlass::reference::host::TensorFill(abs_max_Aux.host_view());
|
| 260 |
+
abs_max_Aux.sync_device();
|
| 261 |
+
|
| 262 |
+
reference_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false);
|
| 263 |
+
reference_abs_max_Aux.resize({1, 1, 1, 1});
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
/// Compares computed reference with device reference and outputs to a file if incorrect
|
| 268 |
+
bool compare_reference(
|
| 269 |
+
cutlass::conv::Conv2dProblemSize const &problem_size,
|
| 270 |
+
ElementCompute alpha,
|
| 271 |
+
ElementCompute beta) {
|
| 272 |
+
|
| 273 |
+
tensor_D.sync_host();
|
| 274 |
+
|
| 275 |
+
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0);
|
| 276 |
+
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0);
|
| 277 |
+
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0);
|
| 278 |
+
|
| 279 |
+
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0);
|
| 280 |
+
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0);
|
| 281 |
+
bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
|
| 282 |
+
|
| 283 |
+
if (kScaleAux) {
|
| 284 |
+
tensor_Aux.sync_host();
|
| 285 |
+
abs_max_Aux.sync_host();
|
| 286 |
+
EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0);
|
| 287 |
+
EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0);
|
| 288 |
+
EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0);
|
| 289 |
+
passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view());
|
| 290 |
+
passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view());
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
if (kScaleOutput) {
|
| 294 |
+
abs_max_D.sync_host();
|
| 295 |
+
EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0);
|
| 296 |
+
passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view());
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
EXPECT_TRUE(passed) << " mismatched reference";
|
| 300 |
+
|
| 301 |
+
if (!passed) {
|
| 302 |
+
|
| 303 |
+
std::ofstream file0("conv_testbed_with_amax_errors_reference.txt");
|
| 304 |
+
std::ofstream file1("conv_testbed_with_amax_errors_computed.txt");
|
| 305 |
+
|
| 306 |
+
std::ofstream file("conv_testbed_with_amax_errors.txt");
|
| 307 |
+
|
| 308 |
+
file
|
| 309 |
+
<< "problem: " << problem_size
|
| 310 |
+
<< ", alpha: " << alpha << ", beta: " << beta << "\n\n";
|
| 311 |
+
|
| 312 |
+
file
|
| 313 |
+
<< "A =\n" << tensor_A.host_view()
|
| 314 |
+
<< "\nB =\n" << tensor_B.host_view()
|
| 315 |
+
<< "\nC =\n" << tensor_C.host_view()
|
| 316 |
+
<< "\nVector =\n" << tensor_Vector.host_view()
|
| 317 |
+
<< "\nScaleA = " << scale_A.host_view()
|
| 318 |
+
<< "\nScaleB = " << scale_B.host_view()
|
| 319 |
+
<< "\nScaleC = " << scale_C.host_view()
|
| 320 |
+
<< "\nScaleD = " << scale_D.host_view()
|
| 321 |
+
<< "\nScaleAux = " << scale_Aux.host_view()
|
| 322 |
+
<< std::endl;
|
| 323 |
+
|
| 324 |
+
file0 << "\n\nReference D =\n" << reference_D.host_view() << std::endl;
|
| 325 |
+
file1 << "\n\nComputed D =\n" << tensor_D.host_view() << std::endl;
|
| 326 |
+
if (kScaleAux) {
|
| 327 |
+
file0 << "\n\nReference Aux =\n" << reference_Aux.host_view() << std::endl;
|
| 328 |
+
file1 << "\n\nComputed Aux =\n" << tensor_Aux.host_view() << std::endl;
|
| 329 |
+
file0 << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() << std::endl;
|
| 330 |
+
file1 << "\n\nComputed Absmax Aux = " << abs_max_Aux.host_view() << std::endl;
|
| 331 |
+
}
|
| 332 |
+
if (kScaleOutput) {
|
| 333 |
+
file0 << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() << std::endl;
|
| 334 |
+
file1 << "\n\nComputed Absmax D = " << abs_max_D.host_view() << std::endl;
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
return passed;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
/// Verifies the result is a GEMM
|
| 342 |
+
bool verify(
|
| 343 |
+
cutlass::conv::Conv2dProblemSize const &problem_size,
|
| 344 |
+
ElementCompute alpha,
|
| 345 |
+
ElementCompute beta) {
|
| 346 |
+
|
| 347 |
+
cutlass::Coord<4> origin(0);
|
| 348 |
+
ElementCompute scaled_alpha = alpha;
|
| 349 |
+
if (doScaleA) {
|
| 350 |
+
scaled_alpha *= scale_A.host_view().at(origin);
|
| 351 |
+
}
|
| 352 |
+
if (doScaleB) {
|
| 353 |
+
scaled_alpha *= scale_B.host_view().at(origin);
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
ElementCompute scaled_beta = beta;
|
| 357 |
+
if (doScaleC) {
|
| 358 |
+
scaled_beta *= scale_C.host_view().at(origin);
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
//
|
| 362 |
+
// Verify
|
| 363 |
+
//
|
| 364 |
+
|
| 365 |
+
cutlass::reference::host::Conv2d<
|
| 366 |
+
typename Conv::ElementA, typename Conv::LayoutA,
|
| 367 |
+
typename Conv::ElementB, typename Conv::LayoutB,
|
| 368 |
+
typename Conv::ElementC, typename Conv::LayoutC,
|
| 369 |
+
ElementCompute, ElementAccumulator, ElementAccumulator
|
| 370 |
+
>(
|
| 371 |
+
kConvolutionalOperator,
|
| 372 |
+
problem_size,
|
| 373 |
+
tensor_A.host_ref(),
|
| 374 |
+
tensor_B.host_ref(),
|
| 375 |
+
tensor_C.host_ref(),
|
| 376 |
+
tmp_D.host_ref(),
|
| 377 |
+
scaled_alpha,
|
| 378 |
+
scaled_beta
|
| 379 |
+
);
|
| 380 |
+
|
| 381 |
+
ElementCompute tmp_abs_max_Aux(0.);
|
| 382 |
+
ElementCompute tmp_abs_max_D(0.);
|
| 383 |
+
|
| 384 |
+
cutlass::NumericConverter<ElementCompute, typename Conv::ElementC> cvt_c_to_compute;
|
| 385 |
+
cutlass::NumericConverter<ElementCompute, ElementAccumulator> cvt_accum_to_compute;
|
| 386 |
+
cutlass::NumericConverter<ElementAbsmax, ElementCompute> cvt_compute_to_absmax;
|
| 387 |
+
cutlass::NumericConverter<typename Conv::EpilogueOutputOp::ElementOutput, ElementCompute> cvt_compute_to_d;
|
| 388 |
+
cutlass::NumericConverter<typename Conv::EpilogueOutputOp::ElementAuxOutput, ElementCompute> cvt_compute_to_aux;
|
| 389 |
+
|
| 390 |
+
cutlass::absolute_value_op<ElementCompute> abs;
|
| 391 |
+
cutlass::maximum_with_nan_propogation<ElementCompute> max;
|
| 392 |
+
ActivationFunctor<ElementCompute> act;
|
| 393 |
+
|
| 394 |
+
ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.);
|
| 395 |
+
|
| 396 |
+
for (int n = 0; n < problem_size.N; ++n) {
|
| 397 |
+
for (int p = 0; p < problem_size.P; ++p) {
|
| 398 |
+
for (int q = 0; q < problem_size.Q; ++q) {
|
| 399 |
+
for (int k = 0; k < problem_size.K; ++k) {
|
| 400 |
+
ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({n, p, q, k}));
|
| 401 |
+
ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, 0, 0, k}));
|
| 402 |
+
ElementCompute aux = intermediate + bias;
|
| 403 |
+
ElementCompute d = act(aux);
|
| 404 |
+
tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux);
|
| 405 |
+
tmp_abs_max_D = max(abs(d), tmp_abs_max_D);
|
| 406 |
+
reference_D.host_view().at({n, p, q, k}) = cvt_compute_to_d(d * d_scale);
|
| 407 |
+
|
| 408 |
+
if (kScaleAux) {
|
| 409 |
+
reference_Aux.host_view().at({n, p, q, k}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin));
|
| 410 |
+
}
|
| 411 |
+
}
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
if (kScaleAux) {
|
| 416 |
+
reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux);
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
if (kScaleOutput) {
|
| 420 |
+
reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D);
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
return compare_reference(problem_size, alpha, beta);
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
/// Returns true if the CUDA device is sufficient to execute the kernel.
|
| 427 |
+
bool sufficient() const {
|
| 428 |
+
//
|
| 429 |
+
// Determine SMEM requirements and waive if not satisfied
|
| 430 |
+
//
|
| 431 |
+
|
| 432 |
+
size_t smem_size = sizeof(typename Conv::UnderlyingKernel::SharedStorage);
|
| 433 |
+
|
| 434 |
+
cudaDeviceProp properties;
|
| 435 |
+
int device_idx;
|
| 436 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 437 |
+
|
| 438 |
+
if (result != cudaSuccess) {
|
| 439 |
+
throw std::runtime_error("cudaGetDevice() API call failed.");
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
result = cudaGetDeviceProperties(&properties, device_idx);
|
| 443 |
+
|
| 444 |
+
if (result != cudaSuccess) {
|
| 445 |
+
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
if (properties.sharedMemPerBlockOptin < smem_size) {
|
| 449 |
+
return false;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
return true;
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
/// Executes one test
|
| 456 |
+
bool run(
|
| 457 |
+
cutlass::conv::Conv2dProblemSize const &problem_size,
|
| 458 |
+
ElementCompute alpha = ElementCompute(1),
|
| 459 |
+
ElementCompute beta = ElementCompute(0))
|
| 460 |
+
{
|
| 461 |
+
|
| 462 |
+
// Waive test if insufficient CUDA device
|
| 463 |
+
if (!sufficient()) {
|
| 464 |
+
if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
|
| 465 |
+
std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
|
| 466 |
+
}
|
| 467 |
+
return true;
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
this->initialize(problem_size);
|
| 471 |
+
|
| 472 |
+
//
|
| 473 |
+
// Initialize the GEMM operator
|
| 474 |
+
//
|
| 475 |
+
|
| 476 |
+
typename Conv::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta};
|
| 477 |
+
typename Conv::EpilogueOutputOp::Params epilogue_params{
|
| 478 |
+
activation_params,
|
| 479 |
+
scale_A.device_data(),
|
| 480 |
+
scale_B.device_data(),
|
| 481 |
+
scale_C.device_data(),
|
| 482 |
+
scale_D.device_data(),
|
| 483 |
+
scale_Aux.device_data(),
|
| 484 |
+
abs_max_Aux.device_data(),
|
| 485 |
+
abs_max_D.device_data()
|
| 486 |
+
};
|
| 487 |
+
|
| 488 |
+
typename Conv::Arguments arguments{
|
| 489 |
+
problem_size,
|
| 490 |
+
tensor_A.device_ref(),
|
| 491 |
+
tensor_B.device_ref(),
|
| 492 |
+
tensor_C.device_ref(),
|
| 493 |
+
tensor_D.device_ref(),
|
| 494 |
+
tensor_Aux.device_ref(),
|
| 495 |
+
epilogue_params,
|
| 496 |
+
cutlass::conv::SplitKMode::kSerial,
|
| 497 |
+
tensor_Vector.device_data(),
|
| 498 |
+
0
|
| 499 |
+
};
|
| 500 |
+
|
| 501 |
+
Conv conv2d_op;
|
| 502 |
+
|
| 503 |
+
cutlass::Status status = conv2d_op.can_implement(arguments);
|
| 504 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
| 505 |
+
|
| 506 |
+
size_t workspace_size = Conv::get_workspace_size(arguments);
|
| 507 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 508 |
+
|
| 509 |
+
status = conv2d_op.initialize(arguments, workspace.get());
|
| 510 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
| 511 |
+
|
| 512 |
+
//
|
| 513 |
+
// Run the GEMM
|
| 514 |
+
//
|
| 515 |
+
|
| 516 |
+
status = conv2d_op();
|
| 517 |
+
|
| 518 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
|
| 519 |
+
|
| 520 |
+
cudaError_t cuda_error = cudaDeviceSynchronize();
|
| 521 |
+
EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error);
|
| 522 |
+
|
| 523 |
+
//
|
| 524 |
+
// Verify
|
| 525 |
+
//
|
| 526 |
+
|
| 527 |
+
bool passed = this->verify(problem_size, alpha, beta);
|
| 528 |
+
|
| 529 |
+
if (!passed) {
|
| 530 |
+
std::cout << "Failed" << std::endl;
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
return passed;
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
};
|
| 537 |
+
|
| 538 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 539 |
+
|
| 540 |
+
template <
|
| 541 |
+
typename ImplicitGemm,
|
| 542 |
+
template<typename T> class ActivationFunctor = cutlass::epilogue::thread::Identity
|
| 543 |
+
>
|
| 544 |
+
bool TestAllConv2dWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) {
|
| 545 |
+
const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector();
|
| 546 |
+
const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector();
|
| 547 |
+
|
| 548 |
+
//
|
| 549 |
+
// Testbed object
|
| 550 |
+
//
|
| 551 |
+
|
| 552 |
+
TestbedConv2dWithAbsMax<ImplicitGemm, ActivationFunctor> testbed(scaleA, scaleB, scaleC);
|
| 553 |
+
|
| 554 |
+
//
|
| 555 |
+
// Get conv problem sizes to run conv operator
|
| 556 |
+
//
|
| 557 |
+
TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
|
| 558 |
+
|
| 559 |
+
// Vector of conv2d problem sizes to avoid duplicate runs
|
| 560 |
+
Conv2dProblemVector conv_tested_sizes;
|
| 561 |
+
|
| 562 |
+
Conv2dProblemVector const *problem_vectors[] = {
|
| 563 |
+
&conv_test_sizes, // run user specified sizes
|
| 564 |
+
&conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
|
| 565 |
+
&conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
|
| 566 |
+
#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
|
| 567 |
+
&conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
|
| 568 |
+
#endif
|
| 569 |
+
};
|
| 570 |
+
|
| 571 |
+
bool passed = true;
|
| 572 |
+
|
| 573 |
+
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 574 |
+
for (Conv2dProblemVector const * problem_vector : problem_vectors) {
|
| 575 |
+
|
| 576 |
+
// Prune all problems with channels that aren't divisible by the number of elements accessed per
|
| 577 |
+
// load for operands A and B. This is meant to align with the requirements of iterators used for
|
| 578 |
+
// fprop kernels.
|
| 579 |
+
ChannelDivisibilitySpecification channel_spec(128 / cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
|
| 580 |
+
auto pruned_problem_vector = prune(*problem_vector, channel_spec);
|
| 581 |
+
|
| 582 |
+
// Run conv testbed on default convolution sizes
|
| 583 |
+
for(auto conv_problem : pruned_problem_vector) {
|
| 584 |
+
|
| 585 |
+
// Skip blacklist and avoid duplicate problem sizes
|
| 586 |
+
if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
|
| 587 |
+
std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
|
| 588 |
+
continue;
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
//
|
| 592 |
+
// Test
|
| 593 |
+
//
|
| 594 |
+
// push back tested problem size to avoid re-running duplicates
|
| 595 |
+
conv_tested_sizes.push_back(conv_problem);
|
| 596 |
+
|
| 597 |
+
// test mode = xcross
|
| 598 |
+
passed &= testbed.run(conv_problem);
|
| 599 |
+
|
| 600 |
+
if (!passed) {
|
| 601 |
+
return false;
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
// test mode = convolution
|
| 605 |
+
passed &= testbed.run(conv_problem.reset_mode(cutlass::conv::Mode::kConvolution));
|
| 606 |
+
|
| 607 |
+
if (!passed) {
|
| 608 |
+
return false;
|
| 609 |
+
}
|
| 610 |
+
}
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
return passed;
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 617 |
+
|
| 618 |
+
} // namespace device
|
| 619 |
+
} // namespace conv
|
| 620 |
+
} // namespace test
|
| 621 |
+
|
| 622 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implicit GEMM for fused epilogue broadcast testbed
|
| 33 |
+
|
| 34 |
+
Parallel split-k is not tested because we can just use regular conv kernel
|
| 35 |
+
when we need to use parallel-splitk. Broadcast can happen in the reduction
|
| 36 |
+
kernel.
|
| 37 |
+
*/
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include <fstream>
|
| 41 |
+
|
| 42 |
+
#include "../../common/cutlass_unit_test.h"
|
| 43 |
+
#include "cutlass/cutlass.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
| 46 |
+
#include "cutlass/reduction/device/reduce_split_k.h"
|
| 47 |
+
#include "cutlass/reduction/thread/reduction_operators.h"
|
| 48 |
+
|
| 49 |
+
#include "conv2d_problems.h"
|
| 50 |
+
|
| 51 |
+
#include "cutlass/util/host_tensor.h"
|
| 52 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 53 |
+
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 54 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 55 |
+
|
| 56 |
+
#include "cutlass/util/reference/host/convolution.h"
|
| 57 |
+
#include "cutlass/util/reference/device/convolution.h"
|
| 58 |
+
|
| 59 |
+
#include "cutlass/core_io.h"
|
| 60 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 61 |
+
|
| 62 |
+
#include "../cache_testbed_output.h"
|
| 63 |
+
|
| 64 |
+
namespace test {
|
| 65 |
+
namespace conv {
|
| 66 |
+
namespace device {
|
| 67 |
+
|
| 68 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 69 |
+
|
| 70 |
+
template <typename Conv2d>
|
| 71 |
+
struct Conv2dWithBroadcastReferenceOp {
|
| 72 |
+
|
| 73 |
+
using OutputOp = typename Conv2d::EpilogueOutputOp;
|
| 74 |
+
|
| 75 |
+
using ElementCompute = typename OutputOp::ElementCompute;
|
| 76 |
+
using ElementZ = typename OutputOp::ElementZ;
|
| 77 |
+
using ElementT = typename OutputOp::ElementT;
|
| 78 |
+
|
| 79 |
+
typename OutputOp::BinaryOp binary_op;
|
| 80 |
+
typename OutputOp::ElementwiseOp elementwise_op;
|
| 81 |
+
|
| 82 |
+
Conv2dWithBroadcastReferenceOp() { }
|
| 83 |
+
|
| 84 |
+
void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) {
|
| 85 |
+
ElementCompute t_full = binary_op(conv2d, bias);
|
| 86 |
+
T = ElementT(t_full);
|
| 87 |
+
|
| 88 |
+
ElementCompute z_full = elementwise_op(t_full);
|
| 89 |
+
Z = ElementZ(z_full);
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 94 |
+
|
| 95 |
+
// Fused testbed
|
| 96 |
+
//
|
| 97 |
+
// Y = CONV(AB, C)
|
| 98 |
+
//
|
| 99 |
+
// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k])
|
| 100 |
+
//
|
| 101 |
+
// Z[n, p, q, k] = Elementwise(T[n, p, q, k])
|
| 102 |
+
//
|
| 103 |
+
|
| 104 |
+
template <
|
| 105 |
+
typename Conv2d,
|
| 106 |
+
typename ReferenceOp,
|
| 107 |
+
bool AddBroadcastFirst = false
|
| 108 |
+
>
|
| 109 |
+
class TestbedConv2dWithBroadcast {
|
| 110 |
+
public:
|
| 111 |
+
|
| 112 |
+
using ElementA = typename Conv2d::ElementA;
|
| 113 |
+
using LayoutA = typename Conv2d::LayoutA;
|
| 114 |
+
using ElementB = typename Conv2d::ElementB;
|
| 115 |
+
using LayoutB = typename Conv2d::LayoutB;
|
| 116 |
+
using ElementC = typename Conv2d::ElementC;
|
| 117 |
+
using LayoutC = typename Conv2d::LayoutC;
|
| 118 |
+
using ElementAccumulator = typename Conv2d::ElementAccumulator;
|
| 119 |
+
using ElementCompute = typename Conv2d::ElementCompute;
|
| 120 |
+
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
|
| 121 |
+
using ElementZ = typename EpilogueOutputOp::ElementZ;
|
| 122 |
+
using ElementT = typename EpilogueOutputOp::ElementT;
|
| 123 |
+
using ElementVector = typename EpilogueOutputOp::ElementVector;
|
| 124 |
+
|
| 125 |
+
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
| 126 |
+
static const bool kAddBroadcastFirst = AddBroadcastFirst;
|
| 127 |
+
static const bool kStoreT = EpilogueOutputOp::kStoreT;
|
| 128 |
+
|
| 129 |
+
public:
|
| 130 |
+
|
| 131 |
+
/// Initialization
|
| 132 |
+
cutlass::Distribution::Kind init_A;
|
| 133 |
+
cutlass::Distribution::Kind init_B;
|
| 134 |
+
cutlass::Distribution::Kind init_C;
|
| 135 |
+
uint64_t seed;
|
| 136 |
+
|
| 137 |
+
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
| 138 |
+
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
| 139 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
| 140 |
+
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_C_reference;
|
| 141 |
+
cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_computed;
|
| 142 |
+
cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_reference;
|
| 143 |
+
cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
|
| 144 |
+
cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
|
| 145 |
+
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
|
| 146 |
+
cutlass::HostTensor<ElementVector, LayoutC> tensor_Broadcast; // Input Broadcast
|
| 147 |
+
|
| 148 |
+
public:
|
| 149 |
+
|
| 150 |
+
TestbedConv2dWithBroadcast(
|
| 151 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 152 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 153 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 154 |
+
uint64_t seed_ = 2080
|
| 155 |
+
):
|
| 156 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
|
| 157 |
+
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/// Helper to initialize a tensor view
|
| 161 |
+
template <typename Element, typename Layout>
|
| 162 |
+
void initialize_tensor(
|
| 163 |
+
cutlass::TensorView<Element, Layout> view,
|
| 164 |
+
cutlass::Distribution::Kind dist_kind,
|
| 165 |
+
uint64_t seed) {
|
| 166 |
+
|
| 167 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 168 |
+
|
| 169 |
+
int scope;
|
| 170 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 171 |
+
|
| 172 |
+
if (bits <= 8) {
|
| 173 |
+
scope = 2;
|
| 174 |
+
}
|
| 175 |
+
else if (bits == 16) {
|
| 176 |
+
if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
|
| 177 |
+
scope = 3;
|
| 178 |
+
}
|
| 179 |
+
else {
|
| 180 |
+
scope = 5;
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
else {
|
| 184 |
+
scope = 8;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 188 |
+
view, seed, scope, -scope, 0);
|
| 189 |
+
}
|
| 190 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 191 |
+
|
| 192 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 193 |
+
}
|
| 194 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 195 |
+
|
| 196 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 197 |
+
}
|
| 198 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 199 |
+
|
| 200 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 201 |
+
}
|
| 202 |
+
else {
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
void initialize(
|
| 207 |
+
cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
|
| 208 |
+
|
| 209 |
+
tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
|
| 210 |
+
tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
|
| 211 |
+
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 212 |
+
tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 213 |
+
tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 214 |
+
tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 215 |
+
tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 216 |
+
tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 217 |
+
tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 218 |
+
tensor_Broadcast.resize({
|
| 219 |
+
1,
|
| 220 |
+
1,
|
| 221 |
+
1,
|
| 222 |
+
implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(),
|
| 223 |
+
});
|
| 224 |
+
|
| 225 |
+
initialize_tensor(tensor_A.host_view(), init_A, seed);
|
| 226 |
+
initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
|
| 227 |
+
initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
|
| 228 |
+
initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39);
|
| 229 |
+
|
| 230 |
+
for (int n = 0; n < tensor_C_reference.extent().n(); ++n) {
|
| 231 |
+
for (int p = 0; p < tensor_C_reference.extent().h(); ++p) {
|
| 232 |
+
for (int q = 0; q < tensor_C_reference.extent().w(); ++q) {
|
| 233 |
+
for (int k = 0; k < tensor_C_reference.extent().c(); ++k) {
|
| 234 |
+
tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k}));
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
tensor_A.sync_device();
|
| 241 |
+
tensor_B.sync_device();
|
| 242 |
+
tensor_C.sync_device();
|
| 243 |
+
tensor_Broadcast.sync_device();
|
| 244 |
+
tensor_C_reference.sync_device();
|
| 245 |
+
tensor_Z_computed.sync_device();
|
| 246 |
+
tensor_Z_reference.sync_device();
|
| 247 |
+
tensor_T_computed.sync_device();
|
| 248 |
+
tensor_T_reference.sync_device();
|
| 249 |
+
tensor_Y_reference.sync_device();
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
bool sufficient() const {
|
| 253 |
+
//
|
| 254 |
+
// Determine SMEM requirements and waive if not satisfied
|
| 255 |
+
//
|
| 256 |
+
|
| 257 |
+
size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage);
|
| 258 |
+
|
| 259 |
+
cudaDeviceProp properties;
|
| 260 |
+
int device_idx;
|
| 261 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 262 |
+
|
| 263 |
+
if (result != cudaSuccess) {
|
| 264 |
+
throw std::runtime_error("cudaGetDevice() API call failed.");
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
result = cudaGetDeviceProperties(&properties, device_idx);
|
| 268 |
+
|
| 269 |
+
if (result != cudaSuccess) {
|
| 270 |
+
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
if (properties.sharedMemPerBlockOptin < smem_size) {
|
| 274 |
+
return false;
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
return true;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
/// Executes one test
|
| 281 |
+
bool run(
|
| 282 |
+
cutlass::conv::Conv2dProblemSize const &problem_size,
|
| 283 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 284 |
+
ElementCompute alpha = ElementCompute(1),
|
| 285 |
+
ElementCompute beta = ElementCompute(1)) {
|
| 286 |
+
|
| 287 |
+
// Waive test if insufficient CUDA device
|
| 288 |
+
if (!sufficient()) {
|
| 289 |
+
if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
|
| 290 |
+
std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
|
| 291 |
+
}
|
| 292 |
+
return true;
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
#if 0 //display conv2d problem size for debugging
|
| 296 |
+
std::cout << problem_size << std::endl
|
| 297 |
+
<< "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
|
| 298 |
+
<< "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
|
| 299 |
+
<< std::endl;
|
| 300 |
+
#endif
|
| 301 |
+
|
| 302 |
+
initialize(problem_size);
|
| 303 |
+
|
| 304 |
+
// configure the operator
|
| 305 |
+
Conv2d conv2d_op;
|
| 306 |
+
typename Conv2d::Arguments conv2d_args(
|
| 307 |
+
problem_size,
|
| 308 |
+
tensor_A.device_ref(),
|
| 309 |
+
tensor_B.device_ref(),
|
| 310 |
+
tensor_C.device_ref(),
|
| 311 |
+
tensor_Z_computed.device_ref(),
|
| 312 |
+
{alpha, beta},
|
| 313 |
+
split_k_mode,
|
| 314 |
+
tensor_Broadcast.device_data(),
|
| 315 |
+
kStoreT ? tensor_T_computed.device_data() : nullptr,
|
| 316 |
+
0, // This must be zero
|
| 317 |
+
implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()
|
| 318 |
+
);
|
| 319 |
+
|
| 320 |
+
// initialize the kernel
|
| 321 |
+
size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
|
| 322 |
+
|
| 323 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 324 |
+
|
| 325 |
+
cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get());
|
| 326 |
+
|
| 327 |
+
if (status != cutlass::Status::kSuccess) {
|
| 328 |
+
cudaError_t error = cudaGetLastError();
|
| 329 |
+
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
|
| 330 |
+
return true;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
// run conv2d operator
|
| 334 |
+
status = conv2d_op();
|
| 335 |
+
|
| 336 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 337 |
+
if (status != cutlass::Status::kSuccess) {
|
| 338 |
+
return false;
|
| 339 |
+
}
|
| 340 |
+
|
| 341 |
+
bool passed = false;
|
| 342 |
+
|
| 343 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 344 |
+
EXPECT_EQ(result, cudaSuccess) << " device reference error: "
|
| 345 |
+
<< cudaGetErrorString(result);
|
| 346 |
+
|
| 347 |
+
tensor_T_computed.sync_host();
|
| 348 |
+
tensor_Z_computed.sync_host();
|
| 349 |
+
|
| 350 |
+
//
|
| 351 |
+
// Reference check
|
| 352 |
+
//
|
| 353 |
+
|
| 354 |
+
// When kAddBroadcastFirst is true, add bias on the host
|
| 355 |
+
ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta;
|
| 356 |
+
|
| 357 |
+
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
|
| 358 |
+
|
| 359 |
+
cutlass::reference::device::Conv2d<
|
| 360 |
+
ElementA,
|
| 361 |
+
LayoutA,
|
| 362 |
+
ElementB,
|
| 363 |
+
LayoutB,
|
| 364 |
+
ElementAccumulator,
|
| 365 |
+
LayoutC,
|
| 366 |
+
ElementAccumulator,
|
| 367 |
+
ElementAccumulator
|
| 368 |
+
>(
|
| 369 |
+
kConvolutionalOperator,
|
| 370 |
+
problem_size,
|
| 371 |
+
tensor_A.device_ref(),
|
| 372 |
+
tensor_B.device_ref(),
|
| 373 |
+
tensor_C_reference.device_ref(),
|
| 374 |
+
tensor_Y_reference.device_ref(),
|
| 375 |
+
alpha,
|
| 376 |
+
beta_ref);
|
| 377 |
+
|
| 378 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 379 |
+
tensor_Y_reference.sync_host();
|
| 380 |
+
|
| 381 |
+
#else
|
| 382 |
+
|
| 383 |
+
cutlass::reference::host::Conv2d<
|
| 384 |
+
ElementA,
|
| 385 |
+
LayoutA,
|
| 386 |
+
ElementB,
|
| 387 |
+
LayoutB,
|
| 388 |
+
ElementAccumulator,
|
| 389 |
+
LayoutC,
|
| 390 |
+
ElementAccumulator,
|
| 391 |
+
ElementAccumulator
|
| 392 |
+
>(
|
| 393 |
+
kConvolutionalOperator,
|
| 394 |
+
problem_size,
|
| 395 |
+
tensor_A.host_ref(),
|
| 396 |
+
tensor_B.host_ref(),
|
| 397 |
+
tensor_C_reference.host_ref(),
|
| 398 |
+
tensor_Y_reference.host_ref(),
|
| 399 |
+
alpha,
|
| 400 |
+
beta_ref);
|
| 401 |
+
|
| 402 |
+
#endif
|
| 403 |
+
ReferenceOp reference_op;
|
| 404 |
+
|
| 405 |
+
// compute tensor Z and tensor T
|
| 406 |
+
for (int n = 0; n < problem_size.N; ++n) {
|
| 407 |
+
for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) {
|
| 408 |
+
for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) {
|
| 409 |
+
for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) {
|
| 410 |
+
|
| 411 |
+
ElementZ z{};
|
| 412 |
+
ElementT t{};
|
| 413 |
+
|
| 414 |
+
ElementCompute accum = tensor_Y_reference.at({n, p, q, k});
|
| 415 |
+
ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k}));
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
if (kAddBroadcastFirst) {
|
| 419 |
+
reference_op(z, t, accum + bias,
|
| 420 |
+
beta * ElementCompute(tensor_C_reference.at({n, p, q, k})));
|
| 421 |
+
} else {
|
| 422 |
+
reference_op(z, t, accum, bias);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
tensor_Z_reference.at({n, p, q, k}) = z;
|
| 426 |
+
tensor_T_reference.at({n, p, q, k}) = t;
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
if (kStoreT) {
|
| 433 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 434 |
+
tensor_T_computed.host_view(),
|
| 435 |
+
tensor_T_reference.host_view());
|
| 436 |
+
|
| 437 |
+
EXPECT_TRUE(passed);
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 441 |
+
tensor_Z_computed.host_view(),
|
| 442 |
+
tensor_Z_reference.host_view());
|
| 443 |
+
|
| 444 |
+
EXPECT_TRUE(passed);
|
| 445 |
+
|
| 446 |
+
if (!passed) {
|
| 447 |
+
std::stringstream fname;
|
| 448 |
+
|
| 449 |
+
fname << "error_Conv2d_ImplicitGemm_device_"
|
| 450 |
+
<< (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
|
| 451 |
+
<< (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
|
| 452 |
+
(Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" :
|
| 453 |
+
(Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_")))
|
| 454 |
+
<< "nhwc_"
|
| 455 |
+
<< problem_size.N << "x"
|
| 456 |
+
<< problem_size.H << "x"
|
| 457 |
+
<< problem_size.W << "x"
|
| 458 |
+
<< problem_size.C
|
| 459 |
+
<< "_krsc_"
|
| 460 |
+
<< problem_size.K << "x"
|
| 461 |
+
<< problem_size.R << "x"
|
| 462 |
+
<< problem_size.S << "x"
|
| 463 |
+
<< problem_size.C
|
| 464 |
+
<< "_padding_"
|
| 465 |
+
<< problem_size.pad_h << "x"
|
| 466 |
+
<< problem_size.pad_w
|
| 467 |
+
<< "_stride_"
|
| 468 |
+
<< problem_size.stride_h << "x"
|
| 469 |
+
<< problem_size.stride_w
|
| 470 |
+
<< "_dilation_"
|
| 471 |
+
<< problem_size.dilation_h << "x"
|
| 472 |
+
<< problem_size.dilation_w << "_"
|
| 473 |
+
<< (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_")
|
| 474 |
+
<< Conv2d::ThreadblockShape::kM << "x"
|
| 475 |
+
<< Conv2d::ThreadblockShape::kN << "x"
|
| 476 |
+
<< Conv2d::ThreadblockShape::kK << "_"
|
| 477 |
+
<< Conv2d::WarpShape::kM << "x"
|
| 478 |
+
<< Conv2d::WarpShape::kN << "x"
|
| 479 |
+
<< Conv2d::WarpShape::kK << ".txt";
|
| 480 |
+
|
| 481 |
+
std::cout << fname.str() << std::endl;
|
| 482 |
+
|
| 483 |
+
std::ofstream results(fname.str());
|
| 484 |
+
|
| 485 |
+
results << problem_size << std::endl;
|
| 486 |
+
|
| 487 |
+
results
|
| 488 |
+
<< "\nA:\n" << tensor_A.host_view() << "\n"
|
| 489 |
+
<< "\nB:\n" << tensor_B.host_view() << "\n"
|
| 490 |
+
<< "\nC:\n" << tensor_C.host_view() << "\n"
|
| 491 |
+
<< "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n"
|
| 492 |
+
<< "\nY reference:\n" << tensor_Y_reference.host_view() << "\n"
|
| 493 |
+
<< "\nT reference:\n" << tensor_T_reference.host_view() << "\n"
|
| 494 |
+
<< "\nT computed:\n" << tensor_T_computed.host_view() << "\n"
|
| 495 |
+
<< "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n"
|
| 496 |
+
<< "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n";
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
return passed;
|
| 500 |
+
}
|
| 501 |
+
};
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 505 |
+
|
| 506 |
+
template <typename ImplicitGemm,
|
| 507 |
+
typename ReferenceOp = Conv2dWithBroadcastReferenceOp<ImplicitGemm>,
|
| 508 |
+
bool AddBroadcastFirst = false>
|
| 509 |
+
bool TestSpecificConv2dWithBroadcast(
|
| 510 |
+
const Conv2dProblemVector & problem_sizes) {
|
| 511 |
+
|
| 512 |
+
bool passed = true;
|
| 513 |
+
|
| 514 |
+
//
|
| 515 |
+
// Testbed object
|
| 516 |
+
//
|
| 517 |
+
|
| 518 |
+
TestbedConv2dWithBroadcast<ImplicitGemm, ReferenceOp, AddBroadcastFirst> testbed;
|
| 519 |
+
|
| 520 |
+
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 521 |
+
for(auto conv_problem : problem_sizes) {
|
| 522 |
+
|
| 523 |
+
//
|
| 524 |
+
// Test
|
| 525 |
+
//
|
| 526 |
+
|
| 527 |
+
// test mode = xcross
|
| 528 |
+
passed = testbed.run(
|
| 529 |
+
conv_problem,
|
| 530 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 531 |
+
|
| 532 |
+
if (!passed) {
|
| 533 |
+
return false;
|
| 534 |
+
}
|
| 535 |
+
|
| 536 |
+
// test mode = convolution
|
| 537 |
+
passed = testbed.run(
|
| 538 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 539 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 540 |
+
|
| 541 |
+
if (!passed) {
|
| 542 |
+
return false;
|
| 543 |
+
}
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
return true;
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 550 |
+
// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
|
| 551 |
+
// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
|
| 552 |
+
// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
|
| 553 |
+
// (conv_blacklist_sizes)
|
| 554 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 555 |
+
template <typename ImplicitGemm,
|
| 556 |
+
typename ReferenceOp = Conv2dWithBroadcastReferenceOp<ImplicitGemm>,
|
| 557 |
+
bool AddBroadcastFirst = false,
|
| 558 |
+
bool TestSplitK = true
|
| 559 |
+
>
|
| 560 |
+
bool TestAllConv2dWithBroadcast(
|
| 561 |
+
const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(),
|
| 562 |
+
const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) {
|
| 563 |
+
|
| 564 |
+
bool passed = true;
|
| 565 |
+
|
| 566 |
+
//
|
| 567 |
+
// Testbed object
|
| 568 |
+
//
|
| 569 |
+
|
| 570 |
+
TestbedConv2dWithBroadcast<ImplicitGemm, ReferenceOp, AddBroadcastFirst> testbed;
|
| 571 |
+
|
| 572 |
+
//
|
| 573 |
+
// Get conv problem sizes to run conv operator
|
| 574 |
+
//
|
| 575 |
+
TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
|
| 576 |
+
|
| 577 |
+
// Vector of conv2d problem sizes to avoid duplicate runs
|
| 578 |
+
Conv2dProblemVector conv_tested_sizes;
|
| 579 |
+
|
| 580 |
+
Conv2dProblemVector const *problem_vectors[] = {
|
| 581 |
+
&conv_test_sizes, // run user specified sizes
|
| 582 |
+
&conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
|
| 583 |
+
&conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
|
| 584 |
+
#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
|
| 585 |
+
&conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
|
| 586 |
+
#endif
|
| 587 |
+
};
|
| 588 |
+
|
| 589 |
+
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 590 |
+
for (Conv2dProblemVector const * problem_vector : problem_vectors) {
|
| 591 |
+
|
| 592 |
+
// Run conv testbed on default convolution sizes
|
| 593 |
+
for(auto conv_problem : *problem_vector) {
|
| 594 |
+
|
| 595 |
+
// Skip blacklist and avoid duplicate problem sizes
|
| 596 |
+
if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
|
| 597 |
+
std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
|
| 598 |
+
continue;
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
//
|
| 602 |
+
// Procedurally disable certain cases
|
| 603 |
+
//
|
| 604 |
+
|
| 605 |
+
// CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
|
| 606 |
+
if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
|
| 607 |
+
ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
|
| 608 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 609 |
+
cutlass::conv::StrideSupport::kUnity)) {
|
| 610 |
+
if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
|
| 611 |
+
continue;
|
| 612 |
+
}
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
#if 0 // relax restrictions on analytic strided dgrad
|
| 616 |
+
// CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2}
|
| 617 |
+
if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
|
| 618 |
+
ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
|
| 619 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 620 |
+
cutlass::conv::StrideSupport::kStrided)) {
|
| 621 |
+
if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
|
| 622 |
+
continue;
|
| 623 |
+
}
|
| 624 |
+
}
|
| 625 |
+
#endif
|
| 626 |
+
|
| 627 |
+
//
|
| 628 |
+
// Test
|
| 629 |
+
//
|
| 630 |
+
// push back tested problem size to avoid re-running duplicates
|
| 631 |
+
conv_tested_sizes.push_back(conv_problem);
|
| 632 |
+
|
| 633 |
+
// test mode = xcross
|
| 634 |
+
passed = testbed.run(
|
| 635 |
+
conv_problem,
|
| 636 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 637 |
+
|
| 638 |
+
if (!passed) {
|
| 639 |
+
return false;
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
// test mode = convolution
|
| 643 |
+
passed = testbed.run(
|
| 644 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 645 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 646 |
+
|
| 647 |
+
if (!passed) {
|
| 648 |
+
return false;
|
| 649 |
+
}
|
| 650 |
+
}
|
| 651 |
+
}
|
| 652 |
+
|
| 653 |
+
// CUTLASS DGRAD's *strided* specialization does not support split-k mode
|
| 654 |
+
if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
|
| 655 |
+
ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
|
| 656 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 657 |
+
cutlass::conv::StrideSupport::kStrided)) {
|
| 658 |
+
|
| 659 |
+
passed = testbed.run(
|
| 660 |
+
cutlass::conv::Conv2dProblemSize(
|
| 661 |
+
{1, 56, 56, 8}, // input size (NHWC)
|
| 662 |
+
{8, 1, 1, 8}, // filter size (KRSC)
|
| 663 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 664 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 665 |
+
{1, 1}), // dilation (dilation_h, dilation_w)
|
| 666 |
+
cutlass::conv::SplitKMode::kSerial,
|
| 667 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0),
|
| 668 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0));
|
| 669 |
+
|
| 670 |
+
if (!passed) {
|
| 671 |
+
return false;
|
| 672 |
+
}
|
| 673 |
+
|
| 674 |
+
return passed;
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
if (!TestSplitK)
|
| 678 |
+
return passed;
|
| 679 |
+
|
| 680 |
+
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
| 681 |
+
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
| 682 |
+
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
| 683 |
+
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
| 684 |
+
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
| 685 |
+
{1, 17, 11, 288}, // input size (NHWC)
|
| 686 |
+
{160, 3, 3, 288}, // filter size (KRSC)
|
| 687 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 688 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 689 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 690 |
+
);
|
| 691 |
+
|
| 692 |
+
cutlass::conv::SplitKMode split_k_modes [] = {
|
| 693 |
+
cutlass::conv::SplitKMode::kSerial
|
| 694 |
+
};
|
| 695 |
+
|
| 696 |
+
int split_k_slices[] = {
|
| 697 |
+
1, 2, 3, 4, 201
|
| 698 |
+
};
|
| 699 |
+
|
| 700 |
+
double problem_alpha[] = {
|
| 701 |
+
2.0
|
| 702 |
+
};
|
| 703 |
+
|
| 704 |
+
double problem_beta[] = {
|
| 705 |
+
2.0
|
| 706 |
+
};
|
| 707 |
+
|
| 708 |
+
for (auto split_k_mode : split_k_modes) {
|
| 709 |
+
for (auto split_k_slice : split_k_slices) {
|
| 710 |
+
for (auto alpha : problem_alpha) {
|
| 711 |
+
for (auto beta : problem_beta) {
|
| 712 |
+
|
| 713 |
+
passed = testbed.run(
|
| 714 |
+
conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
|
| 715 |
+
split_k_mode,
|
| 716 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
|
| 717 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
|
| 718 |
+
|
| 719 |
+
if (!passed) {
|
| 720 |
+
return false;
|
| 721 |
+
}
|
| 722 |
+
}
|
| 723 |
+
}
|
| 724 |
+
}
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
return passed;
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 731 |
+
|
| 732 |
+
} // namespace device
|
| 733 |
+
} // namespace conv
|
| 734 |
+
} // namespace test
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h
ADDED
|
@@ -0,0 +1,643 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implicit GEMM testbed
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include <fstream>
|
| 37 |
+
|
| 38 |
+
#include "../../common/cutlass_unit_test.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
|
| 41 |
+
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
| 42 |
+
#include "cutlass/reduction/device/tensor_reduce.h"
|
| 43 |
+
#include "cutlass/reduction/device/reduce_split_k.h"
|
| 44 |
+
#include "cutlass/reduction/thread/reduction_operators.h"
|
| 45 |
+
|
| 46 |
+
#include "conv2d_problems.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/util/host_tensor.h"
|
| 49 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 50 |
+
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 51 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 52 |
+
|
| 53 |
+
#include "cutlass/util/reference/host/convolution.h"
|
| 54 |
+
#include "cutlass/util/reference/device/convolution.h"
|
| 55 |
+
|
| 56 |
+
#include "cutlass/core_io.h"
|
| 57 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 58 |
+
|
| 59 |
+
#include "../cache_testbed_output.h"
|
| 60 |
+
|
| 61 |
+
namespace test {
|
| 62 |
+
namespace conv {
|
| 63 |
+
namespace device {
|
| 64 |
+
|
| 65 |
+
template <typename Conv2d>
|
| 66 |
+
class TestbedConv2dWithReduction {
|
| 67 |
+
public:
|
| 68 |
+
|
| 69 |
+
using ElementA = typename Conv2d::ElementA;
|
| 70 |
+
using LayoutA = typename Conv2d::LayoutA;
|
| 71 |
+
using ElementB = typename Conv2d::ElementB;
|
| 72 |
+
using LayoutB = typename Conv2d::LayoutB;
|
| 73 |
+
using ElementC = typename Conv2d::ElementC;
|
| 74 |
+
using LayoutC = typename Conv2d::LayoutC;
|
| 75 |
+
using ElementAccumulator = typename Conv2d::ElementAccumulator;
|
| 76 |
+
using ElementCompute = typename Conv2d::ElementCompute;
|
| 77 |
+
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
|
| 78 |
+
using ElementT = typename EpilogueOutputOp::ElementTensor;
|
| 79 |
+
|
| 80 |
+
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
| 81 |
+
|
| 82 |
+
public:
|
| 83 |
+
|
| 84 |
+
/// Initialization
|
| 85 |
+
cutlass::Distribution::Kind init_A;
|
| 86 |
+
cutlass::Distribution::Kind init_B;
|
| 87 |
+
cutlass::Distribution::Kind init_C;
|
| 88 |
+
uint64_t seed;
|
| 89 |
+
|
| 90 |
+
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
| 91 |
+
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
| 92 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
| 93 |
+
|
| 94 |
+
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Reduction;
|
| 95 |
+
cutlass::HostTensor<ElementT, cutlass::layout::RowMajor> tensor_Tensor;
|
| 96 |
+
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Final_Reduction;
|
| 97 |
+
|
| 98 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
|
| 99 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
|
| 100 |
+
|
| 101 |
+
public:
|
| 102 |
+
|
| 103 |
+
TestbedConv2dWithReduction(
|
| 104 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 105 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 106 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 107 |
+
uint64_t seed_ = 2080
|
| 108 |
+
):
|
| 109 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
|
| 110 |
+
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
/// Helper to initialize a tensor view
|
| 114 |
+
template <typename Element, typename Layout>
|
| 115 |
+
void initialize_tensor(
|
| 116 |
+
cutlass::TensorView<Element, Layout> view,
|
| 117 |
+
cutlass::Distribution::Kind dist_kind,
|
| 118 |
+
uint64_t seed) {
|
| 119 |
+
|
| 120 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 121 |
+
|
| 122 |
+
int scope = 2;
|
| 123 |
+
|
| 124 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 125 |
+
view, seed, scope, -scope, 0);
|
| 126 |
+
}
|
| 127 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 128 |
+
|
| 129 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 130 |
+
}
|
| 131 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 132 |
+
|
| 133 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 134 |
+
}
|
| 135 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 136 |
+
|
| 137 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 138 |
+
}
|
| 139 |
+
else {
|
| 140 |
+
}
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
void initialize(
|
| 144 |
+
cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
|
| 145 |
+
|
| 146 |
+
tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
|
| 147 |
+
tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
|
| 148 |
+
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 149 |
+
|
| 150 |
+
tensor_Reduction.resize({
|
| 151 |
+
1,
|
| 152 |
+
1,
|
| 153 |
+
(problem_size.N * problem_size.P * problem_size.Q - 1 + Conv2d::ThreadblockShape::kM) / Conv2d::ThreadblockShape::kM,
|
| 154 |
+
(problem_size.K)
|
| 155 |
+
});
|
| 156 |
+
|
| 157 |
+
tensor_Final_Reduction.resize({
|
| 158 |
+
1,
|
| 159 |
+
1,
|
| 160 |
+
1,
|
| 161 |
+
(problem_size.K)
|
| 162 |
+
});
|
| 163 |
+
|
| 164 |
+
tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K});
|
| 165 |
+
|
| 166 |
+
tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 167 |
+
tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 168 |
+
|
| 169 |
+
initialize_tensor(tensor_A.host_view(), init_A, seed);
|
| 170 |
+
initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
|
| 171 |
+
initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
|
| 172 |
+
|
| 173 |
+
tensor_A.sync_device();
|
| 174 |
+
tensor_B.sync_device();
|
| 175 |
+
tensor_C.sync_device();
|
| 176 |
+
tensor_D_computed.sync_device();
|
| 177 |
+
tensor_D_reference.sync_device();
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
bool sufficient() const {
|
| 181 |
+
//
|
| 182 |
+
// Determine SMEM requirements and waive if not satisfied
|
| 183 |
+
//
|
| 184 |
+
|
| 185 |
+
size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage);
|
| 186 |
+
|
| 187 |
+
cudaDeviceProp properties;
|
| 188 |
+
int device_idx;
|
| 189 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 190 |
+
|
| 191 |
+
if (result != cudaSuccess) {
|
| 192 |
+
throw std::runtime_error("cudaGetDevice() API call failed.");
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
result = cudaGetDeviceProperties(&properties, device_idx);
|
| 196 |
+
|
| 197 |
+
if (result != cudaSuccess) {
|
| 198 |
+
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
if (properties.sharedMemPerBlockOptin < smem_size) {
|
| 202 |
+
return false;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
return true;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
/// Executes one test
|
| 209 |
+
bool run(
|
| 210 |
+
cutlass::conv::Conv2dProblemSize const &problem_size,
|
| 211 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 212 |
+
ElementCompute alpha = ElementCompute(1),
|
| 213 |
+
ElementCompute beta = ElementCompute(0)) {
|
| 214 |
+
|
| 215 |
+
// Waive test if insufficient CUDA device
|
| 216 |
+
if (!sufficient()) {
|
| 217 |
+
if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
|
| 218 |
+
std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
|
| 219 |
+
}
|
| 220 |
+
return true;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
#if 0 //display conv2d problem size for debugging
|
| 224 |
+
std::cout << problem_size << std::endl
|
| 225 |
+
<< "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
|
| 226 |
+
<< "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
|
| 227 |
+
<< std::endl;
|
| 228 |
+
#endif
|
| 229 |
+
|
| 230 |
+
initialize(problem_size);
|
| 231 |
+
|
| 232 |
+
// configure the operator
|
| 233 |
+
Conv2d conv2d_op;
|
| 234 |
+
|
| 235 |
+
typename Conv2d::Arguments conv2d_args(
|
| 236 |
+
problem_size,
|
| 237 |
+
tensor_A.device_ref(),
|
| 238 |
+
tensor_B.device_ref(),
|
| 239 |
+
tensor_C.device_ref(),
|
| 240 |
+
tensor_D_computed.device_ref(),
|
| 241 |
+
{alpha, beta},
|
| 242 |
+
split_k_mode,
|
| 243 |
+
tensor_Reduction.device_data(),
|
| 244 |
+
tensor_Tensor.device_data(),
|
| 245 |
+
static_cast<int>(tensor_Reduction.stride()[0]),
|
| 246 |
+
static_cast<int>(tensor_Tensor.stride()[0])
|
| 247 |
+
);
|
| 248 |
+
|
| 249 |
+
// find workspace requirement for parallel split-k reduction
|
| 250 |
+
size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
|
| 251 |
+
|
| 252 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 253 |
+
|
| 254 |
+
cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get());
|
| 255 |
+
|
| 256 |
+
if (status != cutlass::Status::kSuccess) {
|
| 257 |
+
cudaError_t error = cudaGetLastError();
|
| 258 |
+
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
|
| 259 |
+
return true;
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
// conv2d operation with parallel split-k-mode
|
| 263 |
+
if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
|
| 264 |
+
|
| 265 |
+
// conv2d output is written to workspace in global memory
|
| 266 |
+
conv2d_args.ref_D.reset(reinterpret_cast<ElementC*>(workspace.get()));
|
| 267 |
+
// accumulate mma for each cta in k-dimension (1.0 * A * B)
|
| 268 |
+
conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)};
|
| 269 |
+
// update conv2d operator arguments
|
| 270 |
+
status = conv2d_op.update(conv2d_args, workspace.get());
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 274 |
+
if (status != cutlass::Status::kSuccess) {
|
| 275 |
+
return false;
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
// run conv2d operator
|
| 279 |
+
status = conv2d_op();
|
| 280 |
+
|
| 281 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 282 |
+
if (status != cutlass::Status::kSuccess) {
|
| 283 |
+
return false;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
bool passed = false;
|
| 287 |
+
|
| 288 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 289 |
+
EXPECT_EQ(result, cudaSuccess) << " device reference error: "
|
| 290 |
+
<< cudaGetErrorString(result);
|
| 291 |
+
|
| 292 |
+
// Final reduction over the partial reduction tensor
|
| 293 |
+
using Functor = cutlass::plus<ElementAccumulator>;
|
| 294 |
+
using TensorReduction = cutlass::reduction::device::TensorReduction<
|
| 295 |
+
ElementAccumulator,
|
| 296 |
+
ElementAccumulator,
|
| 297 |
+
LayoutC,
|
| 298 |
+
Functor,
|
| 299 |
+
8,
|
| 300 |
+
ElementAccumulator
|
| 301 |
+
>;
|
| 302 |
+
|
| 303 |
+
TensorReduction reduction(tensor_Reduction.extent(), 2);
|
| 304 |
+
|
| 305 |
+
cutlass::DeviceAllocation<uint8_t> reduction_device_workspace(reduction.workspace_size());
|
| 306 |
+
|
| 307 |
+
status = reduction.reduce(
|
| 308 |
+
tensor_Final_Reduction.device_ref(),
|
| 309 |
+
tensor_Reduction.device_ref(),
|
| 310 |
+
reduction_device_workspace.get(),
|
| 311 |
+
ElementAccumulator());
|
| 312 |
+
|
| 313 |
+
EXPECT_EQ(status, cutlass::Status::kSuccess);
|
| 314 |
+
EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess);
|
| 315 |
+
|
| 316 |
+
//
|
| 317 |
+
// Reference check
|
| 318 |
+
//
|
| 319 |
+
|
| 320 |
+
tensor_D_computed.sync_host();
|
| 321 |
+
|
| 322 |
+
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
|
| 323 |
+
|
| 324 |
+
cutlass::reference::device::Conv2d<
|
| 325 |
+
ElementA,
|
| 326 |
+
LayoutA,
|
| 327 |
+
ElementB,
|
| 328 |
+
LayoutB,
|
| 329 |
+
ElementC,
|
| 330 |
+
LayoutC,
|
| 331 |
+
ElementCompute,
|
| 332 |
+
ElementAccumulator
|
| 333 |
+
>(
|
| 334 |
+
kConvolutionalOperator,
|
| 335 |
+
problem_size,
|
| 336 |
+
tensor_A.device_ref(),
|
| 337 |
+
tensor_B.device_ref(),
|
| 338 |
+
tensor_C.device_ref(),
|
| 339 |
+
tensor_D_reference.device_ref(),
|
| 340 |
+
alpha,
|
| 341 |
+
beta);
|
| 342 |
+
|
| 343 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 344 |
+
tensor_D_reference.sync_host();
|
| 345 |
+
|
| 346 |
+
#else
|
| 347 |
+
|
| 348 |
+
cutlass::reference::host::Conv2d<
|
| 349 |
+
ElementA,
|
| 350 |
+
LayoutA,
|
| 351 |
+
ElementB,
|
| 352 |
+
LayoutB,
|
| 353 |
+
ElementC,
|
| 354 |
+
LayoutC,
|
| 355 |
+
ElementCompute,
|
| 356 |
+
ElementAccumulator
|
| 357 |
+
>(
|
| 358 |
+
kConvolutionalOperator,
|
| 359 |
+
problem_size,
|
| 360 |
+
tensor_A.host_ref(),
|
| 361 |
+
tensor_B.host_ref(),
|
| 362 |
+
tensor_C.host_ref(),
|
| 363 |
+
tensor_D_reference.host_ref(),
|
| 364 |
+
alpha,
|
| 365 |
+
beta);
|
| 366 |
+
|
| 367 |
+
#endif
|
| 368 |
+
|
| 369 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 370 |
+
tensor_D_computed.host_view(),
|
| 371 |
+
tensor_D_reference.host_view());
|
| 372 |
+
|
| 373 |
+
EXPECT_TRUE(passed);
|
| 374 |
+
|
| 375 |
+
//
|
| 376 |
+
// Reference check on reduction results
|
| 377 |
+
//
|
| 378 |
+
|
| 379 |
+
tensor_Reduction.sync_host();
|
| 380 |
+
tensor_Final_Reduction.sync_host();
|
| 381 |
+
|
| 382 |
+
// compute backwards for reduction results
|
| 383 |
+
cutlass::HostTensor<ElementAccumulator, LayoutC> reference_Reduction;
|
| 384 |
+
reference_Reduction.resize({
|
| 385 |
+
1,
|
| 386 |
+
1,
|
| 387 |
+
1,
|
| 388 |
+
(problem_size.K)
|
| 389 |
+
});
|
| 390 |
+
|
| 391 |
+
for (int k = 0; k < problem_size.K; ++k) {
|
| 392 |
+
ElementAccumulator reduced_value = ElementAccumulator();
|
| 393 |
+
for (int n = 0; n < problem_size.N; ++n) {
|
| 394 |
+
for (int p = 0; p < problem_size.P; ++p) {
|
| 395 |
+
for (int q = 0; q < problem_size.Q; ++q) {
|
| 396 |
+
reduced_value += tensor_D_reference.at({n, p, q, k});
|
| 397 |
+
}
|
| 398 |
+
}
|
| 399 |
+
}
|
| 400 |
+
reference_Reduction.at({0, 0, 0, k}) = reduced_value;
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 404 |
+
tensor_Final_Reduction.host_view(),
|
| 405 |
+
reference_Reduction.host_view()
|
| 406 |
+
);
|
| 407 |
+
|
| 408 |
+
EXPECT_TRUE(passed);
|
| 409 |
+
|
| 410 |
+
if (!passed) {
|
| 411 |
+
std::stringstream fname;
|
| 412 |
+
|
| 413 |
+
fname << "error_Conv2d_ImplicitGemm_device_"
|
| 414 |
+
<< (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
|
| 415 |
+
<< (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
|
| 416 |
+
(Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_"))
|
| 417 |
+
<< "nhwc_"
|
| 418 |
+
<< problem_size.N << "x"
|
| 419 |
+
<< problem_size.H << "x"
|
| 420 |
+
<< problem_size.W << "x"
|
| 421 |
+
<< problem_size.C
|
| 422 |
+
<< "_krsc_"
|
| 423 |
+
<< problem_size.K << "x"
|
| 424 |
+
<< problem_size.R << "x"
|
| 425 |
+
<< problem_size.S << "x"
|
| 426 |
+
<< problem_size.C
|
| 427 |
+
<< "_padding_"
|
| 428 |
+
<< problem_size.pad_h << "x"
|
| 429 |
+
<< problem_size.pad_w
|
| 430 |
+
<< "_stride_"
|
| 431 |
+
<< problem_size.stride_h << "x"
|
| 432 |
+
<< problem_size.stride_w
|
| 433 |
+
<< "_dilation_"
|
| 434 |
+
<< problem_size.dilation_h << "x"
|
| 435 |
+
<< problem_size.dilation_w << "_"
|
| 436 |
+
<< (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_")
|
| 437 |
+
<< Conv2d::ThreadblockShape::kM << "x"
|
| 438 |
+
<< Conv2d::ThreadblockShape::kN << "x"
|
| 439 |
+
<< Conv2d::ThreadblockShape::kK << "_"
|
| 440 |
+
<< Conv2d::WarpShape::kM << "x"
|
| 441 |
+
<< Conv2d::WarpShape::kN << "x"
|
| 442 |
+
<< Conv2d::WarpShape::kK << ".txt";
|
| 443 |
+
|
| 444 |
+
std::cout << fname.str() << std::endl;
|
| 445 |
+
|
| 446 |
+
std::ofstream results(fname.str());
|
| 447 |
+
|
| 448 |
+
results << problem_size << std::endl;
|
| 449 |
+
|
| 450 |
+
results
|
| 451 |
+
<< "\nA:\n" << tensor_A.host_view() << "\n"
|
| 452 |
+
<< "\nB:\n" << tensor_B.host_view() << "\n"
|
| 453 |
+
<< "\nC:\n" << tensor_C.host_view() << "\n"
|
| 454 |
+
<< "\nD reference:\n" << tensor_D_reference.host_view() << "\n"
|
| 455 |
+
<< "\nD computed:\n" << tensor_D_computed.host_view() << "\n"
|
| 456 |
+
<< "\nreduction reference:\n" << reference_Reduction.host_view() << "\n"
|
| 457 |
+
<< "\nreduction computed:\n" << tensor_Reduction.host_view() << "\n";
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
return passed;
|
| 461 |
+
}
|
| 462 |
+
};
|
| 463 |
+
|
| 464 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 465 |
+
// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
|
| 466 |
+
// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
|
| 467 |
+
// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
|
| 468 |
+
// (conv_blacklist_sizes)
|
| 469 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 470 |
+
template <typename ImplicitGemm>
|
| 471 |
+
bool TestAllConv2dWithReduction(
|
| 472 |
+
const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(),
|
| 473 |
+
const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) {
|
| 474 |
+
|
| 475 |
+
bool passed = true;
|
| 476 |
+
|
| 477 |
+
//
|
| 478 |
+
// Testbed object
|
| 479 |
+
//
|
| 480 |
+
|
| 481 |
+
TestbedConv2dWithReduction<ImplicitGemm> testbed;
|
| 482 |
+
|
| 483 |
+
//
|
| 484 |
+
// Get conv problem sizes to run conv operator
|
| 485 |
+
//
|
| 486 |
+
TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
|
| 487 |
+
|
| 488 |
+
// Vector of conv2d problem sizes to avoid duplicate runs
|
| 489 |
+
Conv2dProblemVector conv_tested_sizes;
|
| 490 |
+
|
| 491 |
+
Conv2dProblemVector const *problem_vectors[] = {
|
| 492 |
+
&conv_test_sizes, // run user specified sizes
|
| 493 |
+
&conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
|
| 494 |
+
&conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
|
| 495 |
+
#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
|
| 496 |
+
&conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
|
| 497 |
+
#endif
|
| 498 |
+
};
|
| 499 |
+
|
| 500 |
+
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 501 |
+
for (Conv2dProblemVector const * problem_vector : problem_vectors) {
|
| 502 |
+
|
| 503 |
+
// Run conv testbed on default convolution sizes
|
| 504 |
+
for(auto conv_problem : *problem_vector) {
|
| 505 |
+
|
| 506 |
+
// Skip blacklist and avoid duplicate problem sizes
|
| 507 |
+
if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
|
| 508 |
+
std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
|
| 509 |
+
continue;
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
//
|
| 513 |
+
// Procedurally disable certain cases
|
| 514 |
+
//
|
| 515 |
+
|
| 516 |
+
// CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
|
| 517 |
+
if ((ImplicitGemm::kConvolutionalOperator ==
|
| 518 |
+
cutlass::conv::Operator::kDgrad) &&
|
| 519 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 520 |
+
cutlass::conv::StrideSupport::kUnity)) {
|
| 521 |
+
if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
|
| 522 |
+
continue;
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
#if 0 // relax restrictions on analytic strided dgrad
|
| 527 |
+
// CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2}
|
| 528 |
+
if ((ImplicitGemm::kConvolutionalOperator ==
|
| 529 |
+
cutlass::conv::Operator::kDgrad) &&
|
| 530 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 531 |
+
cutlass::conv::StrideSupport::kStrided)) {
|
| 532 |
+
if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
|
| 533 |
+
continue;
|
| 534 |
+
}
|
| 535 |
+
}
|
| 536 |
+
#endif
|
| 537 |
+
|
| 538 |
+
//
|
| 539 |
+
// Test
|
| 540 |
+
//
|
| 541 |
+
// push back tested problem size to avoid re-running duplicates
|
| 542 |
+
conv_tested_sizes.push_back(conv_problem);
|
| 543 |
+
|
| 544 |
+
// test mode = xcross
|
| 545 |
+
passed = testbed.run(
|
| 546 |
+
conv_problem,
|
| 547 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 548 |
+
|
| 549 |
+
if (!passed) {
|
| 550 |
+
return false;
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
// test mode = convolution
|
| 554 |
+
passed = testbed.run(
|
| 555 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 556 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 557 |
+
|
| 558 |
+
if (!passed) {
|
| 559 |
+
return false;
|
| 560 |
+
}
|
| 561 |
+
}
|
| 562 |
+
}
|
| 563 |
+
|
| 564 |
+
// CUTLASS DGRAD's *strided* specialization does not support split-k mode
|
| 565 |
+
if ((ImplicitGemm::kConvolutionalOperator ==
|
| 566 |
+
cutlass::conv::Operator::kDgrad) &&
|
| 567 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 568 |
+
cutlass::conv::StrideSupport::kStrided)) {
|
| 569 |
+
|
| 570 |
+
passed = testbed.run(
|
| 571 |
+
cutlass::conv::Conv2dProblemSize(
|
| 572 |
+
{1, 56, 56, 8}, // input size (NHWC)
|
| 573 |
+
{8, 1, 1, 8}, // filter size (KRSC)
|
| 574 |
+
{0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
|
| 575 |
+
{2, 2}, // stride (stride_h, stride_w)
|
| 576 |
+
{1, 1}), // dilation (dilation_h, dilation_w)
|
| 577 |
+
cutlass::conv::SplitKMode::kSerial,
|
| 578 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0),
|
| 579 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0));
|
| 580 |
+
|
| 581 |
+
if (!passed) {
|
| 582 |
+
return false;
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
return passed;
|
| 586 |
+
}
|
| 587 |
+
|
| 588 |
+
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
| 589 |
+
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
| 590 |
+
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
| 591 |
+
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
| 592 |
+
cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
|
| 593 |
+
{1, 17, 11, 288}, // input size (NHWC)
|
| 594 |
+
{160, 3, 3, 288}, // filter size (KRSC)
|
| 595 |
+
{1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
|
| 596 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 597 |
+
{1, 1} // dilation (dilation_h, dilation_w)
|
| 598 |
+
);
|
| 599 |
+
|
| 600 |
+
// Parallel SplitK is not tested.
|
| 601 |
+
cutlass::conv::SplitKMode split_k_modes [] = {
|
| 602 |
+
cutlass::conv::SplitKMode::kSerial,
|
| 603 |
+
};
|
| 604 |
+
|
| 605 |
+
int split_k_slices[] = {
|
| 606 |
+
1, 2, 3, 4, 201
|
| 607 |
+
};
|
| 608 |
+
|
| 609 |
+
double problem_alpha[] = {
|
| 610 |
+
2.0
|
| 611 |
+
};
|
| 612 |
+
|
| 613 |
+
double problem_beta[] = {
|
| 614 |
+
2.0
|
| 615 |
+
};
|
| 616 |
+
|
| 617 |
+
for (auto split_k_mode : split_k_modes) {
|
| 618 |
+
for (auto split_k_slice : split_k_slices) {
|
| 619 |
+
for (auto alpha : problem_alpha) {
|
| 620 |
+
for (auto beta : problem_beta) {
|
| 621 |
+
|
| 622 |
+
passed = testbed.run(
|
| 623 |
+
conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
|
| 624 |
+
split_k_mode,
|
| 625 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
|
| 626 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
|
| 627 |
+
|
| 628 |
+
if (!passed) {
|
| 629 |
+
return false;
|
| 630 |
+
}
|
| 631 |
+
}
|
| 632 |
+
}
|
| 633 |
+
}
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
return passed;
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 640 |
+
|
| 641 |
+
} // namespace device
|
| 642 |
+
} // namespace conv
|
| 643 |
+
} // namespace test
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implicit GEMM testbed sizes for Conv2d problem
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include "../../common/cutlass_unit_test.h"
|
| 37 |
+
|
| 38 |
+
#include "cutlass/cutlass.h"
|
| 39 |
+
|
| 40 |
+
#include "cutlass/aligned_buffer.h"
|
| 41 |
+
#include "cutlass/numeric_types.h"
|
| 42 |
+
#include "cutlass/layout/matrix.h"
|
| 43 |
+
#include "cutlass/layout/tensor.h"
|
| 44 |
+
#include "cutlass/layout/pitch_linear.h"
|
| 45 |
+
#include "cutlass/core_io.h"
|
| 46 |
+
#include "cutlass/util/host_tensor.h"
|
| 47 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 48 |
+
#include "cutlass/conv/convolution.h"
|
| 49 |
+
#include "cutlass/conv/conv2d_problem_size.h"
|
| 50 |
+
#include "cutlass/conv/conv3d_problem_size.h"
|
| 51 |
+
|
| 52 |
+
namespace test {
|
| 53 |
+
namespace conv {
|
| 54 |
+
namespace device {
|
| 55 |
+
|
| 56 |
+
using Conv3dProblemVector = std::vector<cutlass::conv::Conv3dProblemSize>;
|
| 57 |
+
|
| 58 |
+
////////////////////////////////////////////////////////////////////////////
|
| 59 |
+
/// Structure TestbedConv3dProblemSizes initializes and holds conv default and
|
| 60 |
+
/// important network sizes
|
| 61 |
+
////////////////////////////////////////////////////////////////////////////
|
| 62 |
+
struct TestbedConv3dProblemSizes {
|
| 63 |
+
|
| 64 |
+
//
|
| 65 |
+
// Data members
|
| 66 |
+
//
|
| 67 |
+
int minimum_channel_size;
|
| 68 |
+
Conv3dProblemVector conv3d_default_sizes;
|
| 69 |
+
Conv3dProblemVector conv3d_vnet_medical_sizes;
|
| 70 |
+
|
| 71 |
+
//
|
| 72 |
+
// Methods
|
| 73 |
+
//
|
| 74 |
+
/// Default ctor
|
| 75 |
+
TestbedConv3dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) {
|
| 76 |
+
|
| 77 |
+
initialize_conv3d_default_sizes();
|
| 78 |
+
initialize_conv3d_vnet_medical_sizes(conv3d_vnet_medical_sizes, 1 /*batch-size*/);
|
| 79 |
+
|
| 80 |
+
filter_all();
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
/// Eliminates some illegal cases
|
| 84 |
+
void filter_all() {
|
| 85 |
+
|
| 86 |
+
Conv3dProblemVector *problems_vectors[] = {
|
| 87 |
+
&conv3d_default_sizes,
|
| 88 |
+
&conv3d_vnet_medical_sizes
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
for (Conv3dProblemVector *problems : problems_vectors) {
|
| 92 |
+
Conv3dProblemVector filtered;
|
| 93 |
+
|
| 94 |
+
for (cutlass::conv::Conv3dProblemSize const & problem : *problems) {
|
| 95 |
+
if (!(problem.C % minimum_channel_size)) {
|
| 96 |
+
filtered.push_back(problem);
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
*problems = filtered;
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
// Add a few standard convolution problem sizes
|
| 105 |
+
void initialize_conv3d_default_sizes() {
|
| 106 |
+
|
| 107 |
+
conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
|
| 108 |
+
{1, 1, 3, 3, minimum_channel_size}, // input size (NDHWC)
|
| 109 |
+
{8, 1, 1, 1, minimum_channel_size}, // filter size (KTRSC)
|
| 110 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 111 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 112 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 113 |
+
));
|
| 114 |
+
|
| 115 |
+
conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
|
| 116 |
+
{1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC)
|
| 117 |
+
{8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC)
|
| 118 |
+
cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
|
| 119 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 120 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 121 |
+
));
|
| 122 |
+
|
| 123 |
+
conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
|
| 124 |
+
{1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC)
|
| 125 |
+
{8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC)
|
| 126 |
+
CUTLASS_STL_NAMESPACE::make_tuple(
|
| 127 |
+
cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w)
|
| 128 |
+
cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w)
|
| 129 |
+
),
|
| 130 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 131 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 132 |
+
));
|
| 133 |
+
|
| 134 |
+
conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
|
| 135 |
+
{1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC)
|
| 136 |
+
{8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC)
|
| 137 |
+
cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
|
| 138 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 139 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 140 |
+
));
|
| 141 |
+
|
| 142 |
+
conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
|
| 143 |
+
{1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC)
|
| 144 |
+
{8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC)
|
| 145 |
+
CUTLASS_STL_NAMESPACE::make_tuple(
|
| 146 |
+
cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w)
|
| 147 |
+
cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w)
|
| 148 |
+
),
|
| 149 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 150 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 151 |
+
));
|
| 152 |
+
|
| 153 |
+
conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
|
| 154 |
+
{1, 16, 16, 16, minimum_channel_size}, // input size (NDHWC)
|
| 155 |
+
{8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC)
|
| 156 |
+
cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
|
| 157 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 158 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 159 |
+
));
|
| 160 |
+
|
| 161 |
+
conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
|
| 162 |
+
{1, 1, 15, 19, 160}, // input size (NDHWC)
|
| 163 |
+
{224, 1, 3, 6, 160}, // filter size (KTRSC)
|
| 164 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 165 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 166 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 167 |
+
));
|
| 168 |
+
|
| 169 |
+
conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
|
| 170 |
+
{1, 2, 1, 1, minimum_channel_size}, // input size (NDHWC)
|
| 171 |
+
{8, 2, 1, 1, minimum_channel_size}, // filter size (KTRSC)
|
| 172 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 173 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 174 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 175 |
+
));
|
| 176 |
+
|
| 177 |
+
conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
|
| 178 |
+
{1, 1, 7, 7, minimum_channel_size}, // input size (NDHWC)
|
| 179 |
+
{16, 1, 3, 3, minimum_channel_size}, // filter size (KTRSC)
|
| 180 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 181 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 182 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 183 |
+
));
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
|
| 187 |
+
{1, 11, 15, 19, 64}, // input size (NDHWC)
|
| 188 |
+
{32, 4, 3, 6, 64}, // filter size (KTRSC)
|
| 189 |
+
cutlass::Coord<3>({2, 1, 3}), // padding (pad_d, pad_h, pad_w)
|
| 190 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 191 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 192 |
+
));
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
// Add vnet layers to unit testing sizes
|
| 196 |
+
void initialize_conv3d_vnet_medical_sizes(Conv3dProblemVector &conv3d_problem_vector, int batch_size = 1) {
|
| 197 |
+
|
| 198 |
+
conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
|
| 199 |
+
{batch_size, 32, 32, 32, 16}, // input size (NDHWC)
|
| 200 |
+
{32, 2, 2, 2, 16}, // filter size (KTRSC)
|
| 201 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 202 |
+
cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w)
|
| 203 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 204 |
+
));
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
|
| 208 |
+
{batch_size, 16, 16, 16, 32}, // input size (NDHWC)
|
| 209 |
+
{32, 3, 3, 3, 32}, // filter size (KTRSC)
|
| 210 |
+
cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
|
| 211 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 212 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 213 |
+
));
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
|
| 217 |
+
{batch_size, 16, 16, 16, 32}, // input size (NDHWC)
|
| 218 |
+
{64, 2, 2, 2, 32}, // filter size (KTRSC)
|
| 219 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 220 |
+
cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w)
|
| 221 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 222 |
+
));
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
|
| 226 |
+
{batch_size, 8, 8, 8, 64}, // input size (NDHWC)
|
| 227 |
+
{64, 3, 3, 3, 64}, // filter size (KTRSC)
|
| 228 |
+
cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
|
| 229 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 230 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 231 |
+
));
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
|
| 235 |
+
{batch_size, 8, 8, 8, 64}, // input size (NDHWC)
|
| 236 |
+
{128, 2, 2, 2, 64}, // filter size (KTRSC)
|
| 237 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 238 |
+
cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w)
|
| 239 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 240 |
+
));
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
|
| 244 |
+
{batch_size, 4, 4, 4, 128}, // input size (NDHWC)
|
| 245 |
+
{128, 3, 3, 3, 128}, // filter size (KTRSC)
|
| 246 |
+
cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
|
| 247 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 248 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 249 |
+
));
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
|
| 253 |
+
{batch_size, 8, 8, 8, 128}, // input size (NDHWC)
|
| 254 |
+
{128, 3, 3, 3, 128}, // filter size (KTRSC)
|
| 255 |
+
cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
|
| 256 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 257 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 258 |
+
));
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
|
| 262 |
+
{batch_size, 16, 16, 16, 64}, // input size (NDHWC)
|
| 263 |
+
{64, 3, 3, 3, 64}, // filter size (KTRSC)
|
| 264 |
+
cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
|
| 265 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 266 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 267 |
+
));
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
|
| 271 |
+
{batch_size, 32, 32, 32, 16}, // input size (NDHWC)
|
| 272 |
+
{64, 2, 2, 2, 16}, // filter size (KTRSC)
|
| 273 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 274 |
+
cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w)
|
| 275 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 276 |
+
));
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
|
| 280 |
+
{batch_size, 16, 16, 16, 32}, // input size (NDHWC)
|
| 281 |
+
{128, 2, 2, 2, 32}, // filter size (KTRSC)
|
| 282 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 283 |
+
cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w)
|
| 284 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 285 |
+
));
|
| 286 |
+
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
};
|
| 290 |
+
|
| 291 |
+
} // namespace device
|
| 292 |
+
} // namespace conv
|
| 293 |
+
} // namespace test
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h
ADDED
|
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implicit GEMM testbed
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include <fstream>
|
| 37 |
+
|
| 38 |
+
#include "../../common/cutlass_unit_test.h"
|
| 39 |
+
#include "cutlass/cutlass.h"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
| 43 |
+
#include "cutlass/reduction/device/reduce_split_k.h"
|
| 44 |
+
#include "cutlass/reduction/thread/reduction_operators.h"
|
| 45 |
+
|
| 46 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 47 |
+
|
| 48 |
+
#include "cutlass/util/reference/host/convolution.h"
|
| 49 |
+
|
| 50 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 51 |
+
|
| 52 |
+
#include "cutlass/util/reference/device/convolution.h"
|
| 53 |
+
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 54 |
+
|
| 55 |
+
#include "conv3d_problems.h"
|
| 56 |
+
#include "cutlass/core_io.h"
|
| 57 |
+
|
| 58 |
+
#include "../cache_testbed_output.h"
|
| 59 |
+
|
| 60 |
+
namespace test {
|
| 61 |
+
namespace conv {
|
| 62 |
+
namespace device {
|
| 63 |
+
|
| 64 |
+
template <typename Conv3d>
|
| 65 |
+
class TestbedConv3d {
|
| 66 |
+
public:
|
| 67 |
+
|
| 68 |
+
using ElementA = typename Conv3d::ElementA;
|
| 69 |
+
using LayoutA = typename Conv3d::LayoutA;
|
| 70 |
+
using ElementB = typename Conv3d::ElementB;
|
| 71 |
+
using LayoutB = typename Conv3d::LayoutB;
|
| 72 |
+
using ElementC = typename Conv3d::ElementC;
|
| 73 |
+
using LayoutC = typename Conv3d::LayoutC;
|
| 74 |
+
using ElementAccumulator = typename Conv3d::ElementAccumulator;
|
| 75 |
+
using ElementCompute = typename Conv3d::ElementCompute;
|
| 76 |
+
using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp;
|
| 77 |
+
|
| 78 |
+
static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator;
|
| 79 |
+
|
| 80 |
+
/// Reduction kernel
|
| 81 |
+
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
|
| 82 |
+
ElementAccumulator,
|
| 83 |
+
typename EpilogueOutputOp::ElementAccumulator,
|
| 84 |
+
EpilogueOutputOp::kCount
|
| 85 |
+
>;
|
| 86 |
+
|
| 87 |
+
using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
|
| 88 |
+
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
|
| 89 |
+
EpilogueOutputOp,
|
| 90 |
+
ReductionOp
|
| 91 |
+
>;
|
| 92 |
+
|
| 93 |
+
using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
|
| 94 |
+
using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
|
| 95 |
+
|
| 96 |
+
public:
|
| 97 |
+
|
| 98 |
+
/// Initialization
|
| 99 |
+
cutlass::Distribution::Kind init_A;
|
| 100 |
+
cutlass::Distribution::Kind init_B;
|
| 101 |
+
cutlass::Distribution::Kind init_C;
|
| 102 |
+
uint64_t seed;
|
| 103 |
+
|
| 104 |
+
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
| 105 |
+
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
| 106 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
| 107 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
|
| 108 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
|
| 109 |
+
|
| 110 |
+
public:
|
| 111 |
+
|
| 112 |
+
TestbedConv3d(
|
| 113 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 114 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 115 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 116 |
+
uint64_t seed_ = 2080
|
| 117 |
+
):
|
| 118 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
|
| 119 |
+
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
/// Helper to initialize a tensor view
|
| 123 |
+
template <typename Element, typename Layout>
|
| 124 |
+
void initialize_tensor(
|
| 125 |
+
cutlass::TensorView<Element, Layout> view,
|
| 126 |
+
cutlass::Distribution::Kind dist_kind,
|
| 127 |
+
uint64_t seed) {
|
| 128 |
+
|
| 129 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 130 |
+
|
| 131 |
+
int scope;
|
| 132 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 133 |
+
|
| 134 |
+
if (bits <= 8) {
|
| 135 |
+
scope = 2;
|
| 136 |
+
}
|
| 137 |
+
else if (bits == 16) {
|
| 138 |
+
scope = 4;
|
| 139 |
+
}
|
| 140 |
+
else {
|
| 141 |
+
scope = 8;
|
| 142 |
+
}
|
| 143 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 144 |
+
view, seed, scope, -scope, 0);
|
| 145 |
+
}
|
| 146 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 147 |
+
|
| 148 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 149 |
+
}
|
| 150 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 151 |
+
|
| 152 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 153 |
+
}
|
| 154 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 155 |
+
|
| 156 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 157 |
+
}
|
| 158 |
+
else {
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
void initialize(
|
| 163 |
+
cutlass::conv::Conv3dProblemSize const &problem_size, uint64_t seed = 2019) {
|
| 164 |
+
|
| 165 |
+
tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
|
| 166 |
+
tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
|
| 167 |
+
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 168 |
+
tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 169 |
+
tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 170 |
+
|
| 171 |
+
initialize_tensor(tensor_A.host_view(), init_A, seed);
|
| 172 |
+
initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
|
| 173 |
+
initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
|
| 174 |
+
|
| 175 |
+
tensor_A.sync_device();
|
| 176 |
+
tensor_B.sync_device();
|
| 177 |
+
tensor_C.sync_device();
|
| 178 |
+
tensor_D_computed.sync_device();
|
| 179 |
+
tensor_D_reference.sync_device();
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
bool sufficient() const {
|
| 183 |
+
//
|
| 184 |
+
// Determine SMEM requirements and waive if not satisfied
|
| 185 |
+
//
|
| 186 |
+
|
| 187 |
+
size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage);
|
| 188 |
+
|
| 189 |
+
cudaDeviceProp properties;
|
| 190 |
+
int device_idx;
|
| 191 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 192 |
+
|
| 193 |
+
if (result != cudaSuccess) {
|
| 194 |
+
throw std::runtime_error("cudaGetDevice() API call failed.");
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
result = cudaGetDeviceProperties(&properties, device_idx);
|
| 198 |
+
|
| 199 |
+
if (result != cudaSuccess) {
|
| 200 |
+
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
if (properties.sharedMemPerBlockOptin < smem_size) {
|
| 204 |
+
return false;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
return true;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
/// Executes one test
|
| 212 |
+
bool run(
|
| 213 |
+
cutlass::conv::Conv3dProblemSize const &problem_size,
|
| 214 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 215 |
+
ElementCompute alpha = ElementCompute(1),
|
| 216 |
+
ElementCompute beta = ElementCompute()) {
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
// Waive test if insufficient CUDA device
|
| 220 |
+
if (!sufficient()) {
|
| 221 |
+
if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
|
| 222 |
+
std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
|
| 223 |
+
}
|
| 224 |
+
return true;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
#if 0 //display conv2d problem size for debugging
|
| 228 |
+
std::cout << problem_size << std::endl
|
| 229 |
+
<< "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl
|
| 230 |
+
<< "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
|
| 231 |
+
<< std::endl;
|
| 232 |
+
#endif
|
| 233 |
+
|
| 234 |
+
initialize(problem_size);
|
| 235 |
+
|
| 236 |
+
// configure the operator
|
| 237 |
+
Conv3d conv3d_op;
|
| 238 |
+
|
| 239 |
+
typename Conv3d::Arguments conv3d_args(
|
| 240 |
+
problem_size,
|
| 241 |
+
tensor_A.device_ref(),
|
| 242 |
+
tensor_B.device_ref(),
|
| 243 |
+
tensor_C.device_ref(),
|
| 244 |
+
tensor_D_computed.device_ref(),
|
| 245 |
+
{alpha, beta},
|
| 246 |
+
split_k_mode
|
| 247 |
+
);
|
| 248 |
+
|
| 249 |
+
cutlass::Status status = conv3d_op.can_implement(conv3d_args);
|
| 250 |
+
if (status != cutlass::Status::kSuccess) {
|
| 251 |
+
std::cerr << "can_implement failed for the given problem_size: \n";
|
| 252 |
+
return false;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
// find workspace requirement for parallel split-k reduction
|
| 256 |
+
size_t workspace_size = Conv3d::get_workspace_size(conv3d_args);
|
| 257 |
+
|
| 258 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 259 |
+
|
| 260 |
+
status = conv3d_op.initialize(conv3d_args, workspace.get());
|
| 261 |
+
|
| 262 |
+
if (status != cutlass::Status::kSuccess) {
|
| 263 |
+
cudaError_t error = cudaGetLastError();
|
| 264 |
+
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
|
| 265 |
+
return true;
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
// conv3d operation with parallel split-k-mode
|
| 269 |
+
if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
|
| 270 |
+
|
| 271 |
+
// conv3d output is written to workspace in global memory
|
| 272 |
+
conv3d_args.ref_D.reset(reinterpret_cast<ElementAccumulator*>(workspace.get()));
|
| 273 |
+
// accumulate mma for each cta in k-dimension (1.0 * A * B)
|
| 274 |
+
conv3d_args.output_op = {1.0, 0.0};
|
| 275 |
+
// update conv3d operator arguments
|
| 276 |
+
status = conv3d_op.update(conv3d_args, workspace.get());
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 280 |
+
if (status != cutlass::Status::kSuccess) {
|
| 281 |
+
return false;
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
// run conv3d operator
|
| 285 |
+
status = conv3d_op();
|
| 286 |
+
|
| 287 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 288 |
+
if (status != cutlass::Status::kSuccess) {
|
| 289 |
+
return false;
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
|
| 293 |
+
|
| 294 |
+
// configure parallel reduction operator
|
| 295 |
+
ReductionDevice reduction_op;
|
| 296 |
+
|
| 297 |
+
typename ReductionDevice::Arguments reduction_args(
|
| 298 |
+
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
|
| 299 |
+
problem_size.split_k_slices,
|
| 300 |
+
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
|
| 301 |
+
{
|
| 302 |
+
reinterpret_cast<ElementAccumulator*> (workspace.get()),
|
| 303 |
+
ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx])
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
tensor_D_computed.device_data(),
|
| 307 |
+
ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx])
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
tensor_C.device_data(),
|
| 311 |
+
ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx])
|
| 312 |
+
},
|
| 313 |
+
// apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C
|
| 314 |
+
{alpha, beta}
|
| 315 |
+
);
|
| 316 |
+
|
| 317 |
+
status = reduction_op.initialize(reduction_args, nullptr);
|
| 318 |
+
|
| 319 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 320 |
+
if (status != cutlass::Status::kSuccess) {
|
| 321 |
+
return false;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
// run prallel reduction kernel
|
| 325 |
+
status = reduction_op();
|
| 326 |
+
|
| 327 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 328 |
+
if (status != cutlass::Status::kSuccess) {
|
| 329 |
+
return false;
|
| 330 |
+
}
|
| 331 |
+
}
|
| 332 |
+
bool passed = false;
|
| 333 |
+
|
| 334 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 335 |
+
EXPECT_EQ(result, cudaSuccess) << " device reference error: "
|
| 336 |
+
<< cudaGetErrorString(result);
|
| 337 |
+
|
| 338 |
+
tensor_D_computed.sync_host();
|
| 339 |
+
|
| 340 |
+
//
|
| 341 |
+
// Reference check - support caching results
|
| 342 |
+
//
|
| 343 |
+
|
| 344 |
+
CachedTestKey cached_test_key = CreateCachedConv3dTestKey<
|
| 345 |
+
ElementA, LayoutA,
|
| 346 |
+
ElementB, LayoutB,
|
| 347 |
+
ElementC, LayoutC,
|
| 348 |
+
ElementAccumulator,
|
| 349 |
+
ElementCompute
|
| 350 |
+
>(
|
| 351 |
+
kConvolutionalOperator,
|
| 352 |
+
problem_size,
|
| 353 |
+
alpha,
|
| 354 |
+
beta,
|
| 355 |
+
tensor_A.host_view(),
|
| 356 |
+
tensor_B.host_view(),
|
| 357 |
+
tensor_C.host_view()
|
| 358 |
+
);
|
| 359 |
+
|
| 360 |
+
//
|
| 361 |
+
// Look for the cached key
|
| 362 |
+
//
|
| 363 |
+
|
| 364 |
+
bool cached_result_loaded = false;
|
| 365 |
+
CachedTestResult cached_test_result;
|
| 366 |
+
|
| 367 |
+
std::string conv3d_result_cache_name =
|
| 368 |
+
std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
|
| 369 |
+
|
| 370 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 371 |
+
|
| 372 |
+
CachedTestResultListing cached_results(conv3d_result_cache_name);
|
| 373 |
+
|
| 374 |
+
auto cached = cached_results.find(cached_test_key);
|
| 375 |
+
|
| 376 |
+
cached_result_loaded = cached.first;
|
| 377 |
+
if (cached_result_loaded) {
|
| 378 |
+
cached_test_result = cached.second;
|
| 379 |
+
}
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
if (!cached_result_loaded) {
|
| 383 |
+
|
| 384 |
+
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
|
| 385 |
+
|
| 386 |
+
cutlass::reference::device::Conv3d<
|
| 387 |
+
ElementA,
|
| 388 |
+
LayoutA,
|
| 389 |
+
ElementB,
|
| 390 |
+
LayoutB,
|
| 391 |
+
ElementC,
|
| 392 |
+
LayoutC,
|
| 393 |
+
ElementAccumulator,
|
| 394 |
+
ElementCompute
|
| 395 |
+
>(
|
| 396 |
+
kConvolutionalOperator,
|
| 397 |
+
problem_size,
|
| 398 |
+
tensor_A.device_ref(),
|
| 399 |
+
tensor_B.device_ref(),
|
| 400 |
+
tensor_C.device_ref(),
|
| 401 |
+
tensor_D_reference.device_ref(),
|
| 402 |
+
alpha,
|
| 403 |
+
beta
|
| 404 |
+
);
|
| 405 |
+
|
| 406 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 407 |
+
tensor_D_reference.sync_host();
|
| 408 |
+
|
| 409 |
+
#else
|
| 410 |
+
cutlass::reference::host::Conv3d<
|
| 411 |
+
ElementA,
|
| 412 |
+
LayoutA,
|
| 413 |
+
ElementB,
|
| 414 |
+
LayoutB,
|
| 415 |
+
ElementC,
|
| 416 |
+
LayoutC,
|
| 417 |
+
ElementAccumulator,
|
| 418 |
+
ElementCompute
|
| 419 |
+
>(
|
| 420 |
+
kConvolutionalOperator,
|
| 421 |
+
problem_size,
|
| 422 |
+
tensor_A.host_ref(),
|
| 423 |
+
tensor_B.host_ref(),
|
| 424 |
+
tensor_C.host_ref(),
|
| 425 |
+
tensor_D_reference.host_ref(),
|
| 426 |
+
alpha,
|
| 427 |
+
beta
|
| 428 |
+
);
|
| 429 |
+
#endif
|
| 430 |
+
|
| 431 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 432 |
+
|
| 433 |
+
cached_test_result.D = TensorHash(tensor_D_reference.host_view());
|
| 434 |
+
|
| 435 |
+
CachedTestResultListing cached_results(conv3d_result_cache_name);
|
| 436 |
+
|
| 437 |
+
cached_results.append(cached_test_key, cached_test_result);
|
| 438 |
+
cached_results.write(conv3d_result_cache_name);
|
| 439 |
+
}
|
| 440 |
+
} // if (!cached_result_loaded)
|
| 441 |
+
|
| 442 |
+
uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
|
| 443 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 444 |
+
passed = (tensor_D_hash == cached_test_result.D);
|
| 445 |
+
|
| 446 |
+
EXPECT_EQ(tensor_D_hash, cached_test_result.D)
|
| 447 |
+
<< "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
|
| 448 |
+
}
|
| 449 |
+
else {
|
| 450 |
+
|
| 451 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 452 |
+
tensor_D_computed.host_view(),
|
| 453 |
+
tensor_D_reference.host_view());
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
EXPECT_TRUE(passed);
|
| 457 |
+
|
| 458 |
+
if (!passed) {
|
| 459 |
+
std::stringstream fname;
|
| 460 |
+
|
| 461 |
+
fname << "error_Conv3d_ImplicitGemm_device_"
|
| 462 |
+
<< (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
|
| 463 |
+
<< (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
|
| 464 |
+
(Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" :
|
| 465 |
+
(Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_")))
|
| 466 |
+
<< "ndhwc_"
|
| 467 |
+
<< problem_size.N << "x"
|
| 468 |
+
<< problem_size.D << "x"
|
| 469 |
+
<< problem_size.H << "x"
|
| 470 |
+
<< problem_size.W << "x"
|
| 471 |
+
<< problem_size.C
|
| 472 |
+
<< "_ktrsc_"
|
| 473 |
+
<< problem_size.K << "x"
|
| 474 |
+
<< problem_size.T << "x"
|
| 475 |
+
<< problem_size.R << "x"
|
| 476 |
+
<< problem_size.S << "x"
|
| 477 |
+
<< problem_size.C
|
| 478 |
+
<< "_padding_"
|
| 479 |
+
<< problem_size.pad_d << "x"
|
| 480 |
+
<< problem_size.pad_h << "x"
|
| 481 |
+
<< problem_size.pad_w
|
| 482 |
+
<< "_stride_"
|
| 483 |
+
<< problem_size.stride_d << "x"
|
| 484 |
+
<< problem_size.stride_h << "x"
|
| 485 |
+
<< problem_size.stride_w
|
| 486 |
+
<< "_dilation_"
|
| 487 |
+
<< problem_size.dilation_d << "x"
|
| 488 |
+
<< problem_size.dilation_h << "x"
|
| 489 |
+
<< problem_size.dilation_w << "_"
|
| 490 |
+
<< (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_")
|
| 491 |
+
<< Conv3d::ThreadblockShape::kM << "x"
|
| 492 |
+
<< Conv3d::ThreadblockShape::kN << "x"
|
| 493 |
+
<< Conv3d::ThreadblockShape::kK << "_"
|
| 494 |
+
<< Conv3d::WarpShape::kM << "x"
|
| 495 |
+
<< Conv3d::WarpShape::kN << "x"
|
| 496 |
+
<< Conv3d::WarpShape::kK << ".txt";
|
| 497 |
+
|
| 498 |
+
std::cout << fname.str() << std::endl;
|
| 499 |
+
|
| 500 |
+
std::ofstream results(fname.str());
|
| 501 |
+
|
| 502 |
+
results << problem_size << std::endl;
|
| 503 |
+
|
| 504 |
+
results
|
| 505 |
+
<< "\nA:\n" << tensor_A.host_view() << "\n"
|
| 506 |
+
<< "\nB:\n" << tensor_B.host_view() << "\n"
|
| 507 |
+
<< "\nC:\n" << tensor_C.host_view() << "\n";
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
results << "\nD reference (hash: " << cached_test_result.D << ")\n";
|
| 511 |
+
|
| 512 |
+
if (!cached_result_loaded) {
|
| 513 |
+
results
|
| 514 |
+
<< tensor_D_reference.host_view() << "\n";
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
results
|
| 518 |
+
<< "\nD computed (hash: " << tensor_D_hash << ")\n"
|
| 519 |
+
<< tensor_D_computed.host_view() << "\n";
|
| 520 |
+
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
return passed;
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
};
|
| 527 |
+
|
| 528 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 529 |
+
// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
|
| 530 |
+
// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
|
| 531 |
+
// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
|
| 532 |
+
// (conv_blacklist_sizes)
|
| 533 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 534 |
+
|
| 535 |
+
template <typename ImplicitGemm>
|
| 536 |
+
bool TestAllConv3d(
|
| 537 |
+
const Conv3dProblemVector & conv_test_sizes = Conv3dProblemVector(),
|
| 538 |
+
const Conv3dProblemVector & conv_blacklist_sizes = Conv3dProblemVector()) {
|
| 539 |
+
|
| 540 |
+
bool passed = true;
|
| 541 |
+
|
| 542 |
+
//
|
| 543 |
+
// Testbed object
|
| 544 |
+
//
|
| 545 |
+
|
| 546 |
+
//TestbedConv3d<ImplicitGemm> testbed(cutlass::Distribution::Sequential, cutlass::Distribution::Sequential, cutlass::Distribution::Sequential);
|
| 547 |
+
TestbedConv3d<ImplicitGemm> testbed;
|
| 548 |
+
|
| 549 |
+
//
|
| 550 |
+
// Get conv problem sizes to run conv operator
|
| 551 |
+
//
|
| 552 |
+
TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
|
| 553 |
+
|
| 554 |
+
// Vector of conv3d problem sizes to avoid duplicate runs
|
| 555 |
+
Conv3dProblemVector conv_tested_sizes;
|
| 556 |
+
|
| 557 |
+
Conv3dProblemVector const *problem_vectors[] = {
|
| 558 |
+
&conv3d_problems.conv3d_default_sizes,
|
| 559 |
+
&conv3d_problems.conv3d_vnet_medical_sizes,
|
| 560 |
+
&conv_test_sizes
|
| 561 |
+
};
|
| 562 |
+
|
| 563 |
+
// Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 564 |
+
for (Conv3dProblemVector const * problem_vector : problem_vectors) {
|
| 565 |
+
|
| 566 |
+
// Run conv testbed on default convolution sizes
|
| 567 |
+
for(auto conv_problem : *problem_vector) {
|
| 568 |
+
|
| 569 |
+
// Skip blacklist and avoid duplicate problem sizes
|
| 570 |
+
if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
|
| 571 |
+
std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
|
| 572 |
+
continue;
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
//
|
| 576 |
+
// Procedurally disable certain cases
|
| 577 |
+
//
|
| 578 |
+
|
| 579 |
+
// CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1}
|
| 580 |
+
if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
|
| 581 |
+
ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
|
| 582 |
+
((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 583 |
+
cutlass::conv::StrideSupport::kUnity) ||
|
| 584 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport ==
|
| 585 |
+
cutlass::conv::StrideSupport::kUnity))) {
|
| 586 |
+
if (!((conv_problem.stride_d == 1) &&
|
| 587 |
+
(conv_problem.stride_h == 1) &&
|
| 588 |
+
(conv_problem.stride_w == 1))
|
| 589 |
+
) {
|
| 590 |
+
continue;
|
| 591 |
+
}
|
| 592 |
+
}
|
| 593 |
+
|
| 594 |
+
//
|
| 595 |
+
// Test
|
| 596 |
+
//
|
| 597 |
+
// push back tested problem size to avoid re-running duplicates
|
| 598 |
+
conv_tested_sizes.push_back(conv_problem);
|
| 599 |
+
|
| 600 |
+
// test mode = xcross
|
| 601 |
+
passed = testbed.run(
|
| 602 |
+
conv_problem,
|
| 603 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 604 |
+
|
| 605 |
+
if (!passed) {
|
| 606 |
+
return false;
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
// test mode = convolution
|
| 610 |
+
passed = testbed.run(
|
| 611 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 612 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 613 |
+
|
| 614 |
+
if (!passed) {
|
| 615 |
+
return false;
|
| 616 |
+
}
|
| 617 |
+
}
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
// Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for
|
| 621 |
+
// a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
| 622 |
+
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
| 623 |
+
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
| 624 |
+
cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size (
|
| 625 |
+
{1, 8, 8, 8, 32}, // input size (NDHWC)
|
| 626 |
+
{32, 3, 3, 3, 32}, // filter size (KTRSC)
|
| 627 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 628 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 629 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 630 |
+
);
|
| 631 |
+
|
| 632 |
+
cutlass::conv::SplitKMode split_k_modes [] = {
|
| 633 |
+
cutlass::conv::SplitKMode::kSerial,
|
| 634 |
+
cutlass::conv::SplitKMode::kParallel
|
| 635 |
+
};
|
| 636 |
+
|
| 637 |
+
int split_k_slices[] = {
|
| 638 |
+
1, 2, 3, 4, 201
|
| 639 |
+
};
|
| 640 |
+
|
| 641 |
+
double problem_alpha[] = {
|
| 642 |
+
2.0
|
| 643 |
+
};
|
| 644 |
+
|
| 645 |
+
double problem_beta[] = {
|
| 646 |
+
2.0
|
| 647 |
+
};
|
| 648 |
+
|
| 649 |
+
for (auto split_k_mode : split_k_modes) {
|
| 650 |
+
for (auto split_k_slice : split_k_slices) {
|
| 651 |
+
for (auto alpha : problem_alpha) {
|
| 652 |
+
for (auto beta : problem_beta) {
|
| 653 |
+
|
| 654 |
+
passed = testbed.run(
|
| 655 |
+
conv3d_split_k_test_size.reset_split_k_slices(split_k_slice),
|
| 656 |
+
split_k_mode,
|
| 657 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
|
| 658 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
|
| 659 |
+
|
| 660 |
+
if (!passed) {
|
| 661 |
+
return false;
|
| 662 |
+
}
|
| 663 |
+
}
|
| 664 |
+
}
|
| 665 |
+
}
|
| 666 |
+
}
|
| 667 |
+
|
| 668 |
+
return passed;
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
template <typename ImplicitGemm>
|
| 672 |
+
bool TestSpecificConv3d(
|
| 673 |
+
const Conv3dProblemVector & problem_sizes) {
|
| 674 |
+
|
| 675 |
+
bool passed = true;
|
| 676 |
+
|
| 677 |
+
//
|
| 678 |
+
// Testbed object
|
| 679 |
+
//
|
| 680 |
+
|
| 681 |
+
TestbedConv3d<ImplicitGemm> testbed;
|
| 682 |
+
|
| 683 |
+
// Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 684 |
+
for(auto conv_problem : problem_sizes) {
|
| 685 |
+
|
| 686 |
+
//
|
| 687 |
+
// Test
|
| 688 |
+
//
|
| 689 |
+
|
| 690 |
+
// test mode = xcross
|
| 691 |
+
passed = testbed.run(
|
| 692 |
+
conv_problem,
|
| 693 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 694 |
+
|
| 695 |
+
if (!passed) {
|
| 696 |
+
return false;
|
| 697 |
+
}
|
| 698 |
+
|
| 699 |
+
// test mode = convolution
|
| 700 |
+
passed = testbed.run(
|
| 701 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 702 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 703 |
+
|
| 704 |
+
if (!passed) {
|
| 705 |
+
return false;
|
| 706 |
+
}
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
return true;
|
| 710 |
+
}
|
| 711 |
+
|
| 712 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 713 |
+
|
| 714 |
+
} // namespace device
|
| 715 |
+
} // namespace conv
|
| 716 |
+
} // namespace test
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h
ADDED
|
@@ -0,0 +1,732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implicit GEMM for fused epilogue broadcast testbed
|
| 33 |
+
|
| 34 |
+
Parallel split-k is not tested because we can just use regular conv kernel
|
| 35 |
+
when we need to use parallel-splitk. Broadcast can happen in the reduction
|
| 36 |
+
kernel.
|
| 37 |
+
*/
|
| 38 |
+
#pragma once
|
| 39 |
+
|
| 40 |
+
#include <fstream>
|
| 41 |
+
|
| 42 |
+
#include "../../common/cutlass_unit_test.h"
|
| 43 |
+
#include "cutlass/cutlass.h"
|
| 44 |
+
|
| 45 |
+
#include "cutlass/conv/device/implicit_gemm_convolution.h"
|
| 46 |
+
#include "cutlass/reduction/device/reduce_split_k.h"
|
| 47 |
+
#include "cutlass/reduction/thread/reduction_operators.h"
|
| 48 |
+
|
| 49 |
+
#include "conv3d_problems.h"
|
| 50 |
+
|
| 51 |
+
#include "cutlass/util/host_tensor.h"
|
| 52 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 53 |
+
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 54 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 55 |
+
|
| 56 |
+
#include "cutlass/util/reference/host/convolution.h"
|
| 57 |
+
#include "cutlass/util/reference/device/convolution.h"
|
| 58 |
+
|
| 59 |
+
#include "cutlass/core_io.h"
|
| 60 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 61 |
+
|
| 62 |
+
#include "../cache_testbed_output.h"
|
| 63 |
+
|
| 64 |
+
namespace test {
|
| 65 |
+
namespace conv {
|
| 66 |
+
namespace device {
|
| 67 |
+
|
| 68 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 69 |
+
|
| 70 |
+
template <typename Conv3d>
|
| 71 |
+
struct Conv3dWithBroadcastReferenceOp {
|
| 72 |
+
|
| 73 |
+
using OutputOp = typename Conv3d::EpilogueOutputOp;
|
| 74 |
+
|
| 75 |
+
using ElementCompute = typename OutputOp::ElementCompute;
|
| 76 |
+
using ElementZ = typename OutputOp::ElementZ;
|
| 77 |
+
using ElementT = typename OutputOp::ElementT;
|
| 78 |
+
|
| 79 |
+
typename OutputOp::BinaryOp binary_op;
|
| 80 |
+
typename OutputOp::ElementwiseOp elementwise_op;
|
| 81 |
+
|
| 82 |
+
Conv3dWithBroadcastReferenceOp() { }
|
| 83 |
+
|
| 84 |
+
void operator()(ElementZ &Z, ElementT &T, ElementCompute conv3d, ElementCompute bias) {
|
| 85 |
+
ElementCompute t_full = binary_op(conv3d, bias);
|
| 86 |
+
T = ElementT(t_full);
|
| 87 |
+
|
| 88 |
+
ElementCompute z_full = elementwise_op(t_full);
|
| 89 |
+
Z = ElementZ(z_full);
|
| 90 |
+
}
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 94 |
+
|
| 95 |
+
// Fused testbed
|
| 96 |
+
//
|
| 97 |
+
// Y = CONV(AB, C)
|
| 98 |
+
//
|
| 99 |
+
// T[n, o, p, q, k] = ReductionOp(Y[n, o, p, q, k], Broadcast[k])
|
| 100 |
+
//
|
| 101 |
+
// Z[n, o, p, q, k] = Elementwise(T[n, o, p, q, k])
|
| 102 |
+
//
|
| 103 |
+
|
| 104 |
+
template <
|
| 105 |
+
typename Conv3d,
|
| 106 |
+
typename ReferenceOp,
|
| 107 |
+
bool AddBroadcastFirst = false
|
| 108 |
+
>
|
| 109 |
+
class TestbedConv3dWithBroadcast {
|
| 110 |
+
public:
|
| 111 |
+
|
| 112 |
+
using ElementA = typename Conv3d::ElementA;
|
| 113 |
+
using LayoutA = typename Conv3d::LayoutA;
|
| 114 |
+
using ElementB = typename Conv3d::ElementB;
|
| 115 |
+
using LayoutB = typename Conv3d::LayoutB;
|
| 116 |
+
using ElementC = typename Conv3d::ElementC;
|
| 117 |
+
using LayoutC = typename Conv3d::LayoutC;
|
| 118 |
+
using ElementAccumulator = typename Conv3d::ElementAccumulator;
|
| 119 |
+
using ElementCompute = typename Conv3d::ElementCompute;
|
| 120 |
+
using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp;
|
| 121 |
+
using ElementZ = typename EpilogueOutputOp::ElementZ;
|
| 122 |
+
using ElementT = typename EpilogueOutputOp::ElementT;
|
| 123 |
+
using ElementVector = typename EpilogueOutputOp::ElementVector;
|
| 124 |
+
|
| 125 |
+
static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator;
|
| 126 |
+
static const bool kAddBroadcastFirst = AddBroadcastFirst;
|
| 127 |
+
static const bool kStoreT = EpilogueOutputOp::kStoreT;
|
| 128 |
+
|
| 129 |
+
public:
|
| 130 |
+
|
| 131 |
+
/// Initialization
|
| 132 |
+
cutlass::Distribution::Kind init_A;
|
| 133 |
+
cutlass::Distribution::Kind init_B;
|
| 134 |
+
cutlass::Distribution::Kind init_C;
|
| 135 |
+
uint64_t seed;
|
| 136 |
+
|
| 137 |
+
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
| 138 |
+
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
| 139 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
| 140 |
+
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_C_reference;
|
| 141 |
+
cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_computed;
|
| 142 |
+
cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_reference;
|
| 143 |
+
cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
|
| 144 |
+
cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
|
| 145 |
+
cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
|
| 146 |
+
cutlass::HostTensor<ElementVector, LayoutC> tensor_Broadcast; // Input Broadcast
|
| 147 |
+
|
| 148 |
+
public:
|
| 149 |
+
|
| 150 |
+
TestbedConv3dWithBroadcast(
|
| 151 |
+
cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 152 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 153 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 154 |
+
uint64_t seed_ = 2080
|
| 155 |
+
):
|
| 156 |
+
init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
|
| 157 |
+
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
/// Helper to initialize a tensor view
|
| 161 |
+
template <typename Element, typename Layout>
|
| 162 |
+
void initialize_tensor(
|
| 163 |
+
cutlass::TensorView<Element, Layout> view,
|
| 164 |
+
cutlass::Distribution::Kind dist_kind,
|
| 165 |
+
uint64_t seed) {
|
| 166 |
+
|
| 167 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 168 |
+
|
| 169 |
+
int scope;
|
| 170 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 171 |
+
|
| 172 |
+
if (bits <= 8) {
|
| 173 |
+
scope = 2;
|
| 174 |
+
}
|
| 175 |
+
else if (bits == 16) {
|
| 176 |
+
if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
|
| 177 |
+
scope = 3;
|
| 178 |
+
}
|
| 179 |
+
else {
|
| 180 |
+
scope = 5;
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
else {
|
| 184 |
+
scope = 8;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
cutlass::reference::host::TensorFillRandomUniform(
|
| 188 |
+
view, seed, scope, -scope, 0);
|
| 189 |
+
}
|
| 190 |
+
else if (dist_kind == cutlass::Distribution::Identity) {
|
| 191 |
+
|
| 192 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 193 |
+
}
|
| 194 |
+
else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 195 |
+
|
| 196 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 197 |
+
}
|
| 198 |
+
else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 199 |
+
|
| 200 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 201 |
+
}
|
| 202 |
+
else {
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
void initialize(
|
| 207 |
+
cutlass::conv::Conv3dProblemSize const &problem_size, bool non_packed_test = false, uint64_t seed = 2019) {
|
| 208 |
+
|
| 209 |
+
// to make the layout of tensors a little bit bigger than the problem size
|
| 210 |
+
cutlass::Tensor5DCoord stride_increment = cutlass::Tensor5DCoord(8, 16, 32, 32, 64);
|
| 211 |
+
|
| 212 |
+
cutlass::Tensor5DCoord tensor_A_extent = implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size);
|
| 213 |
+
cutlass::Tensor5DCoord tensor_B_extent = implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size);
|
| 214 |
+
cutlass::Tensor5DCoord tensor_C_extent = implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size);
|
| 215 |
+
|
| 216 |
+
if (non_packed_test) {
|
| 217 |
+
tensor_A_extent += stride_increment;
|
| 218 |
+
tensor_C_extent += stride_increment;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
tensor_A.resize(tensor_A_extent);
|
| 222 |
+
tensor_B.resize(tensor_B_extent);
|
| 223 |
+
tensor_C.resize(tensor_C_extent);
|
| 224 |
+
tensor_C_reference.resize(tensor_C_extent);
|
| 225 |
+
tensor_Z_computed.resize(tensor_C_extent);
|
| 226 |
+
tensor_Z_reference.resize(tensor_C_extent);
|
| 227 |
+
tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 228 |
+
tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 229 |
+
tensor_Y_reference.resize(tensor_C_extent);
|
| 230 |
+
tensor_Broadcast.resize({
|
| 231 |
+
1,
|
| 232 |
+
1,
|
| 233 |
+
1,
|
| 234 |
+
1,
|
| 235 |
+
implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(),
|
| 236 |
+
});
|
| 237 |
+
|
| 238 |
+
initialize_tensor(tensor_A.host_view(), init_A, seed);
|
| 239 |
+
initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
|
| 240 |
+
initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
|
| 241 |
+
initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39);
|
| 242 |
+
for (int n = 0; n < tensor_C_reference.extent().n(); ++n) {
|
| 243 |
+
for (int o = 0; o < tensor_C_reference.extent().d(); ++o) {
|
| 244 |
+
for (int p = 0; p < tensor_C_reference.extent().h(); ++p) {
|
| 245 |
+
for (int q = 0; q < tensor_C_reference.extent().w(); ++q) {
|
| 246 |
+
for (int k = 0; k < tensor_C_reference.extent().c(); ++k) {
|
| 247 |
+
tensor_C_reference.at({n, o, p, q, k}) = ElementAccumulator(tensor_C.at({n, o, p, q, k}));
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
tensor_A.sync_device();
|
| 254 |
+
tensor_B.sync_device();
|
| 255 |
+
tensor_C.sync_device();
|
| 256 |
+
tensor_Broadcast.sync_device();
|
| 257 |
+
tensor_C_reference.sync_device();
|
| 258 |
+
tensor_Z_computed.sync_device();
|
| 259 |
+
tensor_Z_reference.sync_device();
|
| 260 |
+
tensor_T_computed.sync_device();
|
| 261 |
+
tensor_T_reference.sync_device();
|
| 262 |
+
tensor_Y_reference.sync_device();
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
bool sufficient() const {
|
| 266 |
+
//
|
| 267 |
+
// Determine SMEM requirements and waive if not satisfied
|
| 268 |
+
//
|
| 269 |
+
|
| 270 |
+
size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage);
|
| 271 |
+
|
| 272 |
+
cudaDeviceProp properties;
|
| 273 |
+
int device_idx;
|
| 274 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 275 |
+
|
| 276 |
+
if (result != cudaSuccess) {
|
| 277 |
+
throw std::runtime_error("cudaGetDevice() API call failed.");
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
result = cudaGetDeviceProperties(&properties, device_idx);
|
| 281 |
+
|
| 282 |
+
if (result != cudaSuccess) {
|
| 283 |
+
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
if (properties.sharedMemPerBlockOptin < smem_size) {
|
| 287 |
+
return false;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
return true;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
/// Executes one test
|
| 294 |
+
bool run(
|
| 295 |
+
cutlass::conv::Conv3dProblemSize const &problem_size,
|
| 296 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 297 |
+
bool non_packed_test = false,
|
| 298 |
+
ElementCompute alpha = ElementCompute(1),
|
| 299 |
+
ElementCompute beta = ElementCompute(1)) {
|
| 300 |
+
|
| 301 |
+
// Waive test if insufficient CUDA device
|
| 302 |
+
if (!sufficient()) {
|
| 303 |
+
if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
|
| 304 |
+
std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
|
| 305 |
+
}
|
| 306 |
+
return true;
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
#if 0 //display conv3d problem size for debugging
|
| 310 |
+
std::cout << problem_size << std::endl
|
| 311 |
+
<< "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
|
| 312 |
+
<< "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
|
| 313 |
+
<< std::endl;
|
| 314 |
+
#endif
|
| 315 |
+
|
| 316 |
+
initialize(problem_size, non_packed_test);
|
| 317 |
+
|
| 318 |
+
// configure the operator
|
| 319 |
+
Conv3d conv3d_op;
|
| 320 |
+
typename Conv3d::Arguments conv3d_args(
|
| 321 |
+
problem_size,
|
| 322 |
+
tensor_A.device_ref(),
|
| 323 |
+
tensor_B.device_ref(),
|
| 324 |
+
tensor_C.device_ref(),
|
| 325 |
+
tensor_Z_computed.device_ref(),
|
| 326 |
+
{alpha, beta},
|
| 327 |
+
split_k_mode,
|
| 328 |
+
tensor_Broadcast.device_data(),
|
| 329 |
+
kStoreT ? tensor_T_computed.device_data() : nullptr,
|
| 330 |
+
0, // This must be zero
|
| 331 |
+
implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()
|
| 332 |
+
);
|
| 333 |
+
|
| 334 |
+
// initialize the kernel
|
| 335 |
+
size_t workspace_size = Conv3d::get_workspace_size(conv3d_args);
|
| 336 |
+
|
| 337 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 338 |
+
|
| 339 |
+
cutlass::Status status = conv3d_op.initialize(conv3d_args, workspace.get());
|
| 340 |
+
|
| 341 |
+
if (status != cutlass::Status::kSuccess) {
|
| 342 |
+
cudaError_t error = cudaGetLastError();
|
| 343 |
+
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
|
| 344 |
+
return true;
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
// run conv3d operator
|
| 348 |
+
status = conv3d_op();
|
| 349 |
+
|
| 350 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 351 |
+
if (status != cutlass::Status::kSuccess) {
|
| 352 |
+
return false;
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
bool passed = false;
|
| 356 |
+
|
| 357 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 358 |
+
EXPECT_EQ(result, cudaSuccess) << " device reference error: "
|
| 359 |
+
<< cudaGetErrorString(result);
|
| 360 |
+
|
| 361 |
+
tensor_T_computed.sync_host();
|
| 362 |
+
tensor_Z_computed.sync_host();
|
| 363 |
+
|
| 364 |
+
//
|
| 365 |
+
// Reference check
|
| 366 |
+
//
|
| 367 |
+
|
| 368 |
+
// When kAddBroadcastFirst is true, add bias on the host
|
| 369 |
+
ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta;
|
| 370 |
+
|
| 371 |
+
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
|
| 372 |
+
|
| 373 |
+
cutlass::reference::device::Conv3d<
|
| 374 |
+
ElementA,
|
| 375 |
+
LayoutA,
|
| 376 |
+
ElementB,
|
| 377 |
+
LayoutB,
|
| 378 |
+
ElementAccumulator,
|
| 379 |
+
LayoutC,
|
| 380 |
+
ElementAccumulator,
|
| 381 |
+
ElementAccumulator
|
| 382 |
+
>(
|
| 383 |
+
kConvolutionalOperator,
|
| 384 |
+
problem_size,
|
| 385 |
+
tensor_A.device_ref(),
|
| 386 |
+
tensor_B.device_ref(),
|
| 387 |
+
tensor_C_reference.device_ref(),
|
| 388 |
+
tensor_Y_reference.device_ref(),
|
| 389 |
+
alpha,
|
| 390 |
+
beta_ref);
|
| 391 |
+
|
| 392 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 393 |
+
tensor_Y_reference.sync_host();
|
| 394 |
+
|
| 395 |
+
#else
|
| 396 |
+
|
| 397 |
+
cutlass::reference::host::Conv3d<
|
| 398 |
+
ElementA,
|
| 399 |
+
LayoutA,
|
| 400 |
+
ElementB,
|
| 401 |
+
LayoutB,
|
| 402 |
+
ElementAccumulator,
|
| 403 |
+
LayoutC,
|
| 404 |
+
ElementAccumulator,
|
| 405 |
+
ElementAccumulator
|
| 406 |
+
>(
|
| 407 |
+
kConvolutionalOperator,
|
| 408 |
+
problem_size,
|
| 409 |
+
tensor_A.host_ref(),
|
| 410 |
+
tensor_B.host_ref(),
|
| 411 |
+
tensor_C_reference.host_ref(),
|
| 412 |
+
tensor_Y_reference.host_ref(),
|
| 413 |
+
alpha,
|
| 414 |
+
beta_ref);
|
| 415 |
+
|
| 416 |
+
#endif
|
| 417 |
+
ReferenceOp reference_op;
|
| 418 |
+
|
| 419 |
+
// compute tensor Z and tensor T
|
| 420 |
+
for (int n = 0; n < problem_size.N; ++n) {
|
| 421 |
+
for (int o = 0; o < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Z : problem_size.D); ++o) {
|
| 422 |
+
for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) {
|
| 423 |
+
for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) {
|
| 424 |
+
for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) {
|
| 425 |
+
|
| 426 |
+
ElementZ z{};
|
| 427 |
+
ElementT t{};
|
| 428 |
+
|
| 429 |
+
ElementCompute accum = tensor_Y_reference.at({n, o, p, q, k});
|
| 430 |
+
ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, 0, k}));
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
if (kAddBroadcastFirst) {
|
| 434 |
+
reference_op(z, t, accum + bias,
|
| 435 |
+
beta * ElementCompute(tensor_C_reference.at({n, o, p, q, k})));
|
| 436 |
+
} else {
|
| 437 |
+
reference_op(z, t, accum, bias);
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
tensor_Z_reference.at({n, o, p, q, k}) = z;
|
| 441 |
+
tensor_T_reference.at({n, o, p, q, k}) = t;
|
| 442 |
+
}
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
if (kStoreT) {
|
| 449 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 450 |
+
tensor_T_computed.host_view(),
|
| 451 |
+
tensor_T_reference.host_view());
|
| 452 |
+
|
| 453 |
+
EXPECT_TRUE(passed);
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 457 |
+
tensor_Z_computed.host_view(),
|
| 458 |
+
tensor_Z_reference.host_view());
|
| 459 |
+
|
| 460 |
+
EXPECT_TRUE(passed);
|
| 461 |
+
|
| 462 |
+
if (!passed) {
|
| 463 |
+
std::stringstream fname;
|
| 464 |
+
|
| 465 |
+
fname << "error_Conv3d_ImplicitGemm_device_"
|
| 466 |
+
<< (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
|
| 467 |
+
<< (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
|
| 468 |
+
(Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" :
|
| 469 |
+
(Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_")))
|
| 470 |
+
<< "nnhwc_"
|
| 471 |
+
<< problem_size.N << "x"
|
| 472 |
+
<< problem_size.D << "x"
|
| 473 |
+
<< problem_size.H << "x"
|
| 474 |
+
<< problem_size.W << "x"
|
| 475 |
+
<< problem_size.C
|
| 476 |
+
<< "_krsc_"
|
| 477 |
+
<< problem_size.K << "x"
|
| 478 |
+
<< problem_size.T << "x"
|
| 479 |
+
<< problem_size.R << "x"
|
| 480 |
+
<< problem_size.S << "x"
|
| 481 |
+
<< problem_size.C
|
| 482 |
+
<< "_padding_"
|
| 483 |
+
<< problem_size.pad_d << "x"
|
| 484 |
+
<< problem_size.pad_h << "x"
|
| 485 |
+
<< problem_size.pad_w
|
| 486 |
+
<< "_stride_"
|
| 487 |
+
<< problem_size.stride_d << "x"
|
| 488 |
+
<< problem_size.stride_h << "x"
|
| 489 |
+
<< problem_size.stride_w
|
| 490 |
+
<< "_dilation_"
|
| 491 |
+
<< problem_size.dilation_d << "x"
|
| 492 |
+
<< problem_size.dilation_h << "x"
|
| 493 |
+
<< problem_size.dilation_w << "_"
|
| 494 |
+
<< (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_")
|
| 495 |
+
<< (non_packed_test ? "non_packed_tensor_test_" : "packed_tensor_test_")
|
| 496 |
+
<< Conv3d::ThreadblockShape::kM << "x"
|
| 497 |
+
<< Conv3d::ThreadblockShape::kN << "x"
|
| 498 |
+
<< Conv3d::ThreadblockShape::kK << "_"
|
| 499 |
+
<< Conv3d::WarpShape::kM << "x"
|
| 500 |
+
<< Conv3d::WarpShape::kN << "x"
|
| 501 |
+
<< Conv3d::WarpShape::kK << ".txt";
|
| 502 |
+
|
| 503 |
+
std::cout << fname.str() << std::endl;
|
| 504 |
+
|
| 505 |
+
std::ofstream results(fname.str());
|
| 506 |
+
|
| 507 |
+
results << problem_size << std::endl;
|
| 508 |
+
|
| 509 |
+
results
|
| 510 |
+
<< "\nA:\n" << tensor_A.host_view() << "\n"
|
| 511 |
+
<< "\nB:\n" << tensor_B.host_view() << "\n"
|
| 512 |
+
<< "\nC:\n" << tensor_C.host_view() << "\n"
|
| 513 |
+
<< "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n"
|
| 514 |
+
<< "\nY reference:\n" << tensor_Y_reference.host_view() << "\n"
|
| 515 |
+
<< "\nT reference:\n" << tensor_T_reference.host_view() << "\n"
|
| 516 |
+
<< "\nT computed:\n" << tensor_T_computed.host_view() << "\n"
|
| 517 |
+
<< "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n"
|
| 518 |
+
<< "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n";
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
return passed;
|
| 522 |
+
}
|
| 523 |
+
};
|
| 524 |
+
|
| 525 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 526 |
+
// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
|
| 527 |
+
// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv3dProblemSizes
|
| 528 |
+
// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
|
| 529 |
+
// (conv_blacklist_sizes)
|
| 530 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 531 |
+
template <typename ImplicitGemm,
|
| 532 |
+
typename ReferenceOp = Conv3dWithBroadcastReferenceOp<ImplicitGemm>,
|
| 533 |
+
bool AddBroadcastFirst = false,
|
| 534 |
+
bool TestSplitK = true
|
| 535 |
+
>
|
| 536 |
+
bool TestAllConv3dWithBroadcast(
|
| 537 |
+
const Conv3dProblemVector &conv_test_sizes = Conv3dProblemVector(),
|
| 538 |
+
const Conv3dProblemVector &conv_blacklist_sizes = Conv3dProblemVector(),
|
| 539 |
+
bool non_packed_test = false) {
|
| 540 |
+
|
| 541 |
+
bool passed = true;
|
| 542 |
+
|
| 543 |
+
//
|
| 544 |
+
// Testbed object
|
| 545 |
+
//
|
| 546 |
+
|
| 547 |
+
TestbedConv3dWithBroadcast<ImplicitGemm, ReferenceOp, AddBroadcastFirst> testbed;
|
| 548 |
+
|
| 549 |
+
//
|
| 550 |
+
// Get conv problem sizes to run conv operator
|
| 551 |
+
//
|
| 552 |
+
TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
|
| 553 |
+
|
| 554 |
+
// Vector of conv3d problem sizes to avoid duplicate runs
|
| 555 |
+
Conv3dProblemVector conv_tested_sizes;
|
| 556 |
+
|
| 557 |
+
Conv3dProblemVector const *problem_vectors[] = {
|
| 558 |
+
&conv3d_problems.conv3d_default_sizes,
|
| 559 |
+
&conv3d_problems.conv3d_vnet_medical_sizes,
|
| 560 |
+
&conv_test_sizes
|
| 561 |
+
};
|
| 562 |
+
|
| 563 |
+
// Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 564 |
+
for (Conv3dProblemVector const * problem_vector : problem_vectors) {
|
| 565 |
+
|
| 566 |
+
// Run conv testbed on default convolution sizes
|
| 567 |
+
for(auto conv_problem : *problem_vector) {
|
| 568 |
+
|
| 569 |
+
// Skip blacklist and avoid duplicate problem sizes
|
| 570 |
+
if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
|
| 571 |
+
std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
|
| 572 |
+
continue;
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
//
|
| 576 |
+
// Procedurally disable certain cases
|
| 577 |
+
//
|
| 578 |
+
|
| 579 |
+
// CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
|
| 580 |
+
if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
|
| 581 |
+
ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
|
| 582 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 583 |
+
cutlass::conv::StrideSupport::kUnity)) {
|
| 584 |
+
if (!((conv_problem.stride_d == 1) &&
|
| 585 |
+
(conv_problem.stride_h == 1) &&
|
| 586 |
+
(conv_problem.stride_w == 1))
|
| 587 |
+
) {
|
| 588 |
+
continue;
|
| 589 |
+
}
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
#if 0 // relax restrictions on analytic strided dgrad
|
| 593 |
+
// CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2}
|
| 594 |
+
if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
|
| 595 |
+
ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
|
| 596 |
+
(ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
|
| 597 |
+
cutlass::conv::StrideSupport::kStrided)) {
|
| 598 |
+
if (((conv_problem.stride_d == 1) && (conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
|
| 599 |
+
continue;
|
| 600 |
+
}
|
| 601 |
+
}
|
| 602 |
+
#endif
|
| 603 |
+
|
| 604 |
+
//
|
| 605 |
+
// Test
|
| 606 |
+
//
|
| 607 |
+
// push back tested problem size to avoid re-running duplicates
|
| 608 |
+
conv_tested_sizes.push_back(conv_problem);
|
| 609 |
+
|
| 610 |
+
// test mode = xcross
|
| 611 |
+
passed = testbed.run(
|
| 612 |
+
conv_problem,
|
| 613 |
+
cutlass::conv::SplitKMode::kSerial, non_packed_test);
|
| 614 |
+
|
| 615 |
+
if (!passed) {
|
| 616 |
+
return false;
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
// test mode = convolution
|
| 620 |
+
passed = testbed.run(
|
| 621 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 622 |
+
cutlass::conv::SplitKMode::kSerial, non_packed_test);
|
| 623 |
+
|
| 624 |
+
if (!passed) {
|
| 625 |
+
return false;
|
| 626 |
+
}
|
| 627 |
+
}
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
if (!TestSplitK)
|
| 631 |
+
return passed;
|
| 632 |
+
|
| 633 |
+
// Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
|
| 634 |
+
// a single conv3d problem size. Convolution unit tests take a long time to run so only sweep parameters
|
| 635 |
+
// which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
|
| 636 |
+
// alpha and beta for local testing, but only runs one value for alpha and beta.
|
| 637 |
+
cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size (
|
| 638 |
+
{1, 8, 8, 8, 32}, // input size (NDHWC)
|
| 639 |
+
{32, 3, 3, 3, 32}, // filter size (KTRSC)
|
| 640 |
+
cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
|
| 641 |
+
cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
|
| 642 |
+
cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
|
| 643 |
+
);
|
| 644 |
+
|
| 645 |
+
cutlass::conv::SplitKMode split_k_modes [] = {
|
| 646 |
+
cutlass::conv::SplitKMode::kSerial
|
| 647 |
+
};
|
| 648 |
+
|
| 649 |
+
int split_k_slices[] = {
|
| 650 |
+
1, 2, 3, 4, 201
|
| 651 |
+
};
|
| 652 |
+
|
| 653 |
+
double problem_alpha[] = {
|
| 654 |
+
2.0
|
| 655 |
+
};
|
| 656 |
+
|
| 657 |
+
double problem_beta[] = {
|
| 658 |
+
2.0
|
| 659 |
+
};
|
| 660 |
+
|
| 661 |
+
for (auto split_k_mode : split_k_modes) {
|
| 662 |
+
for (auto split_k_slice : split_k_slices) {
|
| 663 |
+
for (auto alpha : problem_alpha) {
|
| 664 |
+
for (auto beta : problem_beta) {
|
| 665 |
+
|
| 666 |
+
passed = testbed.run(
|
| 667 |
+
conv3d_split_k_test_size.reset_split_k_slices(split_k_slice),
|
| 668 |
+
split_k_mode,
|
| 669 |
+
false,/*non_packed_test*/
|
| 670 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
|
| 671 |
+
cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
|
| 672 |
+
|
| 673 |
+
if (!passed) {
|
| 674 |
+
return false;
|
| 675 |
+
}
|
| 676 |
+
}
|
| 677 |
+
}
|
| 678 |
+
}
|
| 679 |
+
}
|
| 680 |
+
|
| 681 |
+
return passed;
|
| 682 |
+
}
|
| 683 |
+
|
| 684 |
+
template <typename ImplicitGemm,
|
| 685 |
+
typename ReferenceOp = Conv3dWithBroadcastReferenceOp<ImplicitGemm>,
|
| 686 |
+
bool AddBroadcastFirst = false>
|
| 687 |
+
bool TestSpecificConv3dWithBroadcast(
|
| 688 |
+
const Conv3dProblemVector & problem_sizes,
|
| 689 |
+
bool non_packed_test = false) {
|
| 690 |
+
|
| 691 |
+
bool passed = true;
|
| 692 |
+
|
| 693 |
+
//
|
| 694 |
+
// Testbed object
|
| 695 |
+
//
|
| 696 |
+
|
| 697 |
+
TestbedConv3dWithBroadcast<ImplicitGemm, ReferenceOp, AddBroadcastFirst> testbed;
|
| 698 |
+
|
| 699 |
+
// Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 700 |
+
for(auto conv_problem : problem_sizes) {
|
| 701 |
+
|
| 702 |
+
//
|
| 703 |
+
// Test
|
| 704 |
+
//
|
| 705 |
+
|
| 706 |
+
// test mode = xcross, non_packed_test = false
|
| 707 |
+
passed = testbed.run(
|
| 708 |
+
conv_problem,
|
| 709 |
+
cutlass::conv::SplitKMode::kSerial, non_packed_test);
|
| 710 |
+
|
| 711 |
+
if (!passed) {
|
| 712 |
+
return false;
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
// test mode = convolution, non_packed_test = false
|
| 716 |
+
passed = testbed.run(
|
| 717 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 718 |
+
cutlass::conv::SplitKMode::kSerial, non_packed_test);
|
| 719 |
+
|
| 720 |
+
if (!passed) {
|
| 721 |
+
return false;
|
| 722 |
+
}
|
| 723 |
+
}
|
| 724 |
+
|
| 725 |
+
return true;
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 729 |
+
|
| 730 |
+
} // namespace device
|
| 731 |
+
} // namespace conv
|
| 732 |
+
} // namespace test
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Depthwise Direct Conv testbed
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include <fstream>
|
| 37 |
+
|
| 38 |
+
#include "../../common/cutlass_unit_test.h"
|
| 39 |
+
#include "../cache_testbed_output.h"
|
| 40 |
+
#include "conv2d_problems.h"
|
| 41 |
+
#include "cutlass/conv/device/direct_convolution.h"
|
| 42 |
+
|
| 43 |
+
#include "cutlass/core_io.h"
|
| 44 |
+
#include "cutlass/cutlass.h"
|
| 45 |
+
#include "cutlass/util/host_tensor.h"
|
| 46 |
+
#include "cutlass/util/reference/device/convolution.h"
|
| 47 |
+
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 48 |
+
#include "cutlass/util/reference/host/convolution.h"
|
| 49 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 50 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 51 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 52 |
+
|
| 53 |
+
namespace test {
|
| 54 |
+
namespace conv {
|
| 55 |
+
namespace device {
|
| 56 |
+
|
| 57 |
+
template <typename Conv2d>
|
| 58 |
+
class TestbedDepthwiseDirectConv2d {
|
| 59 |
+
public:
|
| 60 |
+
|
| 61 |
+
using ElementA = typename Conv2d::ElementA;
|
| 62 |
+
using LayoutA = typename Conv2d::LayoutA;
|
| 63 |
+
using ElementB = typename Conv2d::ElementB;
|
| 64 |
+
using LayoutB = typename Conv2d::LayoutB;
|
| 65 |
+
using ElementC = typename Conv2d::ElementC;
|
| 66 |
+
using LayoutC = typename Conv2d::LayoutC;
|
| 67 |
+
using ElementAccumulator = typename Conv2d::ElementAccumulator;
|
| 68 |
+
using ElementCompute = typename Conv2d::ElementCompute;
|
| 69 |
+
using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
|
| 70 |
+
|
| 71 |
+
static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
|
| 72 |
+
|
| 73 |
+
public:
|
| 74 |
+
/// Initialization
|
| 75 |
+
cutlass::Distribution::Kind init_A;
|
| 76 |
+
cutlass::Distribution::Kind init_B;
|
| 77 |
+
cutlass::Distribution::Kind init_C;
|
| 78 |
+
uint64_t seed;
|
| 79 |
+
|
| 80 |
+
cutlass::HostTensor<ElementA, LayoutA> tensor_A;
|
| 81 |
+
cutlass::HostTensor<ElementB, LayoutB> tensor_B;
|
| 82 |
+
cutlass::HostTensor<ElementB, LayoutB> tensor_reordered_B;
|
| 83 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_C;
|
| 84 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
|
| 85 |
+
cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
|
| 86 |
+
|
| 87 |
+
int tested_problem_count;
|
| 88 |
+
|
| 89 |
+
public:
|
| 90 |
+
TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
|
| 91 |
+
cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
|
| 92 |
+
cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
|
| 93 |
+
uint64_t seed_ = 2080)
|
| 94 |
+
: init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {}
|
| 95 |
+
|
| 96 |
+
/// Helper to initialize a tensor view
|
| 97 |
+
template <typename Element, typename Layout>
|
| 98 |
+
void initialize_tensor(cutlass::TensorView<Element, Layout> view,
|
| 99 |
+
cutlass::Distribution::Kind dist_kind,
|
| 100 |
+
uint64_t seed) {
|
| 101 |
+
if (dist_kind == cutlass::Distribution::Uniform) {
|
| 102 |
+
int scope;
|
| 103 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 104 |
+
|
| 105 |
+
if (bits <= 8) {
|
| 106 |
+
scope = 2;
|
| 107 |
+
} else if (bits == 16) {
|
| 108 |
+
if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
|
| 109 |
+
scope = 3;
|
| 110 |
+
} else {
|
| 111 |
+
scope = 5;
|
| 112 |
+
}
|
| 113 |
+
} else {
|
| 114 |
+
scope = 8;
|
| 115 |
+
}
|
| 116 |
+
cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0);
|
| 117 |
+
} else if (dist_kind == cutlass::Distribution::Identity) {
|
| 118 |
+
cutlass::reference::host::TensorFillIdentity(view);
|
| 119 |
+
|
| 120 |
+
} else if (dist_kind == cutlass::Distribution::Gaussian) {
|
| 121 |
+
cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
|
| 122 |
+
} else if (dist_kind == cutlass::Distribution::Sequential) {
|
| 123 |
+
cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
|
| 124 |
+
} else {
|
| 125 |
+
}
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
|
| 129 |
+
tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
|
| 130 |
+
tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
|
| 131 |
+
tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
|
| 132 |
+
tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 133 |
+
tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 134 |
+
tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
|
| 135 |
+
|
| 136 |
+
initialize_tensor(tensor_A.host_view(), init_A, seed);
|
| 137 |
+
initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
|
| 138 |
+
initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17);
|
| 139 |
+
initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
|
| 140 |
+
|
| 141 |
+
tensor_A.sync_device();
|
| 142 |
+
tensor_B.sync_device();
|
| 143 |
+
tensor_reordered_B.sync_device();
|
| 144 |
+
tensor_C.sync_device();
|
| 145 |
+
tensor_D_computed.sync_device();
|
| 146 |
+
tensor_D_reference.sync_device();
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
bool sufficient(int smem_size) const {
|
| 150 |
+
//
|
| 151 |
+
// Determine SMEM requirements and waive if not satisfied
|
| 152 |
+
//
|
| 153 |
+
|
| 154 |
+
cudaDeviceProp properties;
|
| 155 |
+
int device_idx;
|
| 156 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 157 |
+
|
| 158 |
+
if (result != cudaSuccess) {
|
| 159 |
+
throw std::runtime_error("cudaGetDevice() API call failed.");
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
result = cudaGetDeviceProperties(&properties, device_idx);
|
| 163 |
+
|
| 164 |
+
if (result != cudaSuccess) {
|
| 165 |
+
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
if (properties.sharedMemPerBlockOptin < static_cast<size_t>(smem_size)) {
|
| 169 |
+
return false;
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
return true;
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
/// Executes one test
|
| 176 |
+
bool run(cutlass::conv::Conv2dProblemSize const &problem_size,
|
| 177 |
+
cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
|
| 178 |
+
ElementCompute alpha = ElementCompute(1.5),
|
| 179 |
+
ElementCompute beta = ElementCompute(1)) {
|
| 180 |
+
// increment tested problem count run by the testbed
|
| 181 |
+
tested_problem_count++;
|
| 182 |
+
|
| 183 |
+
#if 0 // display conv2d problem size for debugging
|
| 184 |
+
std::cout << problem_size << std::endl
|
| 185 |
+
<< "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
|
| 186 |
+
<< "split_k_mode: "
|
| 187 |
+
<< ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)")
|
| 188 |
+
<< std::endl
|
| 189 |
+
<< std::endl;
|
| 190 |
+
#endif
|
| 191 |
+
|
| 192 |
+
initialize(problem_size);
|
| 193 |
+
|
| 194 |
+
// configure the operator
|
| 195 |
+
Conv2d conv2d_op;
|
| 196 |
+
|
| 197 |
+
typename Conv2d::Arguments conv2d_args(problem_size,
|
| 198 |
+
tensor_A.device_ref(),
|
| 199 |
+
tensor_B.device_ref(),
|
| 200 |
+
tensor_C.device_ref(),
|
| 201 |
+
tensor_D_computed.device_ref(),
|
| 202 |
+
{alpha, beta},
|
| 203 |
+
tensor_reordered_B.device_ref(),
|
| 204 |
+
split_k_mode);
|
| 205 |
+
|
| 206 |
+
// find workspace requirement for parallel split-k reduction
|
| 207 |
+
size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
|
| 208 |
+
|
| 209 |
+
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
| 210 |
+
|
| 211 |
+
cutlass::Status status = conv2d_op.can_implement(problem_size);
|
| 212 |
+
|
| 213 |
+
if (status != cutlass::Status::kSuccess) {
|
| 214 |
+
cudaError_t error = cudaGetLastError();
|
| 215 |
+
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
|
| 216 |
+
return true;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
status = conv2d_op.initialize(conv2d_args, workspace.get());
|
| 220 |
+
|
| 221 |
+
if (status != cutlass::Status::kSuccess) {
|
| 222 |
+
cudaError_t error = cudaGetLastError();
|
| 223 |
+
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
|
| 224 |
+
return true;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
if (!sufficient(conv2d_op.get_smem_size())) {
|
| 228 |
+
if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
|
| 229 |
+
std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
|
| 230 |
+
}
|
| 231 |
+
return true;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
// run conv2d operator
|
| 235 |
+
status = conv2d_op();
|
| 236 |
+
|
| 237 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 238 |
+
if (status != cutlass::Status::kSuccess) {
|
| 239 |
+
std::cerr << "Failed to run." << std::endl;
|
| 240 |
+
return false;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
bool passed = false;
|
| 244 |
+
|
| 245 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 246 |
+
EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result);
|
| 247 |
+
|
| 248 |
+
tensor_D_computed.sync_host();
|
| 249 |
+
|
| 250 |
+
//
|
| 251 |
+
// Reference check - support caching results
|
| 252 |
+
//
|
| 253 |
+
|
| 254 |
+
CachedTestKey cached_test_key =
|
| 255 |
+
CreateCachedConv2dTestKey<ElementA,
|
| 256 |
+
LayoutA,
|
| 257 |
+
ElementB,
|
| 258 |
+
LayoutB,
|
| 259 |
+
ElementC,
|
| 260 |
+
LayoutC,
|
| 261 |
+
ElementAccumulator,
|
| 262 |
+
ElementCompute>(kConvolutionalOperator,
|
| 263 |
+
problem_size,
|
| 264 |
+
alpha,
|
| 265 |
+
beta,
|
| 266 |
+
tensor_A.host_view(),
|
| 267 |
+
tensor_B.host_view(),
|
| 268 |
+
tensor_C.host_view());
|
| 269 |
+
|
| 270 |
+
//
|
| 271 |
+
// Look for the cached key
|
| 272 |
+
//
|
| 273 |
+
|
| 274 |
+
bool cached_result_loaded = false;
|
| 275 |
+
CachedTestResult cached_test_result;
|
| 276 |
+
|
| 277 |
+
std::string conv2d_result_cache_name =
|
| 278 |
+
std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
|
| 279 |
+
|
| 280 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 281 |
+
|
| 282 |
+
CachedTestResultListing cached_results(conv2d_result_cache_name);
|
| 283 |
+
|
| 284 |
+
auto cached = cached_results.find(cached_test_key);
|
| 285 |
+
|
| 286 |
+
cached_result_loaded = cached.first;
|
| 287 |
+
if (cached_result_loaded) {
|
| 288 |
+
cached_test_result = cached.second;
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
if (!cached_result_loaded) {
|
| 293 |
+
#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
|
| 294 |
+
|
| 295 |
+
cutlass::reference::device::Conv2d<ElementA,
|
| 296 |
+
LayoutA,
|
| 297 |
+
ElementB,
|
| 298 |
+
LayoutB,
|
| 299 |
+
ElementC,
|
| 300 |
+
LayoutC,
|
| 301 |
+
ElementCompute,
|
| 302 |
+
ElementAccumulator>(kConvolutionalOperator,
|
| 303 |
+
problem_size,
|
| 304 |
+
tensor_A.device_ref(),
|
| 305 |
+
tensor_B.device_ref(),
|
| 306 |
+
tensor_C.device_ref(),
|
| 307 |
+
tensor_D_reference.device_ref(),
|
| 308 |
+
alpha,
|
| 309 |
+
beta);
|
| 310 |
+
|
| 311 |
+
// sync host (copy device data to host) for dumping error output in case of mismatches
|
| 312 |
+
tensor_D_reference.sync_host();
|
| 313 |
+
|
| 314 |
+
#else
|
| 315 |
+
|
| 316 |
+
cutlass::reference::host::Conv2d<ElementA,
|
| 317 |
+
LayoutA,
|
| 318 |
+
ElementB,
|
| 319 |
+
LayoutB,
|
| 320 |
+
ElementC,
|
| 321 |
+
LayoutC,
|
| 322 |
+
ElementCompute,
|
| 323 |
+
ElementAccumulator>(kConvolutionalOperator,
|
| 324 |
+
problem_size,
|
| 325 |
+
tensor_A.host_ref(),
|
| 326 |
+
tensor_B.host_ref(),
|
| 327 |
+
tensor_C.host_ref(),
|
| 328 |
+
tensor_D_reference.host_ref(),
|
| 329 |
+
alpha,
|
| 330 |
+
beta);
|
| 331 |
+
|
| 332 |
+
#endif
|
| 333 |
+
|
| 334 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 335 |
+
|
| 336 |
+
cached_test_result.D = TensorHash(tensor_D_reference.host_view());
|
| 337 |
+
|
| 338 |
+
CachedTestResultListing cached_results(conv2d_result_cache_name);
|
| 339 |
+
|
| 340 |
+
cached_results.append(cached_test_key, cached_test_result);
|
| 341 |
+
cached_results.write(conv2d_result_cache_name);
|
| 342 |
+
}
|
| 343 |
+
} // if (!cached_result_loaded)
|
| 344 |
+
|
| 345 |
+
uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
|
| 346 |
+
|
| 347 |
+
if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
|
| 348 |
+
passed = (tensor_D_hash == cached_test_result.D);
|
| 349 |
+
|
| 350 |
+
EXPECT_EQ(tensor_D_hash, cached_test_result.D)
|
| 351 |
+
<< "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
|
| 352 |
+
}
|
| 353 |
+
else {
|
| 354 |
+
|
| 355 |
+
passed = cutlass::reference::host::TensorEquals(
|
| 356 |
+
tensor_D_computed.host_view(),
|
| 357 |
+
tensor_D_reference.host_view());
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
EXPECT_TRUE(passed);
|
| 361 |
+
|
| 362 |
+
std::stringstream ss_problem_size_text;
|
| 363 |
+
ss_problem_size_text << "nhwc_"
|
| 364 |
+
<< problem_size.N << "x"
|
| 365 |
+
<< problem_size.H << "x"
|
| 366 |
+
<< problem_size.W << "x"
|
| 367 |
+
<< problem_size.C
|
| 368 |
+
<< "_krsc_"
|
| 369 |
+
<< problem_size.K << "x"
|
| 370 |
+
<< problem_size.R << "x"
|
| 371 |
+
<< problem_size.S << "x"
|
| 372 |
+
<< problem_size.C
|
| 373 |
+
<< "_padding_"
|
| 374 |
+
<< problem_size.pad_h << "x"
|
| 375 |
+
<< problem_size.pad_w
|
| 376 |
+
<< "_stride_"
|
| 377 |
+
<< problem_size.stride_h << "x"
|
| 378 |
+
<< problem_size.stride_w
|
| 379 |
+
<< "_dilation_"
|
| 380 |
+
<< problem_size.dilation_h << "x"
|
| 381 |
+
<< problem_size.dilation_w << "_"
|
| 382 |
+
<< (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_");
|
| 383 |
+
|
| 384 |
+
if (!passed) {
|
| 385 |
+
std::stringstream fname;
|
| 386 |
+
|
| 387 |
+
fname << "error_Conv2d_DirectConv_device_"
|
| 388 |
+
<< (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
|
| 389 |
+
<< (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
|
| 390 |
+
(Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_"))
|
| 391 |
+
<< ss_problem_size_text.str()
|
| 392 |
+
<< Conv2d::ThreadblockShape::kM << "x"
|
| 393 |
+
<< Conv2d::ThreadblockShape::kN << "x"
|
| 394 |
+
<< Conv2d::ThreadblockShape::kK << "_"
|
| 395 |
+
<< Conv2d::WarpShape::kM << "x"
|
| 396 |
+
<< Conv2d::WarpShape::kN << "x"
|
| 397 |
+
<< Conv2d::WarpShape::kK << ".txt";
|
| 398 |
+
|
| 399 |
+
std::cout << fname.str() << std::endl;
|
| 400 |
+
|
| 401 |
+
std::ofstream results(fname.str());
|
| 402 |
+
|
| 403 |
+
results << problem_size << std::endl;
|
| 404 |
+
|
| 405 |
+
results
|
| 406 |
+
<< "\nA:\n" << tensor_A.host_view() << "\n"
|
| 407 |
+
<< "\nB:\n" << tensor_B.host_view() << "\n"
|
| 408 |
+
<< "\nC:\n" << tensor_C.host_view() << "\n";
|
| 409 |
+
|
| 410 |
+
results << "\nD reference (hash: " << cached_test_result.D << ")\n";
|
| 411 |
+
|
| 412 |
+
if (!cached_result_loaded) {
|
| 413 |
+
results
|
| 414 |
+
<< tensor_D_reference.host_view() << "\n";
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
results
|
| 418 |
+
<< "\nD computed (hash: " << tensor_D_hash << ")\n"
|
| 419 |
+
<< tensor_D_computed.host_view() << "\n";
|
| 420 |
+
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
return passed;
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
};
|
| 427 |
+
|
| 428 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 429 |
+
|
| 430 |
+
template <typename DirectConv>
|
| 431 |
+
bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) {
|
| 432 |
+
bool passed = true;
|
| 433 |
+
|
| 434 |
+
//
|
| 435 |
+
// Testbed object
|
| 436 |
+
//
|
| 437 |
+
TestbedDepthwiseDirectConv2d<DirectConv> testbed;
|
| 438 |
+
|
| 439 |
+
// Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
|
| 440 |
+
for (auto conv_problem : problem_sizes) {
|
| 441 |
+
//
|
| 442 |
+
// Test
|
| 443 |
+
//
|
| 444 |
+
|
| 445 |
+
// test mode = xcross
|
| 446 |
+
passed = testbed.run(
|
| 447 |
+
conv_problem,
|
| 448 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 449 |
+
|
| 450 |
+
if (!passed) {
|
| 451 |
+
return false;
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
// test mode = convolution
|
| 455 |
+
passed = testbed.run(
|
| 456 |
+
conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
|
| 457 |
+
cutlass::conv::SplitKMode::kSerial);
|
| 458 |
+
|
| 459 |
+
if (!passed) {
|
| 460 |
+
return false;
|
| 461 |
+
}
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
return true;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 468 |
+
|
| 469 |
+
} // namespace device
|
| 470 |
+
} // namespace conv
|
| 471 |
+
} // namespace test
|
| 472 |
+
|
| 473 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp
ADDED
|
@@ -0,0 +1,1385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief CUTLASS 3.x Implicit GEMM testbed sizes for ConvNd problem
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include "cutlass/conv/convnd_problem_shape.hpp"
|
| 37 |
+
#include <vector>
|
| 38 |
+
|
| 39 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 40 |
+
|
| 41 |
+
namespace test::conv::device {
|
| 42 |
+
|
| 43 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 44 |
+
|
| 45 |
+
template<int SpatialDim, cutlass::conv::Operator ConvOp, bool SupportStrides = (ConvOp != cutlass::conv::Operator::kDgrad)>
|
| 46 |
+
std::vector<cutlass::conv::ConvProblemShape<ConvOp, SpatialDim>>
|
| 47 |
+
inline
|
| 48 |
+
get_conv_problem_vector();
|
| 49 |
+
|
| 50 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 51 |
+
// Fprop
|
| 52 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 53 |
+
|
| 54 |
+
// Specialization for 1D fprop problems
|
| 55 |
+
template<>
|
| 56 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 1>> inline
|
| 57 |
+
get_conv_problem_vector<1, cutlass::conv::Operator::kFprop>() {
|
| 58 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 1>;
|
| 59 |
+
std::vector<ProblemShape> problem_shapes;
|
| 60 |
+
problem_shapes.push_back({
|
| 61 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 62 |
+
{1, 8, 64}, // nwc
|
| 63 |
+
{64, 1, 64}, // ksc
|
| 64 |
+
{0}, // padding lower (pad_w)
|
| 65 |
+
{0}, // padding upper (pad_w)
|
| 66 |
+
{1}, // stride (stride_w)
|
| 67 |
+
{1}, // dilation (dilation_w)
|
| 68 |
+
1 // group
|
| 69 |
+
});
|
| 70 |
+
// non-packed input strides.
|
| 71 |
+
problem_shapes.push_back({
|
| 72 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 73 |
+
{1, 8, 64}, // nwc
|
| 74 |
+
{800, 80, 1}, // stride (nwc)
|
| 75 |
+
{64, 1, 64}, // ksc
|
| 76 |
+
{64, 64, 1}, // stride (ksc)
|
| 77 |
+
{0}, // padding lower (pad_w)
|
| 78 |
+
{0}, // padding upper (pad_w)
|
| 79 |
+
{1}, // stride (stride_w)
|
| 80 |
+
{1}, // dilation (dilation_w)
|
| 81 |
+
1 // group
|
| 82 |
+
});
|
| 83 |
+
// non-packed output strides.
|
| 84 |
+
problem_shapes.push_back({
|
| 85 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 86 |
+
{1, 8, 64}, // nwc
|
| 87 |
+
{512, 64, 1}, // stride (nwc)
|
| 88 |
+
{64, 1, 64}, // ksc
|
| 89 |
+
{64, 64, 1}, // stride (ksc)
|
| 90 |
+
{800, 80, 1}, // stride (nqk)
|
| 91 |
+
{0}, // padding lower (pad_w)
|
| 92 |
+
{0}, // padding upper (pad_w)
|
| 93 |
+
{1}, // stride (stride_w)
|
| 94 |
+
{1}, // dilation (dilation_w)
|
| 95 |
+
1 // group
|
| 96 |
+
});
|
| 97 |
+
// Filter-K = 16 for predication
|
| 98 |
+
problem_shapes.push_back({
|
| 99 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 100 |
+
{1, 8, 64},
|
| 101 |
+
{16,1, 64},
|
| 102 |
+
{0},
|
| 103 |
+
{0},
|
| 104 |
+
{1},
|
| 105 |
+
{1},
|
| 106 |
+
1
|
| 107 |
+
});
|
| 108 |
+
// N = 2 and K = 128 for a larger grid
|
| 109 |
+
problem_shapes.push_back({
|
| 110 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 111 |
+
{2, 8, 64},
|
| 112 |
+
{96, 1, 64},
|
| 113 |
+
{0},
|
| 114 |
+
{0},
|
| 115 |
+
{1},
|
| 116 |
+
{1},
|
| 117 |
+
1
|
| 118 |
+
});
|
| 119 |
+
// N = 7 and K = 256 for a even larger grid
|
| 120 |
+
problem_shapes.push_back({
|
| 121 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 122 |
+
{7, 8, 64},
|
| 123 |
+
{256, 1, 64},
|
| 124 |
+
{0},
|
| 125 |
+
{0},
|
| 126 |
+
{1},
|
| 127 |
+
{1},
|
| 128 |
+
1
|
| 129 |
+
});
|
| 130 |
+
// 3 filter, no padding
|
| 131 |
+
problem_shapes.push_back({
|
| 132 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 133 |
+
{2, 8, 64},
|
| 134 |
+
{256, 3, 64},
|
| 135 |
+
{0},
|
| 136 |
+
{0},
|
| 137 |
+
{1},
|
| 138 |
+
{1},
|
| 139 |
+
1
|
| 140 |
+
});
|
| 141 |
+
// 3 filter, symmetric padding with c % cta_k !=0
|
| 142 |
+
problem_shapes.push_back({
|
| 143 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 144 |
+
{2, 8, 32},
|
| 145 |
+
{256, 3, 32},
|
| 146 |
+
{1},
|
| 147 |
+
{1},
|
| 148 |
+
{1},
|
| 149 |
+
{1},
|
| 150 |
+
1
|
| 151 |
+
});
|
| 152 |
+
// 4 filter, asymmetric padding
|
| 153 |
+
problem_shapes.push_back({
|
| 154 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 155 |
+
{2, 8, 64},
|
| 156 |
+
{256, 4, 64},
|
| 157 |
+
{0},
|
| 158 |
+
{1},
|
| 159 |
+
{1},
|
| 160 |
+
{1},
|
| 161 |
+
1
|
| 162 |
+
});
|
| 163 |
+
// 3 filter, asymmetric padding and tstride of 2
|
| 164 |
+
problem_shapes.push_back({
|
| 165 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 166 |
+
{2, 8, 64},
|
| 167 |
+
{256, 3, 64},
|
| 168 |
+
{0},
|
| 169 |
+
{1},
|
| 170 |
+
{2},
|
| 171 |
+
{1},
|
| 172 |
+
1
|
| 173 |
+
});
|
| 174 |
+
// 3 filter, asymmetric padding and dilation of 2
|
| 175 |
+
problem_shapes.push_back({
|
| 176 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 177 |
+
{2, 8, 64},
|
| 178 |
+
{256, 3, 64},
|
| 179 |
+
{0},
|
| 180 |
+
{1},
|
| 181 |
+
{1},
|
| 182 |
+
{2},
|
| 183 |
+
1
|
| 184 |
+
});
|
| 185 |
+
return problem_shapes;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
// Specialization for 2D fprop problems
|
| 189 |
+
template<>
|
| 190 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 2>> inline
|
| 191 |
+
get_conv_problem_vector<2, cutlass::conv::Operator::kFprop>() {
|
| 192 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 2>;
|
| 193 |
+
std::vector<ProblemShape> problem_shapes;
|
| 194 |
+
problem_shapes.push_back({
|
| 195 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 196 |
+
{1, 8, 8, 64}, // nhwc
|
| 197 |
+
{64, 1, 1, 64}, // krsc
|
| 198 |
+
{0, 0}, // padding lower (pad_h, pad_w)
|
| 199 |
+
{0, 0}, // padding upper (pad_h, pad_w)
|
| 200 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 201 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 202 |
+
1 // group
|
| 203 |
+
});
|
| 204 |
+
// non-packed input strides.
|
| 205 |
+
problem_shapes.push_back({
|
| 206 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 207 |
+
{1, 8, 8, 64}, // nhwc
|
| 208 |
+
{8000, 800, 80, 1}, // stride (nhwc)
|
| 209 |
+
{64, 1, 1, 64}, // krsc
|
| 210 |
+
{64, 64, 64, 1}, // stride (krsc)
|
| 211 |
+
{0, 0}, // padding lower (pad_h, pad_w)
|
| 212 |
+
{0, 0}, // padding upper (pad_h, pad_w)
|
| 213 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 214 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 215 |
+
1 // group
|
| 216 |
+
});
|
| 217 |
+
// non-packed output strides.
|
| 218 |
+
problem_shapes.push_back({
|
| 219 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 220 |
+
{1, 8, 8, 64}, // nhwc
|
| 221 |
+
{4096, 512, 64, 1}, // stride (nhwc)
|
| 222 |
+
{64, 1, 1, 64}, // krsc
|
| 223 |
+
{64, 64, 64, 1}, // stride (krsc)
|
| 224 |
+
{8000, 800, 80, 1}, // stride (npqk)
|
| 225 |
+
{0, 0}, // padding lower (pad_h, pad_w)
|
| 226 |
+
{0, 0}, // padding upper (pad_h, pad_w)
|
| 227 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 228 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 229 |
+
1 // group
|
| 230 |
+
});
|
| 231 |
+
// Filter-K = 16 for predication
|
| 232 |
+
problem_shapes.push_back({
|
| 233 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 234 |
+
{1, 8, 8, 64},
|
| 235 |
+
{16, 1, 1, 64},
|
| 236 |
+
{0, 0},
|
| 237 |
+
{0, 0},
|
| 238 |
+
{1, 1},
|
| 239 |
+
{1, 1},
|
| 240 |
+
1
|
| 241 |
+
});
|
| 242 |
+
// N = 2 and K = 128 for a larger grid
|
| 243 |
+
problem_shapes.push_back({
|
| 244 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 245 |
+
{2, 8, 8, 64},
|
| 246 |
+
{96, 1, 1, 64},
|
| 247 |
+
{0, 0},
|
| 248 |
+
{0, 0},
|
| 249 |
+
{1, 1},
|
| 250 |
+
{1, 1},
|
| 251 |
+
1
|
| 252 |
+
});
|
| 253 |
+
// N = 7 and K = 256 for a even larger grid
|
| 254 |
+
problem_shapes.push_back({
|
| 255 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 256 |
+
{7, 8, 8, 64},
|
| 257 |
+
{256, 1, 1, 64},
|
| 258 |
+
{0, 0},
|
| 259 |
+
{0, 0},
|
| 260 |
+
{1, 1},
|
| 261 |
+
{1, 1},
|
| 262 |
+
1
|
| 263 |
+
});
|
| 264 |
+
// 3x3 filter, no padding
|
| 265 |
+
problem_shapes.push_back({
|
| 266 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 267 |
+
{2, 8, 8, 64},
|
| 268 |
+
{256, 3, 3, 64},
|
| 269 |
+
{0, 0},
|
| 270 |
+
{0, 0},
|
| 271 |
+
{1, 1},
|
| 272 |
+
{1, 1},
|
| 273 |
+
1
|
| 274 |
+
});
|
| 275 |
+
// 3x3 filter, symmetric padding with c % cta_k !=0
|
| 276 |
+
problem_shapes.push_back({
|
| 277 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 278 |
+
{2, 8, 8, 32},
|
| 279 |
+
{256, 3, 3, 32},
|
| 280 |
+
{1, 1},
|
| 281 |
+
{1, 1},
|
| 282 |
+
{1, 1},
|
| 283 |
+
{1, 1},
|
| 284 |
+
1
|
| 285 |
+
});
|
| 286 |
+
// 2x5 filter, asymmetric padding 1,2/1,2
|
| 287 |
+
problem_shapes.push_back({
|
| 288 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 289 |
+
{2, 8, 8, 64},
|
| 290 |
+
{256, 2, 5, 64},
|
| 291 |
+
{1, 1},
|
| 292 |
+
{2, 2},
|
| 293 |
+
{1, 1},
|
| 294 |
+
{1, 1},
|
| 295 |
+
1
|
| 296 |
+
});
|
| 297 |
+
// 2x5 filter, asymmetric padding 1,0/1,0, w/ stride
|
| 298 |
+
problem_shapes.push_back({
|
| 299 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 300 |
+
{2, 7, 7, 64},
|
| 301 |
+
{256, 2, 5, 64},
|
| 302 |
+
{1, 1},
|
| 303 |
+
{0, 0},
|
| 304 |
+
{2, 3},
|
| 305 |
+
{1, 1},
|
| 306 |
+
1
|
| 307 |
+
});
|
| 308 |
+
// 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
|
| 309 |
+
problem_shapes.push_back({
|
| 310 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 311 |
+
{2, 16, 16, 64},
|
| 312 |
+
{256, 2, 5, 64},
|
| 313 |
+
{1, 1},
|
| 314 |
+
{0, 0},
|
| 315 |
+
{1, 1},
|
| 316 |
+
{2, 3},
|
| 317 |
+
1
|
| 318 |
+
});
|
| 319 |
+
// 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation
|
| 320 |
+
problem_shapes.push_back({
|
| 321 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 322 |
+
{2, 16, 15, 64},
|
| 323 |
+
{256, 2, 5, 64},
|
| 324 |
+
{1, 1},
|
| 325 |
+
{0, 0},
|
| 326 |
+
{2, 3},
|
| 327 |
+
{2, 3},
|
| 328 |
+
1
|
| 329 |
+
});
|
| 330 |
+
return problem_shapes;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
// Specialization for 3D fprop problems
|
| 334 |
+
template<>
|
| 335 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 3>> inline
|
| 336 |
+
get_conv_problem_vector<3, cutlass::conv::Operator::kFprop>() {
|
| 337 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 3>;
|
| 338 |
+
std::vector<ProblemShape> problem_shapes;
|
| 339 |
+
problem_shapes.push_back({
|
| 340 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 341 |
+
{1, 1, 8, 8, 64}, // ndhwc
|
| 342 |
+
{64, 1, 1, 1, 64}, // ktrsc
|
| 343 |
+
{0, 0, 0}, // padding lower (pad_d, pad_h, pad_w)
|
| 344 |
+
{0, 0, 0}, // padding upper (pad_d, pad_h, pad_w)
|
| 345 |
+
{1, 1, 1}, // stride (stride_d, stride_h, stride_w)
|
| 346 |
+
{1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
|
| 347 |
+
1 // group
|
| 348 |
+
});
|
| 349 |
+
// non-packed input output strides.
|
| 350 |
+
problem_shapes.push_back({
|
| 351 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 352 |
+
{1, 1, 8, 8, 64}, // ndhwc
|
| 353 |
+
{8000, 8000, 800, 80, 1}, // stride (ndhwc)
|
| 354 |
+
{64, 1, 1, 1, 64}, // ktrsc
|
| 355 |
+
{64, 64, 64, 64, 1}, // stride (ktrsc)
|
| 356 |
+
{8000, 8000, 800, 80, 1}, // stride (nzpqk)
|
| 357 |
+
{0, 0, 0}, // padding lower (pad_d, pad_h, pad_w)
|
| 358 |
+
{0, 0, 0}, // padding upper (pad_d, pad_h, pad_w)
|
| 359 |
+
{1, 1, 1}, // stride (stride_d, stride_h, stride_w)
|
| 360 |
+
{1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
|
| 361 |
+
1 // group
|
| 362 |
+
});
|
| 363 |
+
// Filter-K = 16 for predication
|
| 364 |
+
problem_shapes.push_back({
|
| 365 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 366 |
+
{1, 1, 8, 8, 64},
|
| 367 |
+
{16, 1, 1, 1, 64},
|
| 368 |
+
{0, 0, 0},
|
| 369 |
+
{0, 0, 0},
|
| 370 |
+
{1, 1, 1},
|
| 371 |
+
{1, 1, 1},
|
| 372 |
+
1
|
| 373 |
+
});
|
| 374 |
+
// N = 7 and K = 256 for a larger grid
|
| 375 |
+
problem_shapes.push_back({
|
| 376 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 377 |
+
{2, 1, 8, 8, 64},
|
| 378 |
+
{96, 1, 1, 1, 64},
|
| 379 |
+
{0, 0, 0},
|
| 380 |
+
{0, 0, 0},
|
| 381 |
+
{1, 1, 1},
|
| 382 |
+
{1, 1, 1},
|
| 383 |
+
1
|
| 384 |
+
});
|
| 385 |
+
// Filter 3x3x3 + no padding
|
| 386 |
+
problem_shapes.push_back({
|
| 387 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 388 |
+
{2, 3, 5, 8, 64},
|
| 389 |
+
{96, 3, 3, 3, 64},
|
| 390 |
+
{0, 0, 0},
|
| 391 |
+
{0, 0, 0},
|
| 392 |
+
{1, 1, 1},
|
| 393 |
+
{1, 1, 1},
|
| 394 |
+
1
|
| 395 |
+
});
|
| 396 |
+
// Filter 3x3x3 + symmetric padding with c % cta_k !=0
|
| 397 |
+
problem_shapes.push_back({
|
| 398 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 399 |
+
{2, 3, 5, 8, 32},
|
| 400 |
+
{96, 3, 3, 3, 32},
|
| 401 |
+
{1, 1, 1},
|
| 402 |
+
{1, 1, 1},
|
| 403 |
+
{1, 1, 1},
|
| 404 |
+
{1, 1, 1},
|
| 405 |
+
1
|
| 406 |
+
});
|
| 407 |
+
// Filter 3x4x5 + symmetric padding 111
|
| 408 |
+
problem_shapes.push_back({
|
| 409 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 410 |
+
{2, 3, 5, 8, 64},
|
| 411 |
+
{96, 3, 4, 5, 64},
|
| 412 |
+
{1, 1, 1},
|
| 413 |
+
{1, 1, 1},
|
| 414 |
+
{1, 1, 1},
|
| 415 |
+
{1, 1, 1},
|
| 416 |
+
1
|
| 417 |
+
});
|
| 418 |
+
// Filter 3x4x5 + asymmetric padding 102/010
|
| 419 |
+
problem_shapes.push_back({
|
| 420 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 421 |
+
{2, 3, 5, 8, 64},
|
| 422 |
+
{96, 3, 4, 5, 64},
|
| 423 |
+
{1, 0, 1},
|
| 424 |
+
{0, 2, 0},
|
| 425 |
+
{1, 1, 1},
|
| 426 |
+
{1, 1, 1},
|
| 427 |
+
1
|
| 428 |
+
});
|
| 429 |
+
// Filter 3x4x5 + asymmetric padding 102/010, w/ stride
|
| 430 |
+
problem_shapes.push_back({
|
| 431 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 432 |
+
{2, 16, 10, 16, 64},
|
| 433 |
+
{96, 3, 4, 5, 64},
|
| 434 |
+
{1, 0, 1},
|
| 435 |
+
{0, 2, 0},
|
| 436 |
+
{2, 2, 3},
|
| 437 |
+
{1, 1, 1},
|
| 438 |
+
1
|
| 439 |
+
});
|
| 440 |
+
// Filter 3x4x5 + asymmetric padding 102/010, w/ dilation
|
| 441 |
+
problem_shapes.push_back({
|
| 442 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 443 |
+
{2, 16, 10, 16, 64},
|
| 444 |
+
{96, 3, 4, 5, 64},
|
| 445 |
+
{1, 0, 1},
|
| 446 |
+
{0, 2, 0},
|
| 447 |
+
{1, 1, 1},
|
| 448 |
+
{2, 2, 3},
|
| 449 |
+
1
|
| 450 |
+
});
|
| 451 |
+
// Filter 3x4x5 + asymmetric padding 102/010, w/ stride, w/ dilation
|
| 452 |
+
problem_shapes.push_back({
|
| 453 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 454 |
+
{2, 16, 10, 16, 64},
|
| 455 |
+
{96, 3, 4, 5, 64},
|
| 456 |
+
{1, 0, 1},
|
| 457 |
+
{0, 2, 0},
|
| 458 |
+
{2, 2, 3},
|
| 459 |
+
{2, 2, 3},
|
| 460 |
+
1
|
| 461 |
+
});
|
| 462 |
+
return problem_shapes;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 467 |
+
// Wgrad
|
| 468 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 469 |
+
|
| 470 |
+
// Specialization for 1D wgrad problems
|
| 471 |
+
template<>
|
| 472 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 1>> inline
|
| 473 |
+
get_conv_problem_vector<1, cutlass::conv::Operator::kWgrad>() {
|
| 474 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 1>;
|
| 475 |
+
std::vector<ProblemShape> problem_shapes;
|
| 476 |
+
problem_shapes.push_back({
|
| 477 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 478 |
+
{1, 8, 64}, // nwc
|
| 479 |
+
{64, 1, 64}, // ksc
|
| 480 |
+
{0}, // padding lower (pad_w)
|
| 481 |
+
{0}, // padding upper (pad_w)
|
| 482 |
+
{1}, // stride (stride_w)
|
| 483 |
+
{1}, // dilation (dilation_w)
|
| 484 |
+
1 // group
|
| 485 |
+
});
|
| 486 |
+
// Filter-K = 16 for predication
|
| 487 |
+
problem_shapes.push_back({
|
| 488 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 489 |
+
{1, 8, 64},
|
| 490 |
+
{16,1, 64},
|
| 491 |
+
{0},
|
| 492 |
+
{0},
|
| 493 |
+
{1},
|
| 494 |
+
{1},
|
| 495 |
+
1
|
| 496 |
+
});
|
| 497 |
+
// N = 2 and K = 128 for a larger grid
|
| 498 |
+
problem_shapes.push_back({
|
| 499 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 500 |
+
{2, 8, 64},
|
| 501 |
+
{96, 1, 64},
|
| 502 |
+
{0},
|
| 503 |
+
{0},
|
| 504 |
+
{1},
|
| 505 |
+
{1},
|
| 506 |
+
1
|
| 507 |
+
});
|
| 508 |
+
// N = 7 and K = 256 for a even larger grid
|
| 509 |
+
problem_shapes.push_back({
|
| 510 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 511 |
+
{7, 8, 64},
|
| 512 |
+
{256, 1, 64},
|
| 513 |
+
{0},
|
| 514 |
+
{0},
|
| 515 |
+
{1},
|
| 516 |
+
{1},
|
| 517 |
+
1
|
| 518 |
+
});
|
| 519 |
+
// 3 filter, no padding
|
| 520 |
+
problem_shapes.push_back({
|
| 521 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 522 |
+
{2, 8, 32},
|
| 523 |
+
{256, 3, 32},
|
| 524 |
+
{0},
|
| 525 |
+
{0},
|
| 526 |
+
{1},
|
| 527 |
+
{1},
|
| 528 |
+
1
|
| 529 |
+
});
|
| 530 |
+
// 3 filter, symmetric padding
|
| 531 |
+
problem_shapes.push_back({
|
| 532 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 533 |
+
{2, 8, 32},
|
| 534 |
+
{256, 3, 32},
|
| 535 |
+
{1},
|
| 536 |
+
{1},
|
| 537 |
+
{1},
|
| 538 |
+
{1},
|
| 539 |
+
1
|
| 540 |
+
});
|
| 541 |
+
// 4 filter, asymmetric padding
|
| 542 |
+
problem_shapes.push_back({
|
| 543 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 544 |
+
{2, 8, 32},
|
| 545 |
+
{256, 4, 32},
|
| 546 |
+
{0},
|
| 547 |
+
{1},
|
| 548 |
+
{1},
|
| 549 |
+
{1},
|
| 550 |
+
1
|
| 551 |
+
});
|
| 552 |
+
// 3 filter, asymmetric padding and tstride of 2
|
| 553 |
+
problem_shapes.push_back({
|
| 554 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 555 |
+
{2, 8, 32},
|
| 556 |
+
{256, 3, 32},
|
| 557 |
+
{0},
|
| 558 |
+
{1},
|
| 559 |
+
{2},
|
| 560 |
+
{1},
|
| 561 |
+
1
|
| 562 |
+
});
|
| 563 |
+
// 3 filter, asymmetric padding and dilation of 2
|
| 564 |
+
problem_shapes.push_back({
|
| 565 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 566 |
+
{2, 8, 32},
|
| 567 |
+
{256, 3, 32},
|
| 568 |
+
{0},
|
| 569 |
+
{1},
|
| 570 |
+
{1},
|
| 571 |
+
{2},
|
| 572 |
+
1
|
| 573 |
+
});
|
| 574 |
+
// To test streamk, equals to gemm-MxNxK size 128x640x2048
|
| 575 |
+
problem_shapes.push_back({
|
| 576 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 577 |
+
{2, 1024, 128},
|
| 578 |
+
{640, 1, 128},
|
| 579 |
+
{0},
|
| 580 |
+
{0},
|
| 581 |
+
{1},
|
| 582 |
+
{1},
|
| 583 |
+
1
|
| 584 |
+
});
|
| 585 |
+
// To test streamk, equals to gemm-MxNxK size 128x640x2080
|
| 586 |
+
problem_shapes.push_back({
|
| 587 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 588 |
+
{2, 1040, 128},
|
| 589 |
+
{640, 1, 128},
|
| 590 |
+
{0},
|
| 591 |
+
{0},
|
| 592 |
+
{1},
|
| 593 |
+
{1},
|
| 594 |
+
1
|
| 595 |
+
});
|
| 596 |
+
return problem_shapes;
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
// Specialization for 2D wgrad problems
|
| 600 |
+
template<>
|
| 601 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 2>> inline
|
| 602 |
+
get_conv_problem_vector<2, cutlass::conv::Operator::kWgrad>() {
|
| 603 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 2>;
|
| 604 |
+
std::vector<ProblemShape> problem_shapes;
|
| 605 |
+
problem_shapes.push_back({
|
| 606 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 607 |
+
{1, 8, 8, 64}, // nhwc
|
| 608 |
+
{64, 1, 1, 64}, // krsc
|
| 609 |
+
{0, 0}, // padding lower (pad_h, pad_w)
|
| 610 |
+
{0, 0}, // padding upper (pad_h, pad_w)
|
| 611 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 612 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 613 |
+
1 // group
|
| 614 |
+
});
|
| 615 |
+
// Filter-K = 16 for predication
|
| 616 |
+
problem_shapes.push_back({
|
| 617 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 618 |
+
{1, 8, 8, 64},
|
| 619 |
+
{16, 1, 1, 64},
|
| 620 |
+
{0, 0},
|
| 621 |
+
{0, 0},
|
| 622 |
+
{1, 1},
|
| 623 |
+
{1, 1},
|
| 624 |
+
1
|
| 625 |
+
});
|
| 626 |
+
// N = 2 and K = 128 for a larger grid
|
| 627 |
+
problem_shapes.push_back({
|
| 628 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 629 |
+
{2, 8, 8, 64},
|
| 630 |
+
{96, 1, 1, 64},
|
| 631 |
+
{0, 0},
|
| 632 |
+
{0, 0},
|
| 633 |
+
{1, 1},
|
| 634 |
+
{1, 1},
|
| 635 |
+
1
|
| 636 |
+
});
|
| 637 |
+
// N = 7 and K = 256 for a even larger grid
|
| 638 |
+
problem_shapes.push_back({
|
| 639 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 640 |
+
{7, 8, 8, 64},
|
| 641 |
+
{256, 1, 1, 64},
|
| 642 |
+
{0, 0},
|
| 643 |
+
{0, 0},
|
| 644 |
+
{1, 1},
|
| 645 |
+
{1, 1},
|
| 646 |
+
1
|
| 647 |
+
});
|
| 648 |
+
// 3x3 filter, no padding
|
| 649 |
+
problem_shapes.push_back({
|
| 650 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 651 |
+
{2, 8, 8, 32},
|
| 652 |
+
{256, 3, 3, 32},
|
| 653 |
+
{0, 0},
|
| 654 |
+
{0, 0},
|
| 655 |
+
{1, 1},
|
| 656 |
+
{1, 1},
|
| 657 |
+
1
|
| 658 |
+
});
|
| 659 |
+
// 3x3 filter, symmetric padding
|
| 660 |
+
problem_shapes.push_back({
|
| 661 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 662 |
+
{2, 8, 8, 32},
|
| 663 |
+
{256, 3, 3, 32},
|
| 664 |
+
{1, 1},
|
| 665 |
+
{1, 1},
|
| 666 |
+
{1, 1},
|
| 667 |
+
{1, 1},
|
| 668 |
+
1
|
| 669 |
+
});
|
| 670 |
+
// 2x5 filter, asymmetric padding 1,0/1,0
|
| 671 |
+
problem_shapes.push_back({
|
| 672 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 673 |
+
{2, 8, 8, 32},
|
| 674 |
+
{256, 2, 5, 32},
|
| 675 |
+
{1, 1},
|
| 676 |
+
{0, 0},
|
| 677 |
+
{1, 1},
|
| 678 |
+
{1, 1},
|
| 679 |
+
1
|
| 680 |
+
});
|
| 681 |
+
// 2x5 filter, asymmetric padding 1,0/1,0, w/ stride
|
| 682 |
+
problem_shapes.push_back({
|
| 683 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 684 |
+
{2, 15, 16, 32},
|
| 685 |
+
{256, 2, 5, 32},
|
| 686 |
+
{1, 1},
|
| 687 |
+
{0, 0},
|
| 688 |
+
{2, 3},
|
| 689 |
+
{1, 1},
|
| 690 |
+
1
|
| 691 |
+
});
|
| 692 |
+
// 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
|
| 693 |
+
problem_shapes.push_back({
|
| 694 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 695 |
+
{2, 16, 16, 32},
|
| 696 |
+
{256, 2, 5, 32},
|
| 697 |
+
{1, 1},
|
| 698 |
+
{0, 0},
|
| 699 |
+
{1, 1},
|
| 700 |
+
{2, 3},
|
| 701 |
+
1
|
| 702 |
+
});
|
| 703 |
+
// 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation
|
| 704 |
+
problem_shapes.push_back({
|
| 705 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 706 |
+
{2, 16, 15, 32},
|
| 707 |
+
{256, 2, 5, 32},
|
| 708 |
+
{1, 1},
|
| 709 |
+
{0, 0},
|
| 710 |
+
{2, 3},
|
| 711 |
+
{2, 3},
|
| 712 |
+
1
|
| 713 |
+
});
|
| 714 |
+
// To test streamk, equals to gemm-MxNxK size 128x640x2048
|
| 715 |
+
problem_shapes.push_back({
|
| 716 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 717 |
+
{2, 64, 16, 128},
|
| 718 |
+
{640, 1, 1, 128},
|
| 719 |
+
{0, 0},
|
| 720 |
+
{0, 0},
|
| 721 |
+
{1, 1},
|
| 722 |
+
{1, 1},
|
| 723 |
+
1
|
| 724 |
+
});
|
| 725 |
+
// To test streamk, equals to gemm-MxNxK size 128x640x2080
|
| 726 |
+
problem_shapes.push_back({
|
| 727 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 728 |
+
{2, 65, 16, 128},
|
| 729 |
+
{640, 1, 1, 128},
|
| 730 |
+
{0, 0},
|
| 731 |
+
{0, 0},
|
| 732 |
+
{1, 1},
|
| 733 |
+
{1, 1},
|
| 734 |
+
1
|
| 735 |
+
});
|
| 736 |
+
return problem_shapes;
|
| 737 |
+
}
|
| 738 |
+
|
| 739 |
+
// Specialization for 3D wgrad problems
|
| 740 |
+
template<>
|
| 741 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 3>> inline
|
| 742 |
+
get_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>() {
|
| 743 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 3>;
|
| 744 |
+
std::vector<ProblemShape> problem_shapes;
|
| 745 |
+
problem_shapes.push_back({
|
| 746 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 747 |
+
{2, 1, 8, 8, 64}, // ndhwc
|
| 748 |
+
{64, 1, 1, 1, 64}, // ktrsc
|
| 749 |
+
{0, 0, 0}, // padding lower (pad_d, pad_h, pad_w)
|
| 750 |
+
{0, 0, 0}, // padding upper (pad_d, pad_h, pad_w)
|
| 751 |
+
{1, 1, 1}, // stride (stride_d, stride_h, stride_w)
|
| 752 |
+
{1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
|
| 753 |
+
1 // group
|
| 754 |
+
});
|
| 755 |
+
// Filter 3x3x3 + no padding
|
| 756 |
+
problem_shapes.push_back({
|
| 757 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 758 |
+
{2, 3, 5, 8, 32},
|
| 759 |
+
{96, 3, 3, 3, 32},
|
| 760 |
+
{0, 0, 0},
|
| 761 |
+
{0, 0, 0},
|
| 762 |
+
{1, 1, 1},
|
| 763 |
+
{1, 1, 1},
|
| 764 |
+
1
|
| 765 |
+
});
|
| 766 |
+
// Filter 3x4x5 + asymmetric padding 102/010
|
| 767 |
+
problem_shapes.push_back({
|
| 768 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 769 |
+
{2, 3, 5, 8, 32},
|
| 770 |
+
{96, 3, 4, 5, 32},
|
| 771 |
+
{1, 0, 1},
|
| 772 |
+
{0, 2, 0},
|
| 773 |
+
{1, 1, 1},
|
| 774 |
+
{1, 1, 1},
|
| 775 |
+
1
|
| 776 |
+
});
|
| 777 |
+
// Filter 3x4x5 + asymmetric padding 102/010, w/ stride
|
| 778 |
+
problem_shapes.push_back({
|
| 779 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 780 |
+
{2, 16, 10, 16, 32},
|
| 781 |
+
{96, 3, 4, 5, 32},
|
| 782 |
+
{1, 0, 1},
|
| 783 |
+
{0, 2, 0},
|
| 784 |
+
{2, 2, 3},
|
| 785 |
+
{1, 1, 1},
|
| 786 |
+
1
|
| 787 |
+
});
|
| 788 |
+
// Filter 3x4x5 + asymmetric padding 102/010, w/ dilation
|
| 789 |
+
problem_shapes.push_back({
|
| 790 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 791 |
+
{2, 16, 10, 16, 32},
|
| 792 |
+
{96, 3, 4, 5, 32},
|
| 793 |
+
{1, 0, 1},
|
| 794 |
+
{0, 2, 0},
|
| 795 |
+
{1, 1, 1},
|
| 796 |
+
{2, 2, 3},
|
| 797 |
+
1
|
| 798 |
+
});
|
| 799 |
+
// To test streamk, equals to gemm-MxNxK size 128x640x2048
|
| 800 |
+
problem_shapes.push_back({
|
| 801 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 802 |
+
{2, 1, 64, 16, 128},
|
| 803 |
+
{640, 1, 1, 1, 128},
|
| 804 |
+
{0, 0, 0},
|
| 805 |
+
{0, 0, 0},
|
| 806 |
+
{1, 1, 1},
|
| 807 |
+
{1, 1, 1},
|
| 808 |
+
1
|
| 809 |
+
});
|
| 810 |
+
// To test streamk, equals to gemm-MxNxK size 128x640x2080
|
| 811 |
+
problem_shapes.push_back({
|
| 812 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 813 |
+
{2, 1, 65, 16, 128},
|
| 814 |
+
{640, 1, 1, 1, 128},
|
| 815 |
+
{0, 0, 0},
|
| 816 |
+
{0, 0, 0},
|
| 817 |
+
{1, 1, 1},
|
| 818 |
+
{1, 1, 1},
|
| 819 |
+
1
|
| 820 |
+
});
|
| 821 |
+
return problem_shapes;
|
| 822 |
+
}
|
| 823 |
+
|
| 824 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 825 |
+
// Grouped Wgrad
|
| 826 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 827 |
+
|
| 828 |
+
// Get problem size vectors for group conv problems
|
| 829 |
+
template<int SpatialDim, cutlass::conv::Operator ConvOp>
|
| 830 |
+
std::vector<cutlass::conv::ConvProblemShape<ConvOp, SpatialDim>>
|
| 831 |
+
inline
|
| 832 |
+
get_grouped_conv_problem_vector(int GroupsPerTile);
|
| 833 |
+
|
| 834 |
+
// Specialization for 3D wgrad problems
|
| 835 |
+
template<>
|
| 836 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 3>> inline
|
| 837 |
+
get_grouped_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>(int GroupsPerTile) {
|
| 838 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 3>;
|
| 839 |
+
std::vector<ProblemShape> problem_shapes;
|
| 840 |
+
|
| 841 |
+
if (GroupsPerTile == 1) {
|
| 842 |
+
// channel_per_group == 64
|
| 843 |
+
problem_shapes.push_back({
|
| 844 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 845 |
+
{1, 1, 16, 16, 2048}, // ndhwc
|
| 846 |
+
{2048, 1, 3, 3, 64}, // ktrsc
|
| 847 |
+
{0, 1, 1}, // padding lower (pad_d, pad_h, pad_w)
|
| 848 |
+
{0, 1, 1}, // padding upper (pad_d, pad_h, pad_w)
|
| 849 |
+
{1, 1, 1}, // stride (stride_d, stride_h, stride_w)
|
| 850 |
+
{1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
|
| 851 |
+
32 // groups
|
| 852 |
+
});
|
| 853 |
+
}
|
| 854 |
+
else if (GroupsPerTile == 2) {
|
| 855 |
+
// channel_per_group == 32
|
| 856 |
+
problem_shapes.push_back({
|
| 857 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 858 |
+
{1, 1, 16, 16, 1024}, // ndhwc
|
| 859 |
+
{1024, 1, 3, 3, 32}, // ktrsc
|
| 860 |
+
{0, 1, 1}, // padding lower (pad_d, pad_h, pad_w)
|
| 861 |
+
{0, 1, 1}, // padding upper (pad_d, pad_h, pad_w)
|
| 862 |
+
{1, 1, 1}, // stride (stride_d, stride_h, stride_w)
|
| 863 |
+
{1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
|
| 864 |
+
32 // groups
|
| 865 |
+
});
|
| 866 |
+
}
|
| 867 |
+
else if (GroupsPerTile == 4) {
|
| 868 |
+
// channel_per_group == 16
|
| 869 |
+
problem_shapes.push_back({
|
| 870 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 871 |
+
{1, 1, 16, 16, 512}, // ndhwc
|
| 872 |
+
{512, 1, 3, 3, 16}, // ktrsc
|
| 873 |
+
{0, 1, 1}, // padding lower (pad_d, pad_h, pad_w)
|
| 874 |
+
{0, 1, 1}, // padding upper (pad_d, pad_h, pad_w)
|
| 875 |
+
{1, 1, 1}, // stride (stride_d, stride_h, stride_w)
|
| 876 |
+
{1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
|
| 877 |
+
32 // groups
|
| 878 |
+
});
|
| 879 |
+
}
|
| 880 |
+
else if (GroupsPerTile == 8) {
|
| 881 |
+
// channel_per_group == 8
|
| 882 |
+
problem_shapes.push_back({
|
| 883 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 884 |
+
{1, 1, 16, 16, 256}, // ndhwc
|
| 885 |
+
{256, 1, 3, 3, 8}, // ktrsc
|
| 886 |
+
{0, 1, 1}, // padding lower (pad_d, pad_h, pad_w)
|
| 887 |
+
{0, 1, 1}, // padding upper (pad_d, pad_h, pad_w)
|
| 888 |
+
{1, 1, 1}, // stride (stride_d, stride_h, stride_w)
|
| 889 |
+
{1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
|
| 890 |
+
32 // groups
|
| 891 |
+
});
|
| 892 |
+
}
|
| 893 |
+
return problem_shapes;
|
| 894 |
+
}
|
| 895 |
+
|
| 896 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 897 |
+
// Unit Stride Dgrad
|
| 898 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 899 |
+
|
| 900 |
+
// Specialization for 1D dgrad problems
|
| 901 |
+
template<>
|
| 902 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 1>> inline
|
| 903 |
+
get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, false>() {
|
| 904 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 1>;
|
| 905 |
+
std::vector<ProblemShape> problem_shapes;
|
| 906 |
+
problem_shapes.push_back({
|
| 907 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 908 |
+
{1, 8, 64}, // nqk
|
| 909 |
+
{64, 1, 64}, // ksc
|
| 910 |
+
{0}, // padding lower (pad_w)
|
| 911 |
+
{0}, // padding upper (pad_w)
|
| 912 |
+
{1}, // stride (stride_w)
|
| 913 |
+
{1}, // dilation (dilation_w)
|
| 914 |
+
1 // group
|
| 915 |
+
});
|
| 916 |
+
// non-packed input strides.
|
| 917 |
+
problem_shapes.push_back({
|
| 918 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 919 |
+
{1, 8, 64}, // nqk
|
| 920 |
+
{800, 80, 1}, // stride (nqk)
|
| 921 |
+
{64, 1, 64}, // ksc
|
| 922 |
+
{64, 64, 1}, // stride (ksc)
|
| 923 |
+
{0}, // padding lower (pad_w)
|
| 924 |
+
{0}, // padding upper (pad_w)
|
| 925 |
+
{1}, // stride (stride_w)
|
| 926 |
+
{1}, // dilation (dilation_w)
|
| 927 |
+
1 // group
|
| 928 |
+
});
|
| 929 |
+
// non-packed output strides.
|
| 930 |
+
problem_shapes.push_back({
|
| 931 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 932 |
+
{1, 8, 64}, // nqk
|
| 933 |
+
{512, 64, 1}, // stride (nqk)
|
| 934 |
+
{64, 1, 64}, // ksc
|
| 935 |
+
{64, 64, 1}, // stride (ksc)
|
| 936 |
+
{800, 80, 1}, // stride (nwc)
|
| 937 |
+
{0}, // padding lower (pad_w)
|
| 938 |
+
{0}, // padding upper (pad_w)
|
| 939 |
+
{1}, // stride (stride_w)
|
| 940 |
+
{1}, // dilation (dilation_w)
|
| 941 |
+
1 // group
|
| 942 |
+
});
|
| 943 |
+
// Filter-K = 16 for predication
|
| 944 |
+
problem_shapes.push_back({
|
| 945 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 946 |
+
{1, 8, 16},
|
| 947 |
+
{64, 1, 16},
|
| 948 |
+
{0},
|
| 949 |
+
{0},
|
| 950 |
+
{1},
|
| 951 |
+
{1},
|
| 952 |
+
1
|
| 953 |
+
});
|
| 954 |
+
// N = 2 and K = 128 for a larger grid
|
| 955 |
+
problem_shapes.push_back({
|
| 956 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 957 |
+
{2, 8, 96},
|
| 958 |
+
{64, 1, 96},
|
| 959 |
+
{0},
|
| 960 |
+
{0},
|
| 961 |
+
{1},
|
| 962 |
+
{1},
|
| 963 |
+
1
|
| 964 |
+
});
|
| 965 |
+
// N = 7 and K = 256 for a even larger grid
|
| 966 |
+
problem_shapes.push_back({
|
| 967 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 968 |
+
{7, 8, 256},
|
| 969 |
+
{64, 1, 256},
|
| 970 |
+
{0},
|
| 971 |
+
{0},
|
| 972 |
+
{1},
|
| 973 |
+
{1},
|
| 974 |
+
1
|
| 975 |
+
});
|
| 976 |
+
// 3 filter, no padding
|
| 977 |
+
problem_shapes.push_back({
|
| 978 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 979 |
+
{2, 8, 256},
|
| 980 |
+
{64, 3, 256},
|
| 981 |
+
{0},
|
| 982 |
+
{0},
|
| 983 |
+
{1},
|
| 984 |
+
{1},
|
| 985 |
+
1
|
| 986 |
+
});
|
| 987 |
+
// 3 filter, symmetric padding with k % cta_k !=0
|
| 988 |
+
problem_shapes.push_back({
|
| 989 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 990 |
+
{2, 8, 256},
|
| 991 |
+
{32, 3, 256},
|
| 992 |
+
{1},
|
| 993 |
+
{1},
|
| 994 |
+
{1},
|
| 995 |
+
{1},
|
| 996 |
+
1
|
| 997 |
+
});
|
| 998 |
+
// 4 filter, asymmetric padding
|
| 999 |
+
problem_shapes.push_back({
|
| 1000 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1001 |
+
{2, 8, 256},
|
| 1002 |
+
{64, 4, 256},
|
| 1003 |
+
{0},
|
| 1004 |
+
{1},
|
| 1005 |
+
{1},
|
| 1006 |
+
{1},
|
| 1007 |
+
1
|
| 1008 |
+
});
|
| 1009 |
+
// 3 filter, asymmetric padding and dilation of 2
|
| 1010 |
+
problem_shapes.push_back({
|
| 1011 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1012 |
+
{2, 16, 64},
|
| 1013 |
+
{256, 3, 64},
|
| 1014 |
+
{0},
|
| 1015 |
+
{1},
|
| 1016 |
+
{1},
|
| 1017 |
+
{2},
|
| 1018 |
+
1
|
| 1019 |
+
});
|
| 1020 |
+
return problem_shapes;
|
| 1021 |
+
}
|
| 1022 |
+
|
| 1023 |
+
// Specialization for 2D dgrad problems
|
| 1024 |
+
template<>
|
| 1025 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 2>> inline
|
| 1026 |
+
get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, false>() {
|
| 1027 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 2>;
|
| 1028 |
+
std::vector<ProblemShape> problem_shapes;
|
| 1029 |
+
problem_shapes.push_back({
|
| 1030 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1031 |
+
{1, 8, 8, 64}, // npqk
|
| 1032 |
+
{64, 1, 1, 64}, // krsc
|
| 1033 |
+
{0, 0}, // padding lower (pad_h, pad_w)
|
| 1034 |
+
{0, 0}, // padding upper (pad_h, pad_w)
|
| 1035 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 1036 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 1037 |
+
1 // group
|
| 1038 |
+
});
|
| 1039 |
+
// non-packed input strides.
|
| 1040 |
+
problem_shapes.push_back({
|
| 1041 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1042 |
+
{1, 8, 8, 64}, // npqk
|
| 1043 |
+
{8000, 800, 80, 1}, // stride (npqk)
|
| 1044 |
+
{64, 1, 1, 64}, // krsc
|
| 1045 |
+
{64, 64, 64, 1}, // stride (krsc)
|
| 1046 |
+
{0, 0}, // padding lower (pad_h, pad_w)
|
| 1047 |
+
{0, 0}, // padding upper (pad_h, pad_w)
|
| 1048 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 1049 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 1050 |
+
1 // group
|
| 1051 |
+
});
|
| 1052 |
+
// non-packed output strides.
|
| 1053 |
+
problem_shapes.push_back({
|
| 1054 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1055 |
+
{1, 8, 8, 64}, // npqk
|
| 1056 |
+
{4096, 512, 64, 1}, // stride (npqk)
|
| 1057 |
+
{64, 1, 1, 64}, // krsc
|
| 1058 |
+
{64, 64, 64, 1}, // stride (krsc)
|
| 1059 |
+
{8000, 800, 80, 1}, // stride (nhwc)
|
| 1060 |
+
{0, 0}, // padding lower (pad_h, pad_w)
|
| 1061 |
+
{0, 0}, // padding upper (pad_h, pad_w)
|
| 1062 |
+
{1, 1}, // stride (stride_h, stride_w)
|
| 1063 |
+
{1, 1}, // dilation (dilation_h, dilation_w)
|
| 1064 |
+
1 // group
|
| 1065 |
+
});
|
| 1066 |
+
// Filter-K = 16 for predication
|
| 1067 |
+
problem_shapes.push_back({
|
| 1068 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1069 |
+
{1, 8, 8, 16},
|
| 1070 |
+
{64, 1, 1, 16},
|
| 1071 |
+
{0, 0},
|
| 1072 |
+
{0, 0},
|
| 1073 |
+
{1, 1},
|
| 1074 |
+
{1, 1},
|
| 1075 |
+
1
|
| 1076 |
+
});
|
| 1077 |
+
// N = 2 and K = 128 for a larger grid
|
| 1078 |
+
problem_shapes.push_back({
|
| 1079 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1080 |
+
{2, 8, 8, 96},
|
| 1081 |
+
{64, 1, 1, 96},
|
| 1082 |
+
{0, 0},
|
| 1083 |
+
{0, 0},
|
| 1084 |
+
{1, 1},
|
| 1085 |
+
{1, 1},
|
| 1086 |
+
1
|
| 1087 |
+
});
|
| 1088 |
+
// N = 7 and K = 256 for a even larger grid
|
| 1089 |
+
problem_shapes.push_back({
|
| 1090 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1091 |
+
{7, 8, 8, 256},
|
| 1092 |
+
{64, 1, 1, 256},
|
| 1093 |
+
{0, 0},
|
| 1094 |
+
{0, 0},
|
| 1095 |
+
{1, 1},
|
| 1096 |
+
{1, 1},
|
| 1097 |
+
1
|
| 1098 |
+
});
|
| 1099 |
+
// 3x3 filter, no padding
|
| 1100 |
+
problem_shapes.push_back({
|
| 1101 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1102 |
+
{2, 8, 8, 256},
|
| 1103 |
+
{64, 3, 3, 256},
|
| 1104 |
+
{0, 0},
|
| 1105 |
+
{0, 0},
|
| 1106 |
+
{1, 1},
|
| 1107 |
+
{1, 1},
|
| 1108 |
+
1
|
| 1109 |
+
});
|
| 1110 |
+
// 3x3 filter, symmetric padding with k % cta_k !=0
|
| 1111 |
+
problem_shapes.push_back({
|
| 1112 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1113 |
+
{2, 8, 8, 256},
|
| 1114 |
+
{32, 3, 3, 256},
|
| 1115 |
+
{1, 1},
|
| 1116 |
+
{1, 1},
|
| 1117 |
+
{1, 1},
|
| 1118 |
+
{1, 1},
|
| 1119 |
+
1
|
| 1120 |
+
});
|
| 1121 |
+
// 2x5 filter, asymmetric padding 1,0/1,0
|
| 1122 |
+
problem_shapes.push_back({
|
| 1123 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1124 |
+
{2, 8, 8, 256},
|
| 1125 |
+
{64, 2, 5, 256},
|
| 1126 |
+
{1, 1},
|
| 1127 |
+
{0, 0},
|
| 1128 |
+
{1, 1},
|
| 1129 |
+
{1, 1},
|
| 1130 |
+
1
|
| 1131 |
+
});
|
| 1132 |
+
// 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
|
| 1133 |
+
problem_shapes.push_back({
|
| 1134 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1135 |
+
{2, 16, 16, 64},
|
| 1136 |
+
{256, 2, 5, 64},
|
| 1137 |
+
{1, 1},
|
| 1138 |
+
{0, 0},
|
| 1139 |
+
{1, 1},
|
| 1140 |
+
{2, 3},
|
| 1141 |
+
1
|
| 1142 |
+
});
|
| 1143 |
+
return problem_shapes;
|
| 1144 |
+
}
|
| 1145 |
+
|
| 1146 |
+
// Specialization for 3D dgrad problems
|
| 1147 |
+
template<>
|
| 1148 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 3>> inline
|
| 1149 |
+
get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, false>() {
|
| 1150 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 3>;
|
| 1151 |
+
std::vector<ProblemShape> problem_shapes;
|
| 1152 |
+
// Filter-K = 16 for predication
|
| 1153 |
+
problem_shapes.push_back({
|
| 1154 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1155 |
+
{1, 1, 8, 8, 16},
|
| 1156 |
+
{64, 1, 1, 1, 16},
|
| 1157 |
+
{0, 0, 0},
|
| 1158 |
+
{0, 0, 0},
|
| 1159 |
+
{1, 1, 1},
|
| 1160 |
+
{1, 1, 1},
|
| 1161 |
+
1
|
| 1162 |
+
});
|
| 1163 |
+
// non-packed input output strides.
|
| 1164 |
+
problem_shapes.push_back({
|
| 1165 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1166 |
+
{1, 1, 8, 8, 64}, // nzpqk
|
| 1167 |
+
{8000, 8000, 800, 80, 1}, // stride (nzpqk)
|
| 1168 |
+
{64, 1, 1, 1, 64}, // ktrsc
|
| 1169 |
+
{64, 64, 64, 64, 1}, // stride (ktrsc)
|
| 1170 |
+
{8000, 8000, 800, 80, 1}, // stride (ndhwc)
|
| 1171 |
+
{0, 0, 0}, // padding lower (pad_d, pad_h, pad_w)
|
| 1172 |
+
{0, 0, 0}, // padding upper (pad_d, pad_h, pad_w)
|
| 1173 |
+
{1, 1, 1}, // stride (stride_d, stride_h, stride_w)
|
| 1174 |
+
{1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
|
| 1175 |
+
1 // group
|
| 1176 |
+
});
|
| 1177 |
+
// N = 7 and K = 256 for a larger grid
|
| 1178 |
+
problem_shapes.push_back({
|
| 1179 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1180 |
+
{2, 1, 8, 8, 96},
|
| 1181 |
+
{64, 1, 1, 1, 96},
|
| 1182 |
+
{0, 0, 0},
|
| 1183 |
+
{0, 0, 0},
|
| 1184 |
+
{1, 1, 1},
|
| 1185 |
+
{1, 1, 1},
|
| 1186 |
+
1
|
| 1187 |
+
});
|
| 1188 |
+
// Filter 3x4x5 + symmetric padding 111
|
| 1189 |
+
problem_shapes.push_back({
|
| 1190 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1191 |
+
{2, 3, 5, 8, 96},
|
| 1192 |
+
{64, 3, 4, 5, 96},
|
| 1193 |
+
{1, 1, 1},
|
| 1194 |
+
{1, 1, 1},
|
| 1195 |
+
{1, 1, 1},
|
| 1196 |
+
{1, 1, 1},
|
| 1197 |
+
1
|
| 1198 |
+
});
|
| 1199 |
+
// Filter 3x4x5 + asymmetric padding 102/010
|
| 1200 |
+
problem_shapes.push_back({
|
| 1201 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1202 |
+
{2, 3, 5, 8, 96},
|
| 1203 |
+
{64, 3, 4, 5, 96},
|
| 1204 |
+
{1, 0, 1},
|
| 1205 |
+
{0, 2, 0},
|
| 1206 |
+
{1, 1, 1},
|
| 1207 |
+
{1, 1, 1},
|
| 1208 |
+
1
|
| 1209 |
+
});
|
| 1210 |
+
// Filter 3x4x5 + asymmetric padding 102/010, w/ dilation
|
| 1211 |
+
problem_shapes.push_back({
|
| 1212 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1213 |
+
{2, 16, 10, 16, 64},
|
| 1214 |
+
{64, 3, 4, 5, 96},
|
| 1215 |
+
{1, 0, 1},
|
| 1216 |
+
{0, 2, 0},
|
| 1217 |
+
{1, 1, 1},
|
| 1218 |
+
{2, 2, 3},
|
| 1219 |
+
1
|
| 1220 |
+
});
|
| 1221 |
+
return problem_shapes;
|
| 1222 |
+
}
|
| 1223 |
+
|
| 1224 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1225 |
+
// Strided Dgrad
|
| 1226 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1227 |
+
|
| 1228 |
+
// Specialization for 1D dgrad problems
|
| 1229 |
+
template<>
|
| 1230 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 1>> inline
|
| 1231 |
+
get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() {
|
| 1232 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 1>;
|
| 1233 |
+
std::vector<ProblemShape> problem_shapes;
|
| 1234 |
+
// Test TMA truncation
|
| 1235 |
+
problem_shapes.push_back({
|
| 1236 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1237 |
+
{1, 512, 64}, // nqk
|
| 1238 |
+
{64, 1, 64}, // ksc
|
| 1239 |
+
{0}, // padding lower (pad_w)
|
| 1240 |
+
{0}, // padding upper (pad_w)
|
| 1241 |
+
{2}, // stride (stride_w)
|
| 1242 |
+
{1}, // dilation (dilation_w)
|
| 1243 |
+
1 // group
|
| 1244 |
+
});
|
| 1245 |
+
problem_shapes.push_back({
|
| 1246 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1247 |
+
{1, 1024, 64}, // nqk
|
| 1248 |
+
{64, 1, 64}, // ksc
|
| 1249 |
+
{0}, // padding lower (pad_w)
|
| 1250 |
+
{0}, // padding upper (pad_w)
|
| 1251 |
+
{4}, // stride (stride_w)
|
| 1252 |
+
{1}, // dilation (dilation_w)
|
| 1253 |
+
1 // group
|
| 1254 |
+
});
|
| 1255 |
+
problem_shapes.push_back({
|
| 1256 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1257 |
+
{1, 2048, 64}, // nqk
|
| 1258 |
+
{64, 1, 64}, // ksc
|
| 1259 |
+
{0}, // padding lower (pad_w)
|
| 1260 |
+
{0}, // padding upper (pad_w)
|
| 1261 |
+
{8}, // stride (stride_w)
|
| 1262 |
+
{1}, // dilation (dilation_w)
|
| 1263 |
+
1 // group
|
| 1264 |
+
});
|
| 1265 |
+
// non-packed input/output strides.
|
| 1266 |
+
// stride divides dilation
|
| 1267 |
+
// asymmetric padding
|
| 1268 |
+
problem_shapes.push_back({
|
| 1269 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1270 |
+
{3, 8, 64}, // nqk
|
| 1271 |
+
{800, 80, 1}, // stride (nqk)
|
| 1272 |
+
{64, 3, 64}, // ksc
|
| 1273 |
+
{64, 64, 1}, // stride (ksc)
|
| 1274 |
+
{800, 80, 1}, // stride (nwc)
|
| 1275 |
+
{0}, // padding lower (pad_w)
|
| 1276 |
+
{1}, // padding upper (pad_w)
|
| 1277 |
+
{2}, // stride (stride_w)
|
| 1278 |
+
{4}, // dilation (dilation_w)
|
| 1279 |
+
1 // group
|
| 1280 |
+
});
|
| 1281 |
+
// non-packed input/output strides.
|
| 1282 |
+
// dilation divides stride
|
| 1283 |
+
// asymmetric padding
|
| 1284 |
+
problem_shapes.push_back({
|
| 1285 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1286 |
+
{3, 8, 64}, // nqk
|
| 1287 |
+
{800, 80, 1}, // stride (nqk)
|
| 1288 |
+
{64, 3, 64}, // ksc
|
| 1289 |
+
{64, 64, 1}, // stride (ksc)
|
| 1290 |
+
{800, 80, 1}, // stride (nwc)
|
| 1291 |
+
{1}, // padding lower (pad_w)
|
| 1292 |
+
{0}, // padding upper (pad_w)
|
| 1293 |
+
{4}, // stride (stride_w)
|
| 1294 |
+
{2}, // dilation (dilation_w)
|
| 1295 |
+
1 // group
|
| 1296 |
+
});
|
| 1297 |
+
// non-packed input/output strides.
|
| 1298 |
+
// stride dilation dont divide
|
| 1299 |
+
// asymmetric padding
|
| 1300 |
+
problem_shapes.push_back({
|
| 1301 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1302 |
+
{3, 8, 64}, // nqk
|
| 1303 |
+
{800, 80, 1}, // stride (nqk)
|
| 1304 |
+
{64, 3, 64}, // ksc
|
| 1305 |
+
{64, 64, 1}, // stride (ksc)
|
| 1306 |
+
{800, 80, 1}, // stride (nwc)
|
| 1307 |
+
{1}, // padding lower (pad_w)
|
| 1308 |
+
{2}, // padding upper (pad_w)
|
| 1309 |
+
{2}, // stride (stride_w)
|
| 1310 |
+
{3}, // dilation (dilation_w)
|
| 1311 |
+
1 // group
|
| 1312 |
+
});
|
| 1313 |
+
return problem_shapes;
|
| 1314 |
+
}
|
| 1315 |
+
|
| 1316 |
+
// Specialization for 2D dgrad problems
|
| 1317 |
+
template<>
|
| 1318 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 2>> inline
|
| 1319 |
+
get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, true>() {
|
| 1320 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 2>;
|
| 1321 |
+
std::vector<ProblemShape> problem_shapes;
|
| 1322 |
+
// 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
|
| 1323 |
+
// mode 0 stride divides dilation
|
| 1324 |
+
// mode 1 dilation divides stride
|
| 1325 |
+
problem_shapes.push_back({
|
| 1326 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1327 |
+
{3, 16, 16, 64},
|
| 1328 |
+
{256, 2, 5, 64},
|
| 1329 |
+
{1, 0},
|
| 1330 |
+
{0, 1},
|
| 1331 |
+
{2, 4},
|
| 1332 |
+
{4, 2},
|
| 1333 |
+
1
|
| 1334 |
+
});
|
| 1335 |
+
// 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
|
| 1336 |
+
// mode 0 dilation divides stride
|
| 1337 |
+
// mode 1 stride divides dilation
|
| 1338 |
+
problem_shapes.push_back({
|
| 1339 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1340 |
+
{3, 16, 16, 64},
|
| 1341 |
+
{256, 2, 5, 64},
|
| 1342 |
+
{1, 0},
|
| 1343 |
+
{0, 1},
|
| 1344 |
+
{4, 2},
|
| 1345 |
+
{2, 4},
|
| 1346 |
+
1
|
| 1347 |
+
});
|
| 1348 |
+
// 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
|
| 1349 |
+
// stride dilation dont divide
|
| 1350 |
+
problem_shapes.push_back({
|
| 1351 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1352 |
+
{3, 16, 16, 64},
|
| 1353 |
+
{256, 2, 5, 64},
|
| 1354 |
+
{1, 0},
|
| 1355 |
+
{0, 1},
|
| 1356 |
+
{3, 2},
|
| 1357 |
+
{2, 3},
|
| 1358 |
+
1
|
| 1359 |
+
});
|
| 1360 |
+
return problem_shapes;
|
| 1361 |
+
}
|
| 1362 |
+
|
| 1363 |
+
// Specialization for 3D dgrad problems
|
| 1364 |
+
template<>
|
| 1365 |
+
std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 3>> inline
|
| 1366 |
+
get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, true>() {
|
| 1367 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 3>;
|
| 1368 |
+
std::vector<ProblemShape> problem_shapes;
|
| 1369 |
+
// Filter 3x4x5 + asymmetric padding 102/010, w/ dilation
|
| 1370 |
+
problem_shapes.push_back({
|
| 1371 |
+
cutlass::conv::Mode::kCrossCorrelation,
|
| 1372 |
+
{2, 16, 10, 16, 64},
|
| 1373 |
+
{64, 3, 4, 5, 96},
|
| 1374 |
+
{1, 0, 1},
|
| 1375 |
+
{0, 2, 0},
|
| 1376 |
+
{2, 1, 2},
|
| 1377 |
+
{4, 2, 3},
|
| 1378 |
+
1
|
| 1379 |
+
});
|
| 1380 |
+
return problem_shapes;
|
| 1381 |
+
}
|
| 1382 |
+
|
| 1383 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 1384 |
+
|
| 1385 |
+
} // namespace cutlass::test
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp
ADDED
|
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
/*! \file
|
| 32 |
+
\brief Implicit GEMM testbed for 3.x API
|
| 33 |
+
*/
|
| 34 |
+
#pragma once
|
| 35 |
+
|
| 36 |
+
#include "cutlass/cutlass.h"
|
| 37 |
+
#include "../../common/cutlass_unit_test.h"
|
| 38 |
+
|
| 39 |
+
#include "cute/tensor.hpp"
|
| 40 |
+
#include "cutlass/kernel_hardware_info.hpp"
|
| 41 |
+
#include "cutlass/conv/convolution.h"
|
| 42 |
+
#include "cutlass/conv/convnd_problem_shape.hpp"
|
| 43 |
+
#include "../test/unit/gemm/device/gemm_testbed_3x.hpp"
|
| 44 |
+
|
| 45 |
+
#include "thrust/universal_vector.h"
|
| 46 |
+
#include "cutlass/util/distribution.h"
|
| 47 |
+
#include "cutlass/util/host_tensor.h"
|
| 48 |
+
#include "cutlass/util/tensor_view_io.h"
|
| 49 |
+
#include "cutlass/util/packed_stride.hpp"
|
| 50 |
+
#include "cutlass/util/reference/host/conv.hpp"
|
| 51 |
+
#include "cutlass/util/reference/host/tensor_fill.h"
|
| 52 |
+
#include "cutlass/util/reference/host/tensor_copy.h"
|
| 53 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 54 |
+
#include "cutlass/util/reference/host/tensor_norm.h"
|
| 55 |
+
#include "cutlass/util/reference/device/tensor_fill.h"
|
| 56 |
+
#include "cutlass/util/reference/device/tensor_compare.h"
|
| 57 |
+
#include "conv_problem_sizes.hpp"
|
| 58 |
+
#include "../cache_testbed_output.h"
|
| 59 |
+
|
| 60 |
+
#include <iostream>
|
| 61 |
+
|
| 62 |
+
#include "cute/layout.hpp"
|
| 63 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 64 |
+
|
| 65 |
+
namespace test::conv::device {
|
| 66 |
+
|
| 67 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 68 |
+
|
| 69 |
+
// Initializes a flat device buffer
|
| 70 |
+
template <typename Element>
|
| 71 |
+
static void
|
| 72 |
+
initialize_values(
|
| 73 |
+
thrust::universal_vector<Element>& dst_ptr,
|
| 74 |
+
cutlass::Distribution::Kind dist_kind,
|
| 75 |
+
uint64_t seed) {
|
| 76 |
+
if (cutlass::Distribution::Uniform == dist_kind) {
|
| 77 |
+
int scope;
|
| 78 |
+
int bits = cutlass::sizeof_bits<Element>::value;
|
| 79 |
+
|
| 80 |
+
if (bits <= 8) {
|
| 81 |
+
scope = 2;
|
| 82 |
+
}
|
| 83 |
+
else if (bits == 16) {
|
| 84 |
+
scope = 4;
|
| 85 |
+
}
|
| 86 |
+
else {
|
| 87 |
+
scope = 8;
|
| 88 |
+
}
|
| 89 |
+
cutlass::reference::host::BlockFillRandomUniform(
|
| 90 |
+
dst_ptr.data().get(), dst_ptr.size(), seed, scope, -scope, 0);
|
| 91 |
+
}
|
| 92 |
+
else if (cutlass::Distribution::Identity == dist_kind) {
|
| 93 |
+
cutlass::reference::host::BlockFillRandomUniform(
|
| 94 |
+
dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0, 0);
|
| 95 |
+
}
|
| 96 |
+
else if (cutlass::Distribution::Gaussian == dist_kind) {
|
| 97 |
+
cutlass::reference::host::BlockFillRandomGaussian(dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0.5);
|
| 98 |
+
}
|
| 99 |
+
else if (cutlass::Distribution::Sequential == dist_kind) {
|
| 100 |
+
cutlass::reference::host::BlockFillSequential(dst_ptr.data().get(), dst_ptr.size());
|
| 101 |
+
}
|
| 102 |
+
else {
|
| 103 |
+
std::cerr << "Invalid distribution kind!\n.";
|
| 104 |
+
exit(1);
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 109 |
+
// utils for sparse or dense conv parameters
|
| 110 |
+
|
| 111 |
+
template <class Conv>
|
| 112 |
+
struct DenseConvParams {
|
| 113 |
+
// Default Kernel data types
|
| 114 |
+
using ElementA = typename Conv::ConvKernel::ElementA;
|
| 115 |
+
using ElementB = typename Conv::ConvKernel::ElementB;
|
| 116 |
+
|
| 117 |
+
static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp;
|
| 118 |
+
static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions;
|
| 119 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<ConvOp, NumSpatialDimensions>;
|
| 120 |
+
|
| 121 |
+
// get the default arguments without sparse data
|
| 122 |
+
auto get_mainloop_arguments(
|
| 123 |
+
[[maybe_unused]] ProblemShape const& problem_shape,
|
| 124 |
+
thrust::universal_vector<ElementA>& tensor_A,
|
| 125 |
+
thrust::universal_vector<ElementB>& tensor_B
|
| 126 |
+
) {
|
| 127 |
+
auto args = typename Conv::ConvKernel::MainloopArguments {
|
| 128 |
+
tensor_A.data().get(),
|
| 129 |
+
tensor_B.data().get(),
|
| 130 |
+
};
|
| 131 |
+
return args;
|
| 132 |
+
}
|
| 133 |
+
};
|
| 134 |
+
|
| 135 |
+
template <class Conv>
|
| 136 |
+
struct SparseConvParams {
|
| 137 |
+
};
|
| 138 |
+
|
| 139 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 140 |
+
template <class Conv, bool isSparseEnabled_ = false>
|
| 141 |
+
struct ConvTestbed {
|
| 142 |
+
// Kernel data types
|
| 143 |
+
using ElementA = typename Conv::ConvKernel::ElementA;
|
| 144 |
+
using ElementB = typename Conv::ConvKernel::ElementB;
|
| 145 |
+
using ElementC = cute::conditional_t<cute::is_void_v<typename Conv::ConvKernel::ElementC>,
|
| 146 |
+
typename Conv::ConvKernel::ElementD, typename Conv::ConvKernel::ElementC>;
|
| 147 |
+
using ElementD = typename Conv::ConvKernel::ElementD;
|
| 148 |
+
using ElementAccumulator = typename Conv::ConvKernel::ElementAccumulator;
|
| 149 |
+
|
| 150 |
+
// ConvTest for sparse kernel
|
| 151 |
+
static constexpr bool isSparseEnabled = isSparseEnabled_;
|
| 152 |
+
using ConvParams = cute::conditional_t<isSparseEnabled, SparseConvParams<Conv>, DenseConvParams<Conv>>;
|
| 153 |
+
ConvParams params;
|
| 154 |
+
|
| 155 |
+
//
|
| 156 |
+
// FusionOperation derived types/queries
|
| 157 |
+
//
|
| 158 |
+
using FusionOp = typename Conv::EpilogueOutputOp;
|
| 159 |
+
|
| 160 |
+
// fusion types are potentially void if the fusion is not supported
|
| 161 |
+
// helper so we don't try to construct HostTensor with void type
|
| 162 |
+
template <typename T, typename U = uint8_t>
|
| 163 |
+
using non_void_t = cute::conditional_t<cute::is_void_v<T>, U, T>;
|
| 164 |
+
using ElementScalar = typename FusionOp::ElementScalar;
|
| 165 |
+
using ElementCompute = typename FusionOp::ElementCompute;
|
| 166 |
+
using BiasType = typename cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias<FusionOp>::type;
|
| 167 |
+
using ElementBias = non_void_t<BiasType>;
|
| 168 |
+
using ActivationType = non_void_t<typename cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithActivation<FusionOp>::type,
|
| 169 |
+
cutlass::epilogue::thread::Identity<ElementCompute>>;
|
| 170 |
+
static constexpr bool IsActivationEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithActivation<FusionOp>::value;
|
| 171 |
+
using ActivationFunctor = cute::conditional_t<IsActivationEnabled, ActivationType, cutlass::epilogue::thread::Identity<ElementCompute>>;
|
| 172 |
+
|
| 173 |
+
static constexpr bool IsBiasEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias<FusionOp>::value &&
|
| 174 |
+
!cute::is_same_v<BiasType, void>;
|
| 175 |
+
static constexpr bool IsPerChannelScaleEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithPerChannelScaled<FusionOp>::value;
|
| 176 |
+
|
| 177 |
+
static constexpr bool DisableSource = cute::is_void_v<typename FusionOp::ElementSource>;
|
| 178 |
+
|
| 179 |
+
static constexpr bool IsResidualEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithResidualAdd<FusionOp>::value;
|
| 180 |
+
|
| 181 |
+
using StrideC = typename Conv::ConvKernel::StrideC;
|
| 182 |
+
using StrideD = typename Conv::ConvKernel::StrideD;
|
| 183 |
+
using ThreadEpilogueOp = typename Conv::ConvKernel::CollectiveEpilogue::ThreadEpilogueOp;
|
| 184 |
+
|
| 185 |
+
static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp;
|
| 186 |
+
static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions;
|
| 187 |
+
using ProblemShape = cutlass::conv::ConvProblemShape<ConvOp, NumSpatialDimensions>;
|
| 188 |
+
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
|
| 189 |
+
using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
| 190 |
+
using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize;
|
| 191 |
+
using Splits = typename gemm::device::detail::Splits;
|
| 192 |
+
|
| 193 |
+
using Schedule = typename Conv::DispatchPolicy::Schedule;
|
| 194 |
+
/// Initialization
|
| 195 |
+
cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform;
|
| 196 |
+
cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform;
|
| 197 |
+
cutlass::Distribution::Kind init_C = cutlass::Distribution::Uniform;
|
| 198 |
+
cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform;
|
| 199 |
+
cutlass::Distribution::Kind init_disable = cutlass::Distribution::Identity; // all zeros
|
| 200 |
+
uint64_t seed = 6090;
|
| 201 |
+
float epsilon = 0.0f;
|
| 202 |
+
int split_p_slices = 1;
|
| 203 |
+
thrust::universal_vector<ElementA> tensor_A;
|
| 204 |
+
thrust::universal_vector<ElementB> tensor_B;
|
| 205 |
+
thrust::universal_vector<ElementC> tensor_C;
|
| 206 |
+
thrust::universal_vector<ElementD> tensor_D_computed;
|
| 207 |
+
thrust::universal_vector<ElementD> tensor_D_reference;
|
| 208 |
+
thrust::universal_vector<ElementBias> tensor_bias;
|
| 209 |
+
thrust::universal_vector<ElementScalar> tensor_alpha;
|
| 210 |
+
thrust::universal_vector<ElementScalar> tensor_beta;
|
| 211 |
+
|
| 212 |
+
// Return true on success, else false
|
| 213 |
+
bool initialize(ProblemShape const& problem_shape, uint64_t seed = 6090) {
|
| 214 |
+
tensor_A.resize(sizeof(ElementA) * problem_shape.size_A());
|
| 215 |
+
tensor_B.resize(sizeof(ElementB) * problem_shape.size_B());
|
| 216 |
+
tensor_C.resize(sizeof(ElementC) * problem_shape.size_C());
|
| 217 |
+
tensor_D_computed.resize(sizeof(ElementD) * problem_shape.size_C());
|
| 218 |
+
tensor_D_reference.resize(sizeof(ElementD) * problem_shape.size_C());
|
| 219 |
+
tensor_bias.resize(sizeof(ElementBias) * cute::size(cute::get<0>(problem_shape.get_shape_B())));
|
| 220 |
+
if constexpr (IsPerChannelScaleEnabled) {
|
| 221 |
+
tensor_alpha.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B())));
|
| 222 |
+
tensor_beta.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B())));
|
| 223 |
+
}
|
| 224 |
+
initialize_values(tensor_A, init_A, seed);
|
| 225 |
+
initialize_values(tensor_B, init_B, seed * 11);
|
| 226 |
+
initialize_values(tensor_C, init_C, seed * 17);
|
| 227 |
+
initialize_values(tensor_bias, init_bias, seed * 19);
|
| 228 |
+
if constexpr (IsPerChannelScaleEnabled) {
|
| 229 |
+
initialize_values(tensor_alpha, init_bias, seed * 23);
|
| 230 |
+
if constexpr (DisableSource) {
|
| 231 |
+
initialize_values(tensor_beta, init_disable, seed * 27);
|
| 232 |
+
}
|
| 233 |
+
else {
|
| 234 |
+
initialize_values(tensor_beta, init_bias, seed * 27);
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
bool flag = true;
|
| 239 |
+
if constexpr (isSparseEnabled) {
|
| 240 |
+
flag &= params.initialize(problem_shape, tensor_B, static_cast<int>(seed + 2023));
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
return flag;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
// Determine SMEM requirements and waive if not satisfied
|
| 247 |
+
bool sufficient() const {
|
| 248 |
+
int device_idx;
|
| 249 |
+
cudaError_t result = cudaGetDevice(&device_idx);
|
| 250 |
+
if (result != cudaSuccess) {
|
| 251 |
+
throw std::runtime_error("cudaGetDevice() API call failed.");
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
int max_smem_size;
|
| 255 |
+
result = cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx);
|
| 256 |
+
if (result != cudaSuccess) {
|
| 257 |
+
throw std::runtime_error("cudaDeviceGetAttribute() failed");
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
return max_smem_size >= Conv::ConvKernel::SharedStorageSize;
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
auto transform_shape_and_stride_with_groups(ProblemShape const& problem_shape) {
|
| 264 |
+
using TensorExtent = cute::array<int32_t, NumSpatialDimensions + 3>;
|
| 265 |
+
using TensorStride = cute::array<int64_t, NumSpatialDimensions + 3>;
|
| 266 |
+
|
| 267 |
+
TensorExtent shape_a_g{};
|
| 268 |
+
TensorExtent shape_b_g{};
|
| 269 |
+
TensorExtent shape_c_g{};
|
| 270 |
+
TensorStride stride_a_g{};
|
| 271 |
+
TensorStride stride_b_g{};
|
| 272 |
+
TensorStride stride_c_g{};
|
| 273 |
+
|
| 274 |
+
auto shape_a = cute::reverse(problem_shape.shape_A);
|
| 275 |
+
auto shape_b = cute::reverse(problem_shape.shape_B);
|
| 276 |
+
auto shape_c = cute::reverse(problem_shape.shape_C);
|
| 277 |
+
auto stride_a = cute::reverse(problem_shape.stride_A);
|
| 278 |
+
auto stride_b = cute::reverse(problem_shape.stride_B);
|
| 279 |
+
auto stride_c = cute::reverse(problem_shape.stride_C);
|
| 280 |
+
|
| 281 |
+
int32_t G = problem_shape.groups;
|
| 282 |
+
|
| 283 |
+
if constexpr (ConvOp == cutlass::conv::Operator::kFprop ||
|
| 284 |
+
ConvOp == cutlass::conv::Operator::kDgrad) {
|
| 285 |
+
// shape_a_g = (c,w,h,d,n,g) or (k,q,p,z,n,g)
|
| 286 |
+
// shape_b_g = (c,s,r,k,t,g)
|
| 287 |
+
// shape_c_g = (k,q,p,z,n,g) or (c,w,h,d,n,g)
|
| 288 |
+
shape_a_g = cute::to_array<int32_t>(tuple_cat(
|
| 289 |
+
cute::make_shape(cute::size<0>(shape_a) / G),
|
| 290 |
+
cute::take<1,NumSpatialDimensions + 2>(shape_a),
|
| 291 |
+
cute::make_shape(G)));
|
| 292 |
+
shape_b_g = cute::to_array<int32_t>(tuple_cat(
|
| 293 |
+
cute::take<0,NumSpatialDimensions + 1>(shape_b),
|
| 294 |
+
cute::make_shape(cute::size<NumSpatialDimensions + 1>(shape_b) / G, G)));
|
| 295 |
+
shape_c_g = cute::to_array<int32_t>(tuple_cat(
|
| 296 |
+
cute::make_shape(cute::size<0>(shape_c) / G),
|
| 297 |
+
cute::take<1,NumSpatialDimensions + 2>(shape_c),
|
| 298 |
+
cute::make_shape(G)));
|
| 299 |
+
|
| 300 |
+
stride_a_g = cute::to_array<int64_t>(append(stride_a, cute::size<0>(shape_a) / G));
|
| 301 |
+
stride_b_g = cute::to_array<int64_t>(append(stride_b,
|
| 302 |
+
cute::size<NumSpatialDimensions + 1>(stride_b) * cute::size<NumSpatialDimensions + 1>(shape_b) / G));
|
| 303 |
+
stride_c_g = cute::to_array<int64_t>(append(stride_c, cute::size<0>(shape_c) / G));
|
| 304 |
+
}
|
| 305 |
+
else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) {
|
| 306 |
+
// shape_a_g = (k,q,p,z,n,g)
|
| 307 |
+
// shape_b_g = (c,w,h,d,n,g)
|
| 308 |
+
// shape_c_g = (c,s,r,k,t,g)
|
| 309 |
+
shape_a_g = cute::to_array<int32_t>(tuple_cat(
|
| 310 |
+
cute::make_shape(cute::size<0>(shape_a) / G),
|
| 311 |
+
cute::take<1,NumSpatialDimensions + 2>(shape_a),
|
| 312 |
+
cute::make_shape(G)));
|
| 313 |
+
shape_b_g = cute::to_array<int32_t>(tuple_cat(
|
| 314 |
+
cute::make_shape(cute::size<0>(shape_b) / G),
|
| 315 |
+
cute::take<1,NumSpatialDimensions + 2>(shape_b),
|
| 316 |
+
cute::make_shape(G)));
|
| 317 |
+
shape_c_g = cute::to_array<int32_t>(tuple_cat(
|
| 318 |
+
cute::take<0,NumSpatialDimensions + 1>(shape_c),
|
| 319 |
+
cute::make_shape(cute::size<NumSpatialDimensions + 1>(shape_c) / G, G)));
|
| 320 |
+
|
| 321 |
+
stride_a_g = cute::to_array<int64_t>(append(stride_a, cute::size<0>(shape_a) / G));
|
| 322 |
+
stride_b_g = cute::to_array<int64_t>(append(stride_b, cute::size<0>(shape_b) / G));
|
| 323 |
+
stride_c_g = cute::to_array<int64_t>(append(stride_c,
|
| 324 |
+
cute::size<NumSpatialDimensions + 1>(stride_c) * cute::size<NumSpatialDimensions + 1>(shape_c) / G));
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
return make_tuple(shape_a_g, shape_b_g, shape_c_g,
|
| 328 |
+
stride_a_g, stride_b_g, stride_c_g);
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
// Executes one test
|
| 332 |
+
bool run(
|
| 333 |
+
ProblemShape const& problem_shape,
|
| 334 |
+
ElementScalar alpha = ElementScalar(1),
|
| 335 |
+
ElementScalar beta = ElementScalar(0),
|
| 336 |
+
dim3 cluster_shape = dim3(0, 0, 0),
|
| 337 |
+
dim3 cluster_shape_fallback = dim3(0, 0, 0),
|
| 338 |
+
RasterOrderOptions raster_order = RasterOrderOptions::Heuristic,
|
| 339 |
+
MaxSwizzleSize max_swizzle = MaxSwizzleSize{},
|
| 340 |
+
Splits splits = Splits{},
|
| 341 |
+
DecompositionMode decomposition_mode = DecompositionMode::Heuristic
|
| 342 |
+
) {
|
| 343 |
+
|
| 344 |
+
// Waive test if insufficient CUDA device
|
| 345 |
+
if (!sufficient()) {
|
| 346 |
+
if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
|
| 347 |
+
std::cerr << "Test waived due to insufficient CUDA device.\n";
|
| 348 |
+
}
|
| 349 |
+
return true;
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
bool ret = initialize(problem_shape);
|
| 353 |
+
|
| 354 |
+
if (!ret) {
|
| 355 |
+
std::cerr << "initialize failed for the given problem_shape: \n";
|
| 356 |
+
return false;
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
cutlass::KernelHardwareInfo hw_info;
|
| 360 |
+
cudaGetDevice(&hw_info.device_id);
|
| 361 |
+
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
| 362 |
+
|
| 363 |
+
hw_info.cluster_shape = cluster_shape;
|
| 364 |
+
hw_info.cluster_shape_fallback = cluster_shape_fallback;
|
| 365 |
+
|
| 366 |
+
// configure the operator
|
| 367 |
+
Conv conv_op;
|
| 368 |
+
auto stride_C = StrideC{};
|
| 369 |
+
auto stride_D = StrideD{};
|
| 370 |
+
if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) {
|
| 371 |
+
stride_C = cutlass::make_cute_packed_stride(
|
| 372 |
+
StrideC{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp);
|
| 373 |
+
stride_D = cutlass::make_cute_packed_stride(
|
| 374 |
+
StrideD{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp);
|
| 375 |
+
}
|
| 376 |
+
// Need to support non-packed output strides for fprop and dgrad kernel.
|
| 377 |
+
else {
|
| 378 |
+
cute::for_each(cute::make_seq<cute::rank<0>(StrideC{})>{}, [&](auto i) {
|
| 379 |
+
cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i];
|
| 380 |
+
});
|
| 381 |
+
cute::for_each(cute::make_seq<cute::rank<0>(StrideD{})>{}, [&](auto i) {
|
| 382 |
+
cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i];
|
| 383 |
+
});
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
|
| 387 |
+
using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
| 388 |
+
|
| 389 |
+
typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{};
|
| 390 |
+
if constexpr (cute::is_same_v<typename Conv::ConvKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
|
| 391 |
+
scheduler_args = { static_cast<int>(splits), static_cast<int>(max_swizzle), raster_order, decomposition_mode };
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
auto mainloop_args = params.get_mainloop_arguments(problem_shape, tensor_A, tensor_B);
|
| 395 |
+
|
| 396 |
+
auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments {
|
| 397 |
+
{},
|
| 398 |
+
tensor_C.data().get(),
|
| 399 |
+
stride_C,
|
| 400 |
+
tensor_D_computed.data().get(),
|
| 401 |
+
stride_D,
|
| 402 |
+
};
|
| 403 |
+
|
| 404 |
+
auto args = typename Conv::Arguments {
|
| 405 |
+
problem_shape,
|
| 406 |
+
mainloop_args, // MainloopArguments
|
| 407 |
+
epilogue_args, // EpilogueArguments
|
| 408 |
+
hw_info,
|
| 409 |
+
scheduler_args
|
| 410 |
+
};
|
| 411 |
+
|
| 412 |
+
auto &fusion_args = args.epilogue.thread;
|
| 413 |
+
|
| 414 |
+
fusion_args.alpha = alpha;
|
| 415 |
+
fusion_args.beta = beta;
|
| 416 |
+
|
| 417 |
+
if constexpr (IsPerChannelScaleEnabled) {
|
| 418 |
+
fusion_args.alpha_ptr = tensor_alpha.data().get();
|
| 419 |
+
fusion_args.beta_ptr = tensor_beta.data().get();
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
if constexpr (IsBiasEnabled) {
|
| 423 |
+
fusion_args.bias_ptr = tensor_bias.data().get();
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
// Clamp bound
|
| 427 |
+
if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) {
|
| 428 |
+
fusion_args.activation.lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits<ElementCompute>::lowest();
|
| 429 |
+
fusion_args.activation.upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits<ElementCompute>::max();
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
// Scale
|
| 433 |
+
if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledGELU_taylor<ElementCompute>> ||
|
| 434 |
+
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledGELU<ElementCompute>> ||
|
| 435 |
+
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledSiLu<ElementCompute>> ||
|
| 436 |
+
cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledHardSwish<ElementCompute>> ) {
|
| 437 |
+
fusion_args.activation.scale = ElementCompute{1};
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
// LeakyRelu
|
| 441 |
+
if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::LeakyReLU<ElementCompute>> ) {
|
| 442 |
+
fusion_args.activation.leaky_alpha = ElementCompute{0};
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
cutlass::Status status = cutlass::Status::kInvalid;
|
| 446 |
+
|
| 447 |
+
status = conv_op.can_implement(args);
|
| 448 |
+
EXPECT_EQ(conv_op.can_implement(args), cutlass::Status::kSuccess);
|
| 449 |
+
if (status != cutlass::Status::kSuccess) {
|
| 450 |
+
std::cerr << "can_implement failed for the given problem_shape: \n";
|
| 451 |
+
print(problem_shape);
|
| 452 |
+
return false;
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
// find workspace requirement for parallel split-k reduction
|
| 456 |
+
size_t workspace_size = Conv::get_workspace_size(args);
|
| 457 |
+
thrust::universal_vector<uint8_t> workspace(workspace_size);
|
| 458 |
+
|
| 459 |
+
status = conv_op.initialize(args, workspace.data().get());
|
| 460 |
+
if (status != cutlass::Status::kSuccess) {
|
| 461 |
+
cudaError_t error = cudaGetLastError();
|
| 462 |
+
std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
|
| 463 |
+
return true;
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
// run conv3d operator
|
| 467 |
+
status = conv_op();
|
| 468 |
+
|
| 469 |
+
EXPECT_TRUE(status == cutlass::Status::kSuccess);
|
| 470 |
+
if (status != cutlass::Status::kSuccess) {
|
| 471 |
+
return false;
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
bool passed = false;
|
| 475 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 476 |
+
EXPECT_EQ(result, cudaSuccess) << " Kernel execution error: "
|
| 477 |
+
<< cudaGetErrorString(result);
|
| 478 |
+
|
| 479 |
+
// Create cute::Tensors using the logical rank-3 MNK multi-mode shapes the mainloop gives us
|
| 480 |
+
auto [shape_mA, shape_mB, shape_mC, stride_mA, stride_mB, stride_mC] =
|
| 481 |
+
transform_shape_and_stride_with_groups(problem_shape);
|
| 482 |
+
auto shape_mBias = cute::make_shape(cute::size(cute::get<0>(problem_shape.get_shape_B())));
|
| 483 |
+
|
| 484 |
+
auto mA = make_tensor(tensor_A.data().get(), make_layout(shape_mA, stride_mA));
|
| 485 |
+
auto mB = make_tensor(tensor_B.data().get(), make_layout(shape_mB, stride_mB));
|
| 486 |
+
auto mC = make_tensor(tensor_C.data().get(), make_layout(shape_mC, stride_mC));
|
| 487 |
+
auto mD_ref = make_tensor(tensor_D_reference.data().get(), make_layout(shape_mC, stride_mC));
|
| 488 |
+
auto mD_computed = make_tensor(tensor_D_computed.data().get(), make_layout(shape_mC, stride_mC));
|
| 489 |
+
auto mBias = make_tensor(tensor_bias.data().get(), make_layout(shape_mBias));
|
| 490 |
+
auto mAlpha = make_tensor(tensor_alpha.data().get(), make_layout(shape_mBias));
|
| 491 |
+
auto mBeta = make_tensor(tensor_beta.data().get(), make_layout(shape_mBias));
|
| 492 |
+
|
| 493 |
+
cutlass::reference::host::ConvEpilogueFusionParams<
|
| 494 |
+
ElementAccumulator,
|
| 495 |
+
ElementScalar,
|
| 496 |
+
ElementCompute,
|
| 497 |
+
ElementC,
|
| 498 |
+
ElementD,
|
| 499 |
+
IsResidualEnabled,
|
| 500 |
+
decltype(mAlpha),
|
| 501 |
+
decltype(mBeta),
|
| 502 |
+
decltype(mBias),
|
| 503 |
+
ActivationFunctor>
|
| 504 |
+
epilogue_fusion_params{};
|
| 505 |
+
|
| 506 |
+
epilogue_fusion_params.alpha = alpha;
|
| 507 |
+
epilogue_fusion_params.beta = beta;
|
| 508 |
+
|
| 509 |
+
if constexpr (IsPerChannelScaleEnabled) {
|
| 510 |
+
epilogue_fusion_params.tensor_alpha = mAlpha;
|
| 511 |
+
epilogue_fusion_params.tensor_beta = mBeta;
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
if constexpr (IsBiasEnabled) {
|
| 515 |
+
epilogue_fusion_params.tensor_bias = mBias;
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
auto padding = cute::reverse(problem_shape.lower_padding);
|
| 519 |
+
auto tstride = cute::reverse(problem_shape.traversal_stride);
|
| 520 |
+
auto dilation = cute::reverse(problem_shape.dilation);
|
| 521 |
+
|
| 522 |
+
cutlass::reference::host::ConvReferenceImpl<
|
| 523 |
+
ConvOp,
|
| 524 |
+
NumSpatialDimensions,
|
| 525 |
+
decltype(mA),
|
| 526 |
+
decltype(mB),
|
| 527 |
+
decltype(mC),
|
| 528 |
+
decltype(mD_ref),
|
| 529 |
+
decltype(padding),
|
| 530 |
+
decltype(tstride),
|
| 531 |
+
decltype(dilation),
|
| 532 |
+
decltype(epilogue_fusion_params)>
|
| 533 |
+
reference_impl(mA, mB, mC, mD_ref, padding, tstride, dilation, epilogue_fusion_params);
|
| 534 |
+
|
| 535 |
+
//
|
| 536 |
+
// Reference check - support caching results
|
| 537 |
+
//
|
| 538 |
+
|
| 539 |
+
CachedTestKey cached_test_key = CreateCachedConvNd3xTestKey<
|
| 540 |
+
ProblemShape,
|
| 541 |
+
ElementA,
|
| 542 |
+
ElementB,
|
| 543 |
+
ElementC,
|
| 544 |
+
ElementD
|
| 545 |
+
>(
|
| 546 |
+
ConvOp,
|
| 547 |
+
problem_shape,
|
| 548 |
+
alpha,
|
| 549 |
+
beta,
|
| 550 |
+
tensor_A,
|
| 551 |
+
tensor_B,
|
| 552 |
+
tensor_C
|
| 553 |
+
);
|
| 554 |
+
|
| 555 |
+
//
|
| 556 |
+
// Look for the cached key
|
| 557 |
+
//
|
| 558 |
+
|
| 559 |
+
bool cached_result_loaded = false;
|
| 560 |
+
CachedTestResult cached_test_result;
|
| 561 |
+
|
| 562 |
+
std::string convnd_result_cache_name =
|
| 563 |
+
std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
|
| 564 |
+
|
| 565 |
+
#if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
|
| 566 |
+
CachedTestResultListing cached_results(convnd_result_cache_name);
|
| 567 |
+
|
| 568 |
+
auto cached = cached_results.find(cached_test_key);
|
| 569 |
+
|
| 570 |
+
cached_result_loaded = cached.first;
|
| 571 |
+
if (cached_result_loaded) {
|
| 572 |
+
cached_test_result = cached.second;
|
| 573 |
+
}
|
| 574 |
+
#endif
|
| 575 |
+
|
| 576 |
+
if (!cached_result_loaded) {
|
| 577 |
+
// Compute reference
|
| 578 |
+
reference_impl.compute_reference();
|
| 579 |
+
|
| 580 |
+
#if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
|
| 581 |
+
cached_test_result.D = TensorHash(tensor_D_reference);
|
| 582 |
+
CachedTestResultListing cached_results(convnd_result_cache_name);
|
| 583 |
+
|
| 584 |
+
cached_results.append(cached_test_key, cached_test_result);
|
| 585 |
+
cached_results.write(convnd_result_cache_name);
|
| 586 |
+
#endif
|
| 587 |
+
} // if (!cached_result_loaded)
|
| 588 |
+
|
| 589 |
+
#if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
|
| 590 |
+
uint32_t tensor_D_computed_hash = TensorHash(tensor_D_computed);
|
| 591 |
+
passed = (tensor_D_computed_hash == cached_test_result.D);
|
| 592 |
+
// If hash fails, double check against reference implementation.
|
| 593 |
+
if(!passed) {
|
| 594 |
+
std::cerr << "Hash-based comparison unsuccessful for key:" << "\n" << cached_test_key
|
| 595 |
+
<< ", comparing with reference implementation now.\n";
|
| 596 |
+
if (cached_result_loaded) {
|
| 597 |
+
// Compute reference
|
| 598 |
+
reference_impl.compute_reference();
|
| 599 |
+
}
|
| 600 |
+
// Validate kernel against reference
|
| 601 |
+
passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon);
|
| 602 |
+
}
|
| 603 |
+
#else
|
| 604 |
+
// Validate kernel against reference
|
| 605 |
+
passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon);
|
| 606 |
+
#endif
|
| 607 |
+
|
| 608 |
+
EXPECT_TRUE(passed);
|
| 609 |
+
return passed;
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
template<
|
| 613 |
+
class Engine, class Layout,
|
| 614 |
+
class EngineA, class LayoutA,
|
| 615 |
+
class EngineB, class LayoutB,
|
| 616 |
+
class EngineAlpha, class LayoutAlpha,
|
| 617 |
+
class EngineBeta, class LayoutBeta,
|
| 618 |
+
class EngineBias, class LayoutBias>
|
| 619 |
+
static constexpr bool
|
| 620 |
+
compare_reference(
|
| 621 |
+
cute::Tensor<Engine, Layout> const& reference,
|
| 622 |
+
cute::Tensor<Engine, Layout> const& computed,
|
| 623 |
+
cute::Tensor<EngineA, LayoutA> const& A,
|
| 624 |
+
cute::Tensor<EngineB, LayoutB> const& B,
|
| 625 |
+
cute::Tensor<EngineAlpha, LayoutAlpha> const& tensor_alpha,
|
| 626 |
+
cute::Tensor<EngineBeta, LayoutBeta> const& tensor_beta,
|
| 627 |
+
cute::Tensor<EngineBias, LayoutBias> const& tensor_bias,
|
| 628 |
+
float epsilon = 0.0f) {
|
| 629 |
+
if (size(reference) != size(computed)) {
|
| 630 |
+
return false;
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
bool passed = true;
|
| 634 |
+
if (epsilon == 0.0f) {
|
| 635 |
+
// fast refcheck w/o epsilon
|
| 636 |
+
for (size_t i = 0; i < size_t(size(reference)); ++i) {
|
| 637 |
+
if (reference(i) != computed(i)) {
|
| 638 |
+
passed = false;
|
| 639 |
+
printf("[%llu] %f, %f\n", static_cast<unsigned long long>(i),
|
| 640 |
+
float(reference(i)), float(computed(i)));
|
| 641 |
+
break;
|
| 642 |
+
}
|
| 643 |
+
}
|
| 644 |
+
} else {
|
| 645 |
+
// refcheck with epsilon
|
| 646 |
+
for (size_t i = 0; i < size_t(size(reference)); ++i) {
|
| 647 |
+
auto ref = static_cast<float>(reference(i));
|
| 648 |
+
auto act = static_cast<float>(computed(i));
|
| 649 |
+
auto abs_error = std::abs(act - ref);
|
| 650 |
+
auto rel_error = abs_error / (std::max(std::abs(act), std::abs(ref)) + 0.00001f);
|
| 651 |
+
if (std::isnan(abs_error) || std::isnan(rel_error) ||
|
| 652 |
+
std::min(abs_error, rel_error) > epsilon) {
|
| 653 |
+
passed = false;
|
| 654 |
+
printf("[%llu] %f, %f\n", static_cast<unsigned long long>(i),
|
| 655 |
+
float(reference(i)), float(computed(i)));
|
| 656 |
+
break;
|
| 657 |
+
}
|
| 658 |
+
}
|
| 659 |
+
}
|
| 660 |
+
#if CUTLASS_DEBUG_TRACE_LEVEL > 1
|
| 661 |
+
if (not passed) {
|
| 662 |
+
cute::print("Reference:");
|
| 663 |
+
cute::print_tensor(reference);
|
| 664 |
+
cute::print("\nComputed:");
|
| 665 |
+
cute::print_tensor(computed);
|
| 666 |
+
cute::print("\n");
|
| 667 |
+
|
| 668 |
+
for (size_t i = 0; i < size_t(size(A)); ++i) {
|
| 669 |
+
printf("[%llu]: A = %f\n", static_cast<unsigned long long>(i), float(A(i)));
|
| 670 |
+
}
|
| 671 |
+
for (size_t i = 0; i < size_t(size(B)); ++i) {
|
| 672 |
+
printf("[%llu]: B = %f\n", static_cast<unsigned long long>(i), float(B(i)));
|
| 673 |
+
}
|
| 674 |
+
if constexpr (IsPerChannelScaleEnabled) {
|
| 675 |
+
for (size_t i = 0; i < size_t(size(tensor_alpha)); ++i) {
|
| 676 |
+
printf("[%llu]: alpha = %f\n", static_cast<unsigned long long>(i),
|
| 677 |
+
float(tensor_alpha(i)));
|
| 678 |
+
}
|
| 679 |
+
for (size_t i = 0; i < size_t(size(tensor_beta)); ++i) {
|
| 680 |
+
printf("[%llu]: beta = %f\n", static_cast<unsigned long long>(i),
|
| 681 |
+
float(tensor_beta(i)));
|
| 682 |
+
}
|
| 683 |
+
}
|
| 684 |
+
if constexpr (IsBiasEnabled) {
|
| 685 |
+
for (size_t i = 0; i < size_t(size(tensor_bias)); ++i) {
|
| 686 |
+
printf("[%llu]: bias = %f\n", static_cast<unsigned long long>(i),
|
| 687 |
+
float(tensor_bias(i)));
|
| 688 |
+
}
|
| 689 |
+
}
|
| 690 |
+
for (size_t i = 0; i < size_t(size(reference)); ++i) {
|
| 691 |
+
printf("[%llu]: ref = %f, computed = %f\n", static_cast<unsigned long long>(i),
|
| 692 |
+
float(reference(i)), float(computed(i)));
|
| 693 |
+
}
|
| 694 |
+
}
|
| 695 |
+
#endif
|
| 696 |
+
return passed;
|
| 697 |
+
}
|
| 698 |
+
};
|
| 699 |
+
|
| 700 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 701 |
+
|
| 702 |
+
template <typename Conv, bool SupportStrides = (Conv::DispatchPolicy::ConvOp != cutlass::conv::Operator::kDgrad)>
|
| 703 |
+
bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f,
|
| 704 |
+
dim3 cluster_shape = dim3(0, 0, 0),
|
| 705 |
+
dim3 cluster_shape_fallback = dim3(0, 0, 0)
|
| 706 |
+
) {
|
| 707 |
+
using ElementScalar = typename Conv::EpilogueOutputOp::ElementScalar;
|
| 708 |
+
|
| 709 |
+
bool passed = true;
|
| 710 |
+
ConvTestbed<Conv> testbed;
|
| 711 |
+
testbed.epsilon = epsilon;
|
| 712 |
+
auto problem_vector = get_conv_problem_vector<
|
| 713 |
+
Conv::NumSpatialDimensions, Conv::DispatchPolicy::ConvOp, SupportStrides>();
|
| 714 |
+
|
| 715 |
+
using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
|
| 716 |
+
using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
|
| 717 |
+
using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize;
|
| 718 |
+
using Splits = typename gemm::device::detail::Splits;
|
| 719 |
+
|
| 720 |
+
std::vector<DecompositionMode> decomposition_modes = {DecompositionMode::Heuristic};
|
| 721 |
+
static constexpr bool UsesStreamKScheduler = cute::is_same_v<typename Conv::ConvKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>;
|
| 722 |
+
if constexpr (UsesStreamKScheduler) {
|
| 723 |
+
decomposition_modes.push_back(DecompositionMode::DataParallel);
|
| 724 |
+
decomposition_modes.push_back(DecompositionMode::SplitK);
|
| 725 |
+
decomposition_modes.push_back(DecompositionMode::StreamK);
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
for (auto conv_problem : problem_vector) {
|
| 729 |
+
#if CUTLASS_DEBUG_TRACE_LEVEL > 0
|
| 730 |
+
print(conv_problem);
|
| 731 |
+
#endif
|
| 732 |
+
for (DecompositionMode decomp_mode : decomposition_modes) {
|
| 733 |
+
std::vector problem_splits = {Splits{1}};
|
| 734 |
+
if constexpr (UsesStreamKScheduler) {
|
| 735 |
+
if (decomp_mode == DecompositionMode::SplitK) {
|
| 736 |
+
problem_splits.push_back(Splits{2});
|
| 737 |
+
problem_splits.push_back(Splits{4});
|
| 738 |
+
}
|
| 739 |
+
}
|
| 740 |
+
for (auto splits : problem_splits) {
|
| 741 |
+
|
| 742 |
+
passed = testbed.run(
|
| 743 |
+
conv_problem,
|
| 744 |
+
cutlass::from_real<ElementScalar>(alpha),
|
| 745 |
+
cutlass::from_real<ElementScalar>(beta),
|
| 746 |
+
cluster_shape,
|
| 747 |
+
cluster_shape_fallback,
|
| 748 |
+
RasterOrderOptions::Heuristic, // raster_order
|
| 749 |
+
MaxSwizzleSize(1),
|
| 750 |
+
splits,
|
| 751 |
+
decomp_mode
|
| 752 |
+
);
|
| 753 |
+
if (!passed) {
|
| 754 |
+
printf("Failed test for "); print(conv_problem);
|
| 755 |
+
return false;
|
| 756 |
+
}
|
| 757 |
+
} // splits
|
| 758 |
+
} // decomposition_mode
|
| 759 |
+
}
|
| 760 |
+
|
| 761 |
+
return passed;
|
| 762 |
+
}
|
| 763 |
+
|
| 764 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
| 765 |
+
|
| 766 |
+
} // namespace test::conv::device
|
| 767 |
+
|
| 768 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#include "cutlass_unit_test.h"
|
| 33 |
+
|
| 34 |
+
#include <iostream>
|
| 35 |
+
#include <iomanip>
|
| 36 |
+
#include <utility>
|
| 37 |
+
#include <type_traits>
|
| 38 |
+
#include <vector>
|
| 39 |
+
#include <numeric>
|
| 40 |
+
|
| 41 |
+
#include <thrust/host_vector.h>
|
| 42 |
+
#include <thrust/device_vector.h>
|
| 43 |
+
|
| 44 |
+
#include <cute/tensor.hpp>
|
| 45 |
+
|
| 46 |
+
using namespace cute;
|
| 47 |
+
|
| 48 |
+
template <class ElementType, class SmemLayout>
|
| 49 |
+
struct SharedStorage
|
| 50 |
+
{
|
| 51 |
+
cute::ArrayEngine<ElementType, cute::cosize_v<SmemLayout>> smem;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
template <class T, class TiledCopy, class GmemLayout, class SmemLayout>
|
| 55 |
+
__global__ void
|
| 56 |
+
test_tiled_cp_async_device_cute(T const* g_in, T* g_out,
|
| 57 |
+
TiledCopy const tiled_copy,
|
| 58 |
+
GmemLayout gmem_layout, SmemLayout smem_layout)
|
| 59 |
+
{
|
| 60 |
+
using namespace cute;
|
| 61 |
+
|
| 62 |
+
extern __shared__ char shared_memory[];
|
| 63 |
+
using SharedStorage = SharedStorage<T, SmemLayout>;
|
| 64 |
+
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
| 65 |
+
|
| 66 |
+
auto thr_copy = tiled_copy.get_slice(threadIdx.x);
|
| 67 |
+
Tensor gA = make_tensor(make_gmem_ptr(g_in), gmem_layout);
|
| 68 |
+
Tensor gB = make_tensor(make_gmem_ptr(g_out), gmem_layout);
|
| 69 |
+
|
| 70 |
+
// Construct SMEM tensor
|
| 71 |
+
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout);
|
| 72 |
+
|
| 73 |
+
auto tAgA = thr_copy.partition_S(gA);
|
| 74 |
+
auto tAsA = thr_copy.partition_D(sA);
|
| 75 |
+
|
| 76 |
+
#if 0
|
| 77 |
+
if (thread0()) {
|
| 78 |
+
print("gA : "); print(gA.layout()); print("\n");
|
| 79 |
+
print("sA : "); print(sA.layout()); print("\n");
|
| 80 |
+
print("tAgA: "); print(tAgA.layout()); print("\n");
|
| 81 |
+
print("tAsA: "); print(tAsA.layout()); print("\n");
|
| 82 |
+
}
|
| 83 |
+
#endif
|
| 84 |
+
|
| 85 |
+
copy(tiled_copy, tAgA, tAsA);
|
| 86 |
+
|
| 87 |
+
cp_async_fence();
|
| 88 |
+
cp_async_wait<0>();
|
| 89 |
+
__syncthreads();
|
| 90 |
+
|
| 91 |
+
// Store trivially smem -> gmem
|
| 92 |
+
|
| 93 |
+
if (thread0()) {
|
| 94 |
+
copy(sA, gB);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
template <class T, class TiledCopy, class GMEM_Layout, class SMEM_Layout>
|
| 100 |
+
void
|
| 101 |
+
test_tiled_cp_async(
|
| 102 |
+
TiledCopy const tiled_copy,
|
| 103 |
+
GMEM_Layout const& gmem_layout,
|
| 104 |
+
SMEM_Layout const& smem_layout)
|
| 105 |
+
{
|
| 106 |
+
using namespace cute;
|
| 107 |
+
|
| 108 |
+
// Allocate and initialize host test data
|
| 109 |
+
size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits<T>::value, 8);
|
| 110 |
+
thrust::host_vector<T> h_in(N);
|
| 111 |
+
Tensor hA_in = make_tensor(recast_ptr<T>(h_in.data()), gmem_layout);
|
| 112 |
+
for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast<T>(i % 13); }
|
| 113 |
+
|
| 114 |
+
// Allocate and initialize device test data
|
| 115 |
+
thrust::device_vector<T> d_in = h_in;
|
| 116 |
+
thrust::device_vector<T> d_out(h_in.size(), T(-1));
|
| 117 |
+
|
| 118 |
+
// Launch
|
| 119 |
+
int smem_size = int(sizeof(SharedStorage<T, decltype(smem_layout)>));
|
| 120 |
+
test_tiled_cp_async_device_cute<<<1, 128, smem_size>>>(
|
| 121 |
+
reinterpret_cast<T const*>(raw_pointer_cast(d_in.data())),
|
| 122 |
+
reinterpret_cast<T*> (raw_pointer_cast(d_out.data())),
|
| 123 |
+
tiled_copy,
|
| 124 |
+
gmem_layout,
|
| 125 |
+
smem_layout);
|
| 126 |
+
|
| 127 |
+
// Copy results back to host
|
| 128 |
+
thrust::host_vector<T> h_out = d_out;
|
| 129 |
+
Tensor hA_out = make_tensor(recast_ptr<T>(h_out.data()), gmem_layout);
|
| 130 |
+
|
| 131 |
+
// Validate the results. Print only the first 3 errors.
|
| 132 |
+
int count = 3;
|
| 133 |
+
for (int i = 0; i < size(hA_out) && count > 0; ++i) {
|
| 134 |
+
EXPECT_EQ(hA_in(i), hA_out(i));
|
| 135 |
+
if (hA_in(i) != hA_out(i)) {
|
| 136 |
+
--count;
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
template <typename T, typename M, typename N, typename GMEM_STRIDE_TYPE, typename SMEM_LAYOUT, typename TILED_COPY>
|
| 142 |
+
void test_cp_async_no_swizzle() {
|
| 143 |
+
using namespace cute;
|
| 144 |
+
auto smem_atom = SMEM_LAYOUT{};
|
| 145 |
+
auto smem_layout = tile_to_shape(smem_atom, Shape<M, N>{});
|
| 146 |
+
auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{});
|
| 147 |
+
test_tiled_cp_async<T>(TILED_COPY{}, gmem_layout, smem_layout);
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
template <typename T, typename M, typename N, typename GMEM_STRIDE_TYPE, typename SWIZZLE_ATOM, typename SMEM_LAYOUT, typename TILED_COPY>
|
| 151 |
+
void test_cp_async_with_swizzle() {
|
| 152 |
+
using namespace cute;
|
| 153 |
+
auto swizzle_atom = SWIZZLE_ATOM{};
|
| 154 |
+
auto smem_atom = composition(swizzle_atom, SMEM_LAYOUT{});
|
| 155 |
+
auto smem_layout = tile_to_shape(smem_atom, Shape<M, N>{});
|
| 156 |
+
auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{});
|
| 157 |
+
test_tiled_cp_async<T>(TILED_COPY{}, gmem_layout, smem_layout);
|
| 158 |
+
}
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp
ADDED
|
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include "cutlass/relatively_equal.h"
|
| 35 |
+
#include "cutlass_unit_test.h"
|
| 36 |
+
#include "cutlass/util/reference/host/tensor_compare.h"
|
| 37 |
+
|
| 38 |
+
#include <iostream>
|
| 39 |
+
|
| 40 |
+
#include <thrust/host_vector.h>
|
| 41 |
+
#include <thrust/device_vector.h>
|
| 42 |
+
|
| 43 |
+
#include <cute/tensor.hpp>
|
| 44 |
+
|
| 45 |
+
using namespace cute;
|
| 46 |
+
|
| 47 |
+
template<typename T>
|
| 48 |
+
struct fp64_tester {
|
| 49 |
+
using value_type = double;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
template<typename T>
|
| 53 |
+
struct fp64_tester<complex<T>> {
|
| 54 |
+
using value_type = complex<double>;
|
| 55 |
+
};
|
| 56 |
+
|
| 57 |
+
template<class TA,
|
| 58 |
+
class TB,
|
| 59 |
+
class TC,
|
| 60 |
+
class ALayout, // logical shape (M, K)
|
| 61 |
+
class BLayout, // logical shape (N, K)
|
| 62 |
+
class CLayout> // logical shape (M, N)
|
| 63 |
+
auto host_generate_gemm_inputs(
|
| 64 |
+
ALayout a_layout,
|
| 65 |
+
BLayout b_layout,
|
| 66 |
+
CLayout c_layout
|
| 67 |
+
) {
|
| 68 |
+
thrust::host_vector<TA> h_a(cosize(a_layout));
|
| 69 |
+
thrust::host_vector<TB> h_b(cosize(b_layout));
|
| 70 |
+
thrust::host_vector<TC> h_c(cosize(c_layout));
|
| 71 |
+
thrust::host_vector<TC> h_c_out(cosize(c_layout));
|
| 72 |
+
|
| 73 |
+
auto h_a_tensor = make_tensor(h_a.data(), a_layout);
|
| 74 |
+
auto h_b_tensor = make_tensor(h_b.data(), b_layout);
|
| 75 |
+
auto h_c_tensor = make_tensor(h_c.data(), c_layout);
|
| 76 |
+
size_t max_size = std::max<size_t>({static_cast<size_t>(size(a_layout)),
|
| 77 |
+
static_cast<size_t>(size(b_layout)),
|
| 78 |
+
static_cast<size_t>(size(c_layout))});
|
| 79 |
+
for (size_t i = 0; i < max_size; ++i) {
|
| 80 |
+
double di = static_cast<double>(i);
|
| 81 |
+
if(i < size(a_layout)) {
|
| 82 |
+
h_a_tensor(i) = static_cast<TA>(di / size(a_layout));
|
| 83 |
+
}
|
| 84 |
+
if(i < size(b_layout)) {
|
| 85 |
+
h_b_tensor(i) = static_cast<TB>(di / size(a_layout));
|
| 86 |
+
}
|
| 87 |
+
if(i < size(c_layout)) {
|
| 88 |
+
h_c_tensor(i) = static_cast<TC>((di*di) / size(a_layout));
|
| 89 |
+
}
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
return std::make_tuple(h_a, h_b, h_c, h_c_out);
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
template<class Alpha, class EngineA, class ALayout,
|
| 96 |
+
class EngineB, class BLayout,
|
| 97 |
+
class Beta, class EngineC, class CLayout,
|
| 98 |
+
class ALoadTransform = cute::identity,
|
| 99 |
+
class BLoadTransform = cute::identity,
|
| 100 |
+
class CLoadTransform = cute::identity,
|
| 101 |
+
class CStoreTransform = cute::identity>
|
| 102 |
+
thrust::host_vector<typename EngineC::value_type>
|
| 103 |
+
host_reference_gemm(Alpha alpha,
|
| 104 |
+
Tensor<EngineA, ALayout> const& h_a_tensor,
|
| 105 |
+
Tensor<EngineB, BLayout> const& h_b_tensor,
|
| 106 |
+
Beta beta,
|
| 107 |
+
Tensor<EngineC, CLayout> const& h_c_tensor,
|
| 108 |
+
ALoadTransform const& a_load_transform = {},
|
| 109 |
+
BLoadTransform const& b_load_transform = {},
|
| 110 |
+
CLoadTransform const& c_load_transform = {},
|
| 111 |
+
CStoreTransform const& c_store_transform = {})
|
| 112 |
+
{
|
| 113 |
+
// Cannot use ::value_type because it propagates to complex::value_type,
|
| 114 |
+
// so ViewEngine<complex<double>>::value_type == double
|
| 115 |
+
using TA = remove_cv_t<typename EngineA::element_type>;
|
| 116 |
+
using TB = remove_cv_t<typename EngineB::element_type>;
|
| 117 |
+
using TC = remove_cv_t<typename EngineC::element_type>;
|
| 118 |
+
|
| 119 |
+
using tester = fp64_tester<TC>;
|
| 120 |
+
using ABC_64 = typename tester::value_type;
|
| 121 |
+
|
| 122 |
+
static_assert(std::is_same_v<typename fp64_tester<TA>::value_type, typename fp64_tester<TB>::value_type>);
|
| 123 |
+
static_assert(std::is_same_v<typename fp64_tester<TB>::value_type, typename fp64_tester<TC>::value_type>);
|
| 124 |
+
|
| 125 |
+
thrust::host_vector<TC> h_c_ref(cosize(h_c_tensor.layout()), static_cast<TC>(0.0));
|
| 126 |
+
auto h_c_ref_tensor = make_tensor(h_c_ref.data(), h_c_tensor.layout());
|
| 127 |
+
// A * B
|
| 128 |
+
for (int k = 0; k < size<1>(h_a_tensor); k++) {
|
| 129 |
+
for (int m = 0; m < size<0>(h_a_tensor); m++) {
|
| 130 |
+
for (int n = 0; n < size<0>(h_b_tensor); n++) {
|
| 131 |
+
const auto a_value = a_load_transform(h_a_tensor(m, k));
|
| 132 |
+
const auto b_value = b_load_transform(h_b_tensor(n, k));
|
| 133 |
+
const auto a_value_fp64 = static_cast<ABC_64>(a_value);
|
| 134 |
+
const auto b_value_fp64 = static_cast<ABC_64>(b_value);
|
| 135 |
+
h_c_ref_tensor(m, n) += static_cast<TC>(a_value_fp64 * b_value_fp64);
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
// C = A*B + C
|
| 140 |
+
for (int i = 0; i < size(h_c_ref_tensor); i++) {
|
| 141 |
+
const auto ab_value_fp64 = static_cast<ABC_64>(h_c_ref_tensor(i));
|
| 142 |
+
const auto c_value_fp64 = static_cast<ABC_64>(c_load_transform(h_c_tensor(i)));
|
| 143 |
+
h_c_ref_tensor(i) = c_store_transform(static_cast<TC>(alpha * ab_value_fp64 + beta * c_value_fp64));
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
return h_c_ref;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
template<class EngineC, class CLayout>
|
| 150 |
+
void verify_gemm_correctness(cute::Tensor<EngineC, CLayout> const& h_c_out_tensor,
|
| 151 |
+
cute::Tensor<EngineC, CLayout> const& h_c_ref_tensor)
|
| 152 |
+
{
|
| 153 |
+
// Cannot use ::value_type because it propagates to complex::value_type,
|
| 154 |
+
// so ViewEngine<complex<double>>::value_type == double
|
| 155 |
+
using TC = remove_cv_t<typename EngineC::element_type>;
|
| 156 |
+
|
| 157 |
+
using tester = fp64_tester<TC>;
|
| 158 |
+
using ABC_64 = typename tester::value_type;
|
| 159 |
+
|
| 160 |
+
for (int i = 0; i < size(h_c_ref_tensor); i++) {
|
| 161 |
+
ABC_64 h_c_ref_i = h_c_ref_tensor(i);
|
| 162 |
+
ABC_64 h_c_out_i = h_c_out_tensor(i);
|
| 163 |
+
double epsilon(0.1f);
|
| 164 |
+
double nonzero_floor(std::numeric_limits<double>::min());
|
| 165 |
+
bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor);
|
| 166 |
+
ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i;
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
template<uint32_t ThreadBlockSize,
|
| 172 |
+
uint32_t CopyMaxVecBits,
|
| 173 |
+
class GMemALayout,
|
| 174 |
+
class GMemBLayout,
|
| 175 |
+
class GMemCLayout,
|
| 176 |
+
class SMemALayout,
|
| 177 |
+
class SMemBLayout,
|
| 178 |
+
class SMemCLayout,
|
| 179 |
+
class TA,
|
| 180 |
+
class TB,
|
| 181 |
+
class TC,
|
| 182 |
+
class Alpha,
|
| 183 |
+
class Beta,
|
| 184 |
+
class TiledMma,
|
| 185 |
+
class ALoadTransform,
|
| 186 |
+
class BLoadTransform,
|
| 187 |
+
class CLoadTransform,
|
| 188 |
+
class CStoreTransform,
|
| 189 |
+
class SMemCopyOpA,
|
| 190 |
+
class SMemCopyOpB,
|
| 191 |
+
class SMemCopyLdOpC,
|
| 192 |
+
class SMemCopyStOpC>
|
| 193 |
+
__launch_bounds__(ThreadBlockSize) __global__ void
|
| 194 |
+
cooperative_gemm_kernel(GMemALayout gmem_a_layout,
|
| 195 |
+
GMemBLayout gmem_b_layout,
|
| 196 |
+
GMemCLayout gmem_c_layout,
|
| 197 |
+
SMemALayout smem_a_layout,
|
| 198 |
+
SMemBLayout smem_b_layout,
|
| 199 |
+
SMemCLayout smem_c_layout,
|
| 200 |
+
TA const* a,
|
| 201 |
+
TB const* b,
|
| 202 |
+
TC const* c,
|
| 203 |
+
TC * c_out,
|
| 204 |
+
Alpha const alpha,
|
| 205 |
+
Beta const beta,
|
| 206 |
+
TiledMma tiled_mma,
|
| 207 |
+
ALoadTransform a_load_transform,
|
| 208 |
+
BLoadTransform b_load_transform,
|
| 209 |
+
CLoadTransform c_load_transform,
|
| 210 |
+
CStoreTransform c_store_transform,
|
| 211 |
+
SMemCopyOpA a_copy_op,
|
| 212 |
+
SMemCopyOpB b_copy_op,
|
| 213 |
+
SMemCopyLdOpC c_copy_ld_op,
|
| 214 |
+
SMemCopyStOpC c_copy_st_op)
|
| 215 |
+
{
|
| 216 |
+
using namespace cute;
|
| 217 |
+
|
| 218 |
+
Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout);
|
| 219 |
+
Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout);
|
| 220 |
+
Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout);
|
| 221 |
+
Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout);
|
| 222 |
+
|
| 223 |
+
constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8;
|
| 224 |
+
|
| 225 |
+
extern __shared__ float4 smem_buf[];
|
| 226 |
+
auto* smem_ptr = reinterpret_cast<unsigned char*>(smem_buf);
|
| 227 |
+
auto* smem_ptr_a = smem_ptr;
|
| 228 |
+
auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes);
|
| 229 |
+
auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(smem_b_layout)), copy_max_vec_bytes);
|
| 230 |
+
|
| 231 |
+
Tensor s_a_tensor = make_tensor(make_smem_ptr<TA>(smem_ptr_a), smem_a_layout);
|
| 232 |
+
Tensor s_b_tensor = make_tensor(make_smem_ptr<TB>(smem_ptr_b), smem_b_layout);
|
| 233 |
+
Tensor s_c_tensor = make_tensor(make_smem_ptr<TC>(smem_ptr_c), smem_c_layout);
|
| 234 |
+
|
| 235 |
+
cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_a_tensor, s_a_tensor);
|
| 236 |
+
cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_b_tensor, s_b_tensor);
|
| 237 |
+
cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_c_tensor, s_c_tensor);
|
| 238 |
+
|
| 239 |
+
cp_async_fence();
|
| 240 |
+
cp_async_wait<0>();
|
| 241 |
+
__syncthreads();
|
| 242 |
+
|
| 243 |
+
cooperative_gemm(
|
| 244 |
+
threadIdx.x, tiled_mma,
|
| 245 |
+
alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor,
|
| 246 |
+
a_load_transform, b_load_transform, c_load_transform, c_store_transform,
|
| 247 |
+
a_copy_op, b_copy_op, c_copy_ld_op, c_copy_st_op
|
| 248 |
+
);
|
| 249 |
+
__syncthreads();
|
| 250 |
+
|
| 251 |
+
cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, s_c_tensor, g_c_out_tensor);
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
template<uint32_t ThreadBlockSize,
|
| 255 |
+
uint32_t CopyMaxVecBits,
|
| 256 |
+
class GMemALayout,
|
| 257 |
+
class GMemBLayout,
|
| 258 |
+
class GMemCLayout,
|
| 259 |
+
class SMemALayout,
|
| 260 |
+
class SMemBLayout,
|
| 261 |
+
class TA,
|
| 262 |
+
class TB,
|
| 263 |
+
class TC,
|
| 264 |
+
class TiledMma,
|
| 265 |
+
class ALoadTransform,
|
| 266 |
+
class BLoadTransform,
|
| 267 |
+
class CLoadTransform,
|
| 268 |
+
class CStoreTransform,
|
| 269 |
+
class SMemCopyOpA,
|
| 270 |
+
class SMemCopyOpB>
|
| 271 |
+
__launch_bounds__(ThreadBlockSize) __global__ void
|
| 272 |
+
cooperative_gemm_kernel_rmem_c(GMemALayout gmem_a_layout,
|
| 273 |
+
GMemBLayout gmem_b_layout,
|
| 274 |
+
GMemCLayout gmem_c_layout,
|
| 275 |
+
SMemALayout smem_a_layout,
|
| 276 |
+
SMemBLayout smem_b_layout,
|
| 277 |
+
TA const* a,
|
| 278 |
+
TB const* b,
|
| 279 |
+
TC const* c,
|
| 280 |
+
TC * c_out,
|
| 281 |
+
TiledMma tiled_mma,
|
| 282 |
+
ALoadTransform a_load_transform,
|
| 283 |
+
BLoadTransform b_load_transform,
|
| 284 |
+
CLoadTransform c_load_transform,
|
| 285 |
+
CStoreTransform c_store_transform,
|
| 286 |
+
SMemCopyOpA a_copy_op,
|
| 287 |
+
SMemCopyOpB b_copy_op)
|
| 288 |
+
{
|
| 289 |
+
using namespace cute;
|
| 290 |
+
|
| 291 |
+
Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout);
|
| 292 |
+
Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout);
|
| 293 |
+
Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout);
|
| 294 |
+
Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout);
|
| 295 |
+
|
| 296 |
+
constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8;
|
| 297 |
+
|
| 298 |
+
extern __shared__ float4 smem_buf[];
|
| 299 |
+
auto* smem_ptr = reinterpret_cast<unsigned char*>(smem_buf);
|
| 300 |
+
auto* smem_ptr_a = smem_ptr;
|
| 301 |
+
auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes);
|
| 302 |
+
|
| 303 |
+
Tensor s_a_tensor = make_tensor(make_smem_ptr<TA>(smem_ptr_a), smem_a_layout);
|
| 304 |
+
Tensor s_b_tensor = make_tensor(make_smem_ptr<TB>(smem_ptr_b), smem_b_layout);
|
| 305 |
+
|
| 306 |
+
cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_a_tensor, s_a_tensor);
|
| 307 |
+
cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_b_tensor, s_b_tensor);
|
| 308 |
+
|
| 309 |
+
cp_async_fence();
|
| 310 |
+
cp_async_wait<0>();
|
| 311 |
+
__syncthreads();
|
| 312 |
+
|
| 313 |
+
// Create C fragment for storing intermediate results
|
| 314 |
+
auto thr_mma = TiledMma().get_thread_slice(threadIdx.x);
|
| 315 |
+
Tensor g_c_partition = thr_mma.partition_C(g_c_tensor);
|
| 316 |
+
Tensor g_c_out_partition = thr_mma.partition_C(g_c_out_tensor);
|
| 317 |
+
Tensor r_c_partition = thr_mma.make_fragment_C(g_c_partition);
|
| 318 |
+
|
| 319 |
+
// Create indexing help for predicated GEMMs
|
| 320 |
+
Tensor cC = make_identity_tensor(shape(gmem_c_layout));
|
| 321 |
+
Tensor tCcC = thr_mma.partition_C(cC);
|
| 322 |
+
|
| 323 |
+
// Load C from global
|
| 324 |
+
// (always loading in predicated way)
|
| 325 |
+
CUTE_UNROLL
|
| 326 |
+
for (int i = 0; i < size(r_c_partition); ++i)
|
| 327 |
+
{
|
| 328 |
+
if (elem_less(tCcC(i), shape(g_c_tensor)))
|
| 329 |
+
{
|
| 330 |
+
r_c_partition(i) = c_load_transform(g_c_partition(i));
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
cooperative_gemm(
|
| 335 |
+
threadIdx.x, tiled_mma, s_a_tensor, s_b_tensor, r_c_partition,
|
| 336 |
+
a_load_transform, b_load_transform, a_copy_op, b_copy_op
|
| 337 |
+
);
|
| 338 |
+
|
| 339 |
+
__syncthreads();
|
| 340 |
+
|
| 341 |
+
// Store C to global
|
| 342 |
+
// (always storing in predicated way)
|
| 343 |
+
CUTE_UNROLL
|
| 344 |
+
for (int i = 0; i < size(r_c_partition); ++i)
|
| 345 |
+
{
|
| 346 |
+
if (elem_less(tCcC(i), shape(g_c_tensor)))
|
| 347 |
+
{
|
| 348 |
+
g_c_out_partition(i) = c_store_transform(r_c_partition(i));
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
template<uint32_t ThreadBlockSize,
|
| 354 |
+
uint32_t CopyMaxVecBits,
|
| 355 |
+
class TA,
|
| 356 |
+
class TB,
|
| 357 |
+
class TC,
|
| 358 |
+
class GMemALayout, // logical shape (M, K)
|
| 359 |
+
class GMemBLayout, // logical shape (N, K)
|
| 360 |
+
class GMemCLayout, // logical shape (M, N)
|
| 361 |
+
class SMemALayout, // logical shape (M, K)
|
| 362 |
+
class SMemBLayout, // logical shape (N, K)
|
| 363 |
+
class SMemCLayout, // logical shape (M, N)
|
| 364 |
+
class TiledMma,
|
| 365 |
+
class ALoadTransform = cute::identity,
|
| 366 |
+
class BLoadTransform = cute::identity,
|
| 367 |
+
class CLoadTransform = cute::identity,
|
| 368 |
+
class CStoreTransform = cute::identity,
|
| 369 |
+
class ASMemCopyOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>,
|
| 370 |
+
class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>,
|
| 371 |
+
class CSMemCopyLdOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>,
|
| 372 |
+
class CSMemCopyStOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>>
|
| 373 |
+
void test_cooperative_gemm(GMemALayout gmem_a_layout,
|
| 374 |
+
GMemBLayout gmem_b_layout,
|
| 375 |
+
GMemCLayout gmem_c_layout,
|
| 376 |
+
SMemALayout smem_a_layout,
|
| 377 |
+
SMemBLayout smem_b_layout,
|
| 378 |
+
SMemCLayout smem_c_layout,
|
| 379 |
+
TiledMma tiled_mma,
|
| 380 |
+
ALoadTransform a_load_transform = {},
|
| 381 |
+
BLoadTransform b_load_transform = {},
|
| 382 |
+
CLoadTransform c_load_transform = {},
|
| 383 |
+
CStoreTransform c_store_transform = {},
|
| 384 |
+
ASMemCopyOp a_smem_copy_op = {},
|
| 385 |
+
BSMemCopyOp b_smem_copy_op = {},
|
| 386 |
+
CSMemCopyLdOp c_smem_copy_ld_op = {},
|
| 387 |
+
CSMemCopyStOp c_smem_copy_st_op = {})
|
| 388 |
+
{
|
| 389 |
+
static_assert(std::is_same_v<typename fp64_tester<TA>::value_type, typename fp64_tester<TB>::value_type>);
|
| 390 |
+
static_assert(std::is_same_v<typename fp64_tester<TB>::value_type, typename fp64_tester<TC>::value_type>);
|
| 391 |
+
|
| 392 |
+
static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM
|
| 393 |
+
static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN
|
| 394 |
+
static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK
|
| 395 |
+
|
| 396 |
+
static_assert(size<0>(smem_a_layout) == size<0>(smem_c_layout)); // AM == CM
|
| 397 |
+
static_assert(size<0>(smem_b_layout) == size<1>(smem_c_layout)); // BN == CN
|
| 398 |
+
static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK
|
| 399 |
+
|
| 400 |
+
static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout));
|
| 401 |
+
static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout));
|
| 402 |
+
static_assert(cute::size(gmem_c_layout) == cute::size(smem_c_layout));
|
| 403 |
+
|
| 404 |
+
#if 0
|
| 405 |
+
print(" "); print("gmem: "); print(gmem_layout); print("\n");
|
| 406 |
+
print(" "); print("smem: "); print(smem_layout); print("\n");
|
| 407 |
+
print(" "); print("threads: "); print(ThreadBlockSize); print("\n");
|
| 408 |
+
#endif
|
| 409 |
+
|
| 410 |
+
const auto alpha = static_cast<TC>(1.1);
|
| 411 |
+
const auto beta = static_cast<TC>(1.2);
|
| 412 |
+
|
| 413 |
+
// Generate inputs
|
| 414 |
+
auto [h_a, h_b, h_c, h_c_out] = host_generate_gemm_inputs<TA, TB, TC>(gmem_a_layout, gmem_b_layout, gmem_c_layout);
|
| 415 |
+
|
| 416 |
+
thrust::device_vector<TA> d_a(h_a);
|
| 417 |
+
thrust::device_vector<TB> d_b(h_b);
|
| 418 |
+
thrust::device_vector<TC> d_c(h_c);
|
| 419 |
+
thrust::device_vector<TC> d_c_out(h_c_out.size(), TC(float(-1)));
|
| 420 |
+
|
| 421 |
+
constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8;
|
| 422 |
+
|
| 423 |
+
const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) +
|
| 424 |
+
round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) +
|
| 425 |
+
sizeof(TC) * h_c.size();
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
auto kernel = cooperative_gemm_kernel<
|
| 429 |
+
ThreadBlockSize, CopyMaxVecBits,
|
| 430 |
+
GMemALayout, GMemBLayout, GMemCLayout,
|
| 431 |
+
SMemALayout, SMemBLayout, SMemCLayout,
|
| 432 |
+
TA, TB, TC, decltype(alpha), decltype(beta),
|
| 433 |
+
TiledMma,
|
| 434 |
+
ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform,
|
| 435 |
+
ASMemCopyOp, BSMemCopyOp, CSMemCopyLdOp, CSMemCopyStOp
|
| 436 |
+
>;
|
| 437 |
+
|
| 438 |
+
ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast<int>(shared_memory_size)), 0);
|
| 439 |
+
|
| 440 |
+
kernel<<<1, ThreadBlockSize, shared_memory_size>>>(
|
| 441 |
+
gmem_a_layout,
|
| 442 |
+
gmem_b_layout,
|
| 443 |
+
gmem_c_layout,
|
| 444 |
+
smem_a_layout,
|
| 445 |
+
smem_b_layout,
|
| 446 |
+
smem_c_layout,
|
| 447 |
+
thrust::raw_pointer_cast(d_a.data()),
|
| 448 |
+
thrust::raw_pointer_cast(d_b.data()),
|
| 449 |
+
thrust::raw_pointer_cast(d_c.data()),
|
| 450 |
+
thrust::raw_pointer_cast(d_c_out.data()),
|
| 451 |
+
alpha,
|
| 452 |
+
beta,
|
| 453 |
+
tiled_mma,
|
| 454 |
+
a_load_transform,
|
| 455 |
+
b_load_transform,
|
| 456 |
+
c_load_transform,
|
| 457 |
+
c_store_transform,
|
| 458 |
+
a_smem_copy_op,
|
| 459 |
+
b_smem_copy_op,
|
| 460 |
+
c_smem_copy_ld_op,
|
| 461 |
+
c_smem_copy_st_op
|
| 462 |
+
);
|
| 463 |
+
|
| 464 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 465 |
+
if (result != cudaSuccess) {
|
| 466 |
+
cudaError_t error = cudaGetLastError();
|
| 467 |
+
FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n";
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
// Reference gemm
|
| 471 |
+
auto h_c_ref = host_reference_gemm(alpha,
|
| 472 |
+
make_tensor(h_a.data(), gmem_a_layout),
|
| 473 |
+
make_tensor(h_b.data(), gmem_b_layout),
|
| 474 |
+
beta,
|
| 475 |
+
make_tensor(h_c.data(), gmem_c_layout),
|
| 476 |
+
a_load_transform,
|
| 477 |
+
b_load_transform,
|
| 478 |
+
c_load_transform,
|
| 479 |
+
c_store_transform);
|
| 480 |
+
|
| 481 |
+
// Copy result data
|
| 482 |
+
h_c_out = d_c_out;
|
| 483 |
+
|
| 484 |
+
// Verify correctness
|
| 485 |
+
verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout),
|
| 486 |
+
make_tensor(h_c_ref.data(), gmem_c_layout));
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
template<uint32_t ThreadBlockSize,
|
| 490 |
+
uint32_t CopyMaxVecBits,
|
| 491 |
+
class TA,
|
| 492 |
+
class TB,
|
| 493 |
+
class TC,
|
| 494 |
+
class GMemALayout, // logical shape (M, K)
|
| 495 |
+
class GMemBLayout, // logical shape (N, K)
|
| 496 |
+
class GMemCLayout, // logical shape (M, N)
|
| 497 |
+
class SMemALayout, // logical shape (M, K)
|
| 498 |
+
class SMemBLayout, // logical shape (N, K)
|
| 499 |
+
class TiledMma,
|
| 500 |
+
class ALoadTransform = cute::identity,
|
| 501 |
+
class BLoadTransform = cute::identity,
|
| 502 |
+
class CLoadTransform = cute::identity,
|
| 503 |
+
class CStoreTransform = cute::identity,
|
| 504 |
+
class ASMemCopyOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>,
|
| 505 |
+
class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>>
|
| 506 |
+
void test_cooperative_gemm_rmem_c(GMemALayout gmem_a_layout,
|
| 507 |
+
GMemBLayout gmem_b_layout,
|
| 508 |
+
GMemCLayout gmem_c_layout,
|
| 509 |
+
SMemALayout smem_a_layout,
|
| 510 |
+
SMemBLayout smem_b_layout,
|
| 511 |
+
TiledMma tiled_mma,
|
| 512 |
+
ALoadTransform a_load_transform = {},
|
| 513 |
+
BLoadTransform b_load_transform = {},
|
| 514 |
+
CLoadTransform c_load_transform = {},
|
| 515 |
+
CStoreTransform c_store_transform = {},
|
| 516 |
+
ASMemCopyOp a_smem_copy_op = {},
|
| 517 |
+
BSMemCopyOp b_smem_copy_op = {})
|
| 518 |
+
{
|
| 519 |
+
static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM
|
| 520 |
+
static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN
|
| 521 |
+
static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK
|
| 522 |
+
|
| 523 |
+
static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK
|
| 524 |
+
|
| 525 |
+
static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout));
|
| 526 |
+
static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout));
|
| 527 |
+
|
| 528 |
+
#if 0
|
| 529 |
+
print(" "); print("gmem: "); print(gmem_layout); print("\n");
|
| 530 |
+
print(" "); print("smem: "); print(smem_layout); print("\n");
|
| 531 |
+
print(" "); print("threads: "); print(ThreadBlockSize); print("\n");
|
| 532 |
+
#endif
|
| 533 |
+
|
| 534 |
+
const auto alpha = static_cast<TC>(1.0);
|
| 535 |
+
const auto beta = static_cast<TC>(1.0);
|
| 536 |
+
|
| 537 |
+
// Generate inputs
|
| 538 |
+
auto [h_a, h_b, h_c, h_c_out] =
|
| 539 |
+
host_generate_gemm_inputs<TA, TB, TC>(gmem_a_layout, gmem_b_layout, gmem_c_layout);
|
| 540 |
+
|
| 541 |
+
thrust::device_vector<TA> d_a(h_a);
|
| 542 |
+
thrust::device_vector<TB> d_b(h_b);
|
| 543 |
+
thrust::device_vector<TC> d_c(h_c);
|
| 544 |
+
thrust::device_vector<TC> d_c_out(h_c_out.size(), static_cast<TC>(-1));
|
| 545 |
+
|
| 546 |
+
constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8;
|
| 547 |
+
|
| 548 |
+
const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) +
|
| 549 |
+
round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes);
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
auto kernel = cooperative_gemm_kernel_rmem_c<
|
| 553 |
+
ThreadBlockSize, CopyMaxVecBits,
|
| 554 |
+
GMemALayout, GMemBLayout, GMemCLayout,
|
| 555 |
+
SMemALayout, SMemBLayout,
|
| 556 |
+
TA, TB, TC,
|
| 557 |
+
TiledMma,
|
| 558 |
+
ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform,
|
| 559 |
+
ASMemCopyOp, BSMemCopyOp
|
| 560 |
+
>;
|
| 561 |
+
|
| 562 |
+
ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast<int>(shared_memory_size)), 0);
|
| 563 |
+
|
| 564 |
+
kernel<<<1, ThreadBlockSize, shared_memory_size>>>(
|
| 565 |
+
gmem_a_layout,
|
| 566 |
+
gmem_b_layout,
|
| 567 |
+
gmem_c_layout,
|
| 568 |
+
smem_a_layout,
|
| 569 |
+
smem_b_layout,
|
| 570 |
+
thrust::raw_pointer_cast(d_a.data()),
|
| 571 |
+
thrust::raw_pointer_cast(d_b.data()),
|
| 572 |
+
thrust::raw_pointer_cast(d_c.data()),
|
| 573 |
+
thrust::raw_pointer_cast(d_c_out.data()),
|
| 574 |
+
tiled_mma,
|
| 575 |
+
a_load_transform, b_load_transform, c_load_transform, c_store_transform,
|
| 576 |
+
a_smem_copy_op, b_smem_copy_op
|
| 577 |
+
);
|
| 578 |
+
|
| 579 |
+
cudaError_t result = cudaDeviceSynchronize();
|
| 580 |
+
if (result != cudaSuccess) {
|
| 581 |
+
cudaError_t error = cudaGetLastError();
|
| 582 |
+
FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n";
|
| 583 |
+
}
|
| 584 |
+
|
| 585 |
+
// Copy result data
|
| 586 |
+
h_c_out = d_c_out;
|
| 587 |
+
|
| 588 |
+
// Reference gemm
|
| 589 |
+
auto h_c_ref = host_reference_gemm(alpha,
|
| 590 |
+
make_tensor(h_a.data(), gmem_a_layout),
|
| 591 |
+
make_tensor(h_b.data(), gmem_b_layout),
|
| 592 |
+
beta,
|
| 593 |
+
make_tensor(h_c.data(), gmem_c_layout),
|
| 594 |
+
a_load_transform,
|
| 595 |
+
b_load_transform,
|
| 596 |
+
c_load_transform,
|
| 597 |
+
c_store_transform);
|
| 598 |
+
|
| 599 |
+
// Verify correctness
|
| 600 |
+
verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout),
|
| 601 |
+
make_tensor(h_c_ref.data(), gmem_c_layout));
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
template<uint32_t ThreadBlockSize,
|
| 605 |
+
uint32_t CopyMaxVecBits,
|
| 606 |
+
class TA,
|
| 607 |
+
class TB,
|
| 608 |
+
class TC,
|
| 609 |
+
class ShapeMNK,
|
| 610 |
+
class TiledMma,
|
| 611 |
+
class ... Ops>
|
| 612 |
+
void test_cooperative_gemm_col_major_layout(ShapeMNK shape_mnk,
|
| 613 |
+
TiledMma tiled_mma,
|
| 614 |
+
Ops ... ops)
|
| 615 |
+
{
|
| 616 |
+
auto a_layout = make_layout(select<0, 2>(shape_mnk));
|
| 617 |
+
auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{});
|
| 618 |
+
auto c_layout = make_layout(select<0, 1>(shape_mnk));
|
| 619 |
+
|
| 620 |
+
test_cooperative_gemm<ThreadBlockSize,
|
| 621 |
+
CopyMaxVecBits,
|
| 622 |
+
TA, TB, TC>
|
| 623 |
+
(a_layout,
|
| 624 |
+
b_layout,
|
| 625 |
+
c_layout,
|
| 626 |
+
a_layout,
|
| 627 |
+
b_layout,
|
| 628 |
+
c_layout,
|
| 629 |
+
tiled_mma,
|
| 630 |
+
ops...);
|
| 631 |
+
}
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
template<uint32_t ThreadBlockSize,
|
| 635 |
+
uint32_t CopyMaxVecBits,
|
| 636 |
+
class TA,
|
| 637 |
+
class TB,
|
| 638 |
+
class TC,
|
| 639 |
+
class SMemAtomLayoutA,
|
| 640 |
+
class SMemAtomLayoutB,
|
| 641 |
+
class SMemAtomLayoutC,
|
| 642 |
+
class ShapeMNK,
|
| 643 |
+
class TiledMma,
|
| 644 |
+
class ... Ops>
|
| 645 |
+
std::enable_if_t<std::conjunction_v<cute::is_layout<SMemAtomLayoutA>,
|
| 646 |
+
cute::is_layout<SMemAtomLayoutB>,
|
| 647 |
+
cute::is_layout<SMemAtomLayoutC>>>
|
| 648 |
+
test_cooperative_gemm_col_major_layout(SMemAtomLayoutA smem_atom_layout_a,
|
| 649 |
+
SMemAtomLayoutB smem_atom_layout_b,
|
| 650 |
+
SMemAtomLayoutC smem_atom_layout_c,
|
| 651 |
+
ShapeMNK shape_mnk,
|
| 652 |
+
TiledMma tiled_mma,
|
| 653 |
+
Ops&& ... ops)
|
| 654 |
+
{
|
| 655 |
+
auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk));
|
| 656 |
+
auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{});
|
| 657 |
+
auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk));
|
| 658 |
+
|
| 659 |
+
auto smem_a_layout = tile_to_shape(
|
| 660 |
+
smem_atom_layout_a,
|
| 661 |
+
make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout)));
|
| 662 |
+
|
| 663 |
+
auto smem_b_layout = tile_to_shape(
|
| 664 |
+
smem_atom_layout_b,
|
| 665 |
+
make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout)));
|
| 666 |
+
|
| 667 |
+
auto smem_c_layout = tile_to_shape(
|
| 668 |
+
smem_atom_layout_c,
|
| 669 |
+
make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout)));
|
| 670 |
+
|
| 671 |
+
test_cooperative_gemm<ThreadBlockSize,
|
| 672 |
+
CopyMaxVecBits,
|
| 673 |
+
TA, TB, TC>
|
| 674 |
+
(gmem_a_layout,
|
| 675 |
+
gmem_b_layout,
|
| 676 |
+
gmem_c_layout,
|
| 677 |
+
smem_a_layout,
|
| 678 |
+
smem_b_layout,
|
| 679 |
+
smem_c_layout,
|
| 680 |
+
tiled_mma,
|
| 681 |
+
ops...);
|
| 682 |
+
}
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
template<uint32_t ThreadBlockSize,
|
| 686 |
+
uint32_t CopyMaxVecBits,
|
| 687 |
+
class TA,
|
| 688 |
+
class TB,
|
| 689 |
+
class TC,
|
| 690 |
+
class ShapeMNK,
|
| 691 |
+
class TiledMma,
|
| 692 |
+
class ... Ops>
|
| 693 |
+
void test_cooperative_gemm_col_major_layout_rmem_c(ShapeMNK shape_mnk,
|
| 694 |
+
TiledMma tiled_mma,
|
| 695 |
+
Ops ... ops)
|
| 696 |
+
{
|
| 697 |
+
auto a_layout = make_layout(select<0, 2>(shape_mnk));
|
| 698 |
+
auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{});
|
| 699 |
+
auto c_layout = make_layout(select<0, 1>(shape_mnk));
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
test_cooperative_gemm_rmem_c<ThreadBlockSize,
|
| 703 |
+
CopyMaxVecBits,
|
| 704 |
+
TA, TB,TC>
|
| 705 |
+
(a_layout,
|
| 706 |
+
b_layout,
|
| 707 |
+
c_layout,
|
| 708 |
+
a_layout,
|
| 709 |
+
b_layout,
|
| 710 |
+
tiled_mma,
|
| 711 |
+
ops...);
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
template<uint32_t ThreadBlockSize,
|
| 715 |
+
uint32_t CopyMaxVecBits,
|
| 716 |
+
class TA,
|
| 717 |
+
class TB,
|
| 718 |
+
class TC,
|
| 719 |
+
class SMemAtomLayoutA,
|
| 720 |
+
class SMemAtomLayoutB,
|
| 721 |
+
class ShapeMNK,
|
| 722 |
+
class TiledMma,
|
| 723 |
+
class ... Ops>
|
| 724 |
+
std::enable_if_t<std::conjunction_v<cute::is_layout<SMemAtomLayoutA>,
|
| 725 |
+
cute::is_layout<SMemAtomLayoutB>>>
|
| 726 |
+
test_cooperative_gemm_col_major_layout_rmem_c(SMemAtomLayoutA smem_atom_layout_a,
|
| 727 |
+
SMemAtomLayoutB smem_atom_layout_b,
|
| 728 |
+
ShapeMNK shape_mnk,
|
| 729 |
+
TiledMma tiled_mma,
|
| 730 |
+
Ops ... ops)
|
| 731 |
+
{
|
| 732 |
+
auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk));
|
| 733 |
+
auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{});
|
| 734 |
+
auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk));
|
| 735 |
+
|
| 736 |
+
auto smem_a_layout = tile_to_shape(
|
| 737 |
+
smem_atom_layout_a,
|
| 738 |
+
make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout)));
|
| 739 |
+
|
| 740 |
+
auto smem_b_layout = tile_to_shape(
|
| 741 |
+
smem_atom_layout_b,
|
| 742 |
+
make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout)));
|
| 743 |
+
|
| 744 |
+
test_cooperative_gemm_rmem_c<ThreadBlockSize, CopyMaxVecBits,
|
| 745 |
+
TA, TB, TC>
|
| 746 |
+
(gmem_a_layout,
|
| 747 |
+
gmem_b_layout,
|
| 748 |
+
gmem_c_layout,
|
| 749 |
+
smem_a_layout,
|
| 750 |
+
smem_b_layout,
|
| 751 |
+
tiled_mma,
|
| 752 |
+
ops...);
|
| 753 |
+
}
|
| 754 |
+
|
| 755 |
+
template<uint32_t ThreadBlockSize,
|
| 756 |
+
typename T,
|
| 757 |
+
class ... Args>
|
| 758 |
+
void test_cooperative_gemm_col_major_layout_rmem_c(Args&& ... args)
|
| 759 |
+
{
|
| 760 |
+
test_cooperative_gemm_col_major_layout_rmem_c<ThreadBlockSize,
|
| 761 |
+
cute::sizeof_bits_v<T>,
|
| 762 |
+
T, T, T>
|
| 763 |
+
(static_cast<Args&&>(args)...);
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
template<uint32_t ThreadBlockSize,
|
| 767 |
+
class T,
|
| 768 |
+
class ... Args>
|
| 769 |
+
void test_cooperative_gemm_col_major_layout(Args&& ... args)
|
| 770 |
+
{
|
| 771 |
+
test_cooperative_gemm_col_major_layout<ThreadBlockSize,
|
| 772 |
+
cute::sizeof_bits_v<T>,
|
| 773 |
+
T, T, T>
|
| 774 |
+
(static_cast<Args&&>(args)...);
|
| 775 |
+
}
|
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/***************************************************************************************************
|
| 2 |
+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 3 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
| 4 |
+
*
|
| 5 |
+
* Redistribution and use in source and binary forms, with or without
|
| 6 |
+
* modification, are permitted provided that the following conditions are met:
|
| 7 |
+
*
|
| 8 |
+
* 1. Redistributions of source code must retain the above copyright notice, this
|
| 9 |
+
* list of conditions and the following disclaimer.
|
| 10 |
+
*
|
| 11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
| 12 |
+
* this list of conditions and the following disclaimer in the documentation
|
| 13 |
+
* and/or other materials provided with the distribution.
|
| 14 |
+
*
|
| 15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
| 16 |
+
* contributors may be used to endorse or promote products derived from
|
| 17 |
+
* this software without specific prior written permission.
|
| 18 |
+
*
|
| 19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
| 20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
| 21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 22 |
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 23 |
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 24 |
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 25 |
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 26 |
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 27 |
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 28 |
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 29 |
+
*
|
| 30 |
+
**************************************************************************************************/
|
| 31 |
+
|
| 32 |
+
#pragma once
|
| 33 |
+
|
| 34 |
+
#include "cutlass_unit_test.h"
|
| 35 |
+
|
| 36 |
+
#include <iostream>
|
| 37 |
+
#include <cstdint>
|
| 38 |
+
|
| 39 |
+
#include <thrust/host_vector.h>
|
| 40 |
+
#include <thrust/device_vector.h>
|
| 41 |
+
|
| 42 |
+
#include <cute/tensor.hpp>
|
| 43 |
+
|
| 44 |
+
namespace cutlass::test {
|
| 45 |
+
|
| 46 |
+
template <class ElementType, class SmemLayout>
|
| 47 |
+
struct SharedStorage
|
| 48 |
+
{
|
| 49 |
+
cute::ArrayEngine<ElementType, cute::cosize_v<SmemLayout>> smem;
|
| 50 |
+
alignas(16) cute::uint64_t tma_load_mbar[1];
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
#if CUDA_12_0_SM90_FEATURES_SUPPORTED
|
| 54 |
+
|
| 55 |
+
template <class T, class TiledCopy, class CTA_Tiler, class GmemLayout, class SmemLayout>
|
| 56 |
+
__global__ void
|
| 57 |
+
tma_test_device_cute(T const* g_in, T* g_out,
|
| 58 |
+
CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler,
|
| 59 |
+
GmemLayout gmem_layout, SmemLayout smem_layout)
|
| 60 |
+
{
|
| 61 |
+
using namespace cute;
|
| 62 |
+
CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout)));
|
| 63 |
+
|
| 64 |
+
// Use Shared Storage structure to allocate and distribute aligned SMEM addresses
|
| 65 |
+
extern __shared__ char shared_memory[];
|
| 66 |
+
using SharedStorage = SharedStorage<T, SmemLayout>;
|
| 67 |
+
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
|
| 68 |
+
|
| 69 |
+
// Construct SMEM tensor
|
| 70 |
+
Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...)
|
| 71 |
+
// Shared memory barriers use 64bits in SMEM for synchronization
|
| 72 |
+
uint64_t* tma_load_mbar = shared_storage.tma_load_mbar;
|
| 73 |
+
|
| 74 |
+
// TMA requires special handling of strides to deal with coord codomain mapping
|
| 75 |
+
// Represent the full tensors -- get these from TMA
|
| 76 |
+
Tensor mA = tma.get_tma_tensor(shape(gmem_layout));
|
| 77 |
+
Tensor mB = make_tensor(make_gmem_ptr<T>(g_out), gmem_layout);
|
| 78 |
+
|
| 79 |
+
constexpr int R = rank_v<CTA_Tiler>;
|
| 80 |
+
Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
|
| 81 |
+
Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
|
| 82 |
+
|
| 83 |
+
//
|
| 84 |
+
// Prepare the TMA_LOAD
|
| 85 |
+
//
|
| 86 |
+
|
| 87 |
+
auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice
|
| 88 |
+
Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N)
|
| 89 |
+
Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N)
|
| 90 |
+
|
| 91 |
+
#if 0
|
| 92 |
+
if (thread0()) {
|
| 93 |
+
print(tma);
|
| 94 |
+
print("TILE : "); print(cta_tiler); print("\n");
|
| 95 |
+
print(" mA : "); print( mA); print("\n");
|
| 96 |
+
print(" mB : "); print( mB); print("\n");
|
| 97 |
+
print(" gA : "); print( gA); print("\n");
|
| 98 |
+
print(" gB : "); print( gB); print("\n");
|
| 99 |
+
print(" sA : "); print( sA); print("\n");
|
| 100 |
+
print("tAgA_x: "); print(tAgA_x); print("\n");
|
| 101 |
+
print("tAsA_x: "); print(tAsA_x); print("\n");
|
| 102 |
+
}
|
| 103 |
+
#endif
|
| 104 |
+
|
| 105 |
+
//
|
| 106 |
+
// Perform the TMA_LOAD
|
| 107 |
+
//
|
| 108 |
+
|
| 109 |
+
// INPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles
|
| 110 |
+
Tensor tAgA = group_modes<1,rank(tAgA_x)>(tAgA_x); // (TMA,REST)
|
| 111 |
+
Tensor tAsA = group_modes<1,rank(tAsA_x)>(tAsA_x); // (TMA,REST)
|
| 112 |
+
static_assert(size<1>(tAsA) == 1);
|
| 113 |
+
|
| 114 |
+
// OUTPUT: Group the CTA_TILE_X modes and REST_X modes for output
|
| 115 |
+
Tensor tBgB = group_modes<0,R>(group_modes<R,rank(gB)>(gB)); // (CTA_TILE, REST)
|
| 116 |
+
|
| 117 |
+
#if 0
|
| 118 |
+
if (thread0()) {
|
| 119 |
+
print("tAgA : "); print(tAgA); print("\n");
|
| 120 |
+
print("tAsA : "); print(tAsA); print("\n");
|
| 121 |
+
print("tBgB : "); print(tBgB); print("\n");
|
| 122 |
+
}
|
| 123 |
+
#endif
|
| 124 |
+
|
| 125 |
+
// Test L2 prefetch
|
| 126 |
+
if (threadIdx.x == 0) {
|
| 127 |
+
prefetch(tma, tAgA);
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
// Loop over the TMA stages, using smem as our buffer
|
| 131 |
+
for (int stage = 0; stage < size<1>(tAgA); ++stage)
|
| 132 |
+
{
|
| 133 |
+
// Set the bytes transferred in this TMA transaction (may involve multiple issues)
|
| 134 |
+
constexpr int kTmaTransactionBytes = sizeof(make_tensor_like(tensor<0>(tAsA)));
|
| 135 |
+
|
| 136 |
+
if (threadIdx.x == 0)
|
| 137 |
+
{
|
| 138 |
+
/// Initialize shared memory barrier
|
| 139 |
+
tma_load_mbar[0] = 0;
|
| 140 |
+
cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/);
|
| 141 |
+
cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes);
|
| 142 |
+
|
| 143 |
+
copy(tma.with(tma_load_mbar[0]), tAgA(_,stage), tAsA(_,0));
|
| 144 |
+
}
|
| 145 |
+
__syncthreads();
|
| 146 |
+
|
| 147 |
+
/// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value
|
| 148 |
+
constexpr int kPhaseBit = 0;
|
| 149 |
+
cute::wait_barrier(tma_load_mbar[0], kPhaseBit);
|
| 150 |
+
|
| 151 |
+
//
|
| 152 |
+
// Write out trivially smem -> gmem
|
| 153 |
+
//
|
| 154 |
+
|
| 155 |
+
// Subbyte elements could cause race conditions, so be even more conservative
|
| 156 |
+
if (thread0()) {
|
| 157 |
+
copy(sA, tBgB(_,stage));
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
__syncthreads();
|
| 161 |
+
}
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
template <class T, class TmaType = T, class CopyOp, class GMEM_Layout, class SMEM_Layout, class CTA_Tile>
|
| 165 |
+
auto
|
| 166 |
+
test_tma_load(CopyOp const& copy_op,
|
| 167 |
+
GMEM_Layout const& gmem_layout,
|
| 168 |
+
SMEM_Layout const& smem_layout,
|
| 169 |
+
CTA_Tile const& cta_tile)
|
| 170 |
+
{
|
| 171 |
+
using namespace cute;
|
| 172 |
+
|
| 173 |
+
// Allocate and initialize host test data
|
| 174 |
+
size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits<T>::value, 8);
|
| 175 |
+
thrust::host_vector<uint8_t> h_in(N);
|
| 176 |
+
for (size_t i = 0; i < h_in.size(); ++i) {
|
| 177 |
+
h_in[i] = uint8_t(i % 13);
|
| 178 |
+
}
|
| 179 |
+
Tensor hA_in = make_tensor(recast_ptr<T>(h_in.data()), gmem_layout);
|
| 180 |
+
|
| 181 |
+
// Allocate and initialize device test data
|
| 182 |
+
thrust::device_vector<uint8_t> d_in = h_in;
|
| 183 |
+
thrust::device_vector<uint8_t> d_out(h_in.size(), uint8_t(-1)); // overflow uint
|
| 184 |
+
|
| 185 |
+
// Create TMA for this device Tensor
|
| 186 |
+
Tensor gA = make_tensor(make_gmem_ptr<T>(raw_pointer_cast(d_in.data())), gmem_layout);
|
| 187 |
+
auto tma = make_tma_copy<TmaType>(copy_op, gA, smem_layout, cta_tile, Int<1>{});
|
| 188 |
+
//print(tma);
|
| 189 |
+
|
| 190 |
+
// Launch
|
| 191 |
+
int smem_size = int(sizeof(SharedStorage<T, decltype(smem_layout)>));
|
| 192 |
+
tma_test_device_cute<<<1, 128, smem_size>>>(
|
| 193 |
+
reinterpret_cast<T const*>(raw_pointer_cast(d_in.data())),
|
| 194 |
+
reinterpret_cast<T*> (raw_pointer_cast(d_out.data())),
|
| 195 |
+
tma, cta_tile,
|
| 196 |
+
gmem_layout,
|
| 197 |
+
smem_layout);
|
| 198 |
+
|
| 199 |
+
// Copy results back to host
|
| 200 |
+
thrust::host_vector<uint8_t> h_out = d_out;
|
| 201 |
+
Tensor hA_out = make_tensor(recast_ptr<T>(h_out.data()), gmem_layout);
|
| 202 |
+
|
| 203 |
+
// Validate the results. Print only the first 3 errors.
|
| 204 |
+
int count = 3;
|
| 205 |
+
for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) {
|
| 206 |
+
EXPECT_EQ(hA_in(i), hA_out(i));
|
| 207 |
+
if (hA_in(i) != hA_out(i)) {
|
| 208 |
+
--count;
|
| 209 |
+
}
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
return tma;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
#endif
|
| 216 |
+
|
| 217 |
+
} // end namespace cutlass::test
|