Kernels
danieldk HF Staff commited on
Commit
06534f8
·
verified ·
1 Parent(s): d7f7ab1

Build uploaded using `kernels` (batch 5/10).

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py +173 -0
  3. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py +142 -0
  4. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py +319 -0
  5. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py +180 -0
  6. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py +44 -0
  7. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py +235 -0
  8. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py +134 -0
  9. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py +128 -0
  10. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py +146 -0
  11. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py +104 -0
  12. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py +103 -0
  13. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py +71 -0
  14. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py +112 -0
  15. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py +75 -0
  16. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py +103 -0
  17. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py +98 -0
  18. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py +423 -0
  19. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py +44 -0
  20. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py +260 -0
  21. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/installation.py +57 -0
  22. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py +284 -0
  23. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py +254 -0
  24. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py +354 -0
  25. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py +69 -0
  26. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py +75 -0
  27. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py +95 -0
  28. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py +92 -0
  29. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py +213 -0
  30. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py +80 -0
  31. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py +87 -0
  32. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py +96 -0
  33. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py +59 -0
  34. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h +102 -0
  35. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h +907 -0
  36. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h +927 -0
  37. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h +818 -0
  38. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h +666 -0
  39. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h +622 -0
  40. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h +734 -0
  41. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h +643 -0
  42. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h +293 -0
  43. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h +716 -0
  44. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h +732 -0
  45. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h +473 -0
  46. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp +1385 -0
  47. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp +768 -0
  48. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp +158 -0
  49. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp +775 -0
  50. build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp +217 -0
.gitattributes CHANGED
@@ -14,3 +14,4 @@ build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lf
14
  build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
15
  build/torch210-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
16
  build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
 
 
14
  build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
15
  build/torch210-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
16
  build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
17
+ build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ """
34
+ Unit test for store nodes in SM90
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen.backend import *
42
+ from cutlass_cppgen.epilogue import *
43
+
44
+ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
45
+
46
+ cutlass_cppgen.set_log_level(logging.WARNING)
47
+
48
+
49
+ @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
50
+ class TestEVTLayout(EVTTestCaseBase):
51
+
52
+ def test_permute_1(self):
53
+ """
54
+ Returning a tensor with shape [m, n]
55
+ """
56
+ def evt_permute(accum, alpha, C):
57
+ F = alpha * accum
58
+ F_permute = permute(F, indices=(0, 2, 1))
59
+ D_permute = F_permute + permute(C, indices=(0, 2, 1))
60
+ D = permute(D_permute, indices=(0, 2, 1))
61
+ return D, F
62
+
63
+ for m, n, k, l in self.get_problem_sizes(8):
64
+ example_inputs = {
65
+ "accum": self.fake_tensor(self.element, (l, m, n)),
66
+ "alpha": 0.5,
67
+ "C": self.fake_tensor(self.element, (l, m, n)),
68
+ "F": self.fake_tensor(self.element, (l, m, n)),
69
+ "D": self.fake_tensor(self.element, (l, m, n)),
70
+ }
71
+
72
+ launcher = EVTTestBed(self.element, evt_permute, example_inputs)
73
+ input_keys = ["C", "alpha"]
74
+ result_keys = ["D", "F"]
75
+ launcher.verify((m, n, k), input_keys, result_keys, l)
76
+
77
+ @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
78
+ def test_permute_2(self):
79
+ """
80
+ Returning a tensor with shape [m, n]
81
+ """
82
+ def evt_permute(accum, alpha, C):
83
+ F = alpha * accum
84
+ F_permute = permute(F, indices=(0, 2, 1))
85
+ D = F_permute + C
86
+ return D, F
87
+
88
+ for m, n, k, l in self.get_problem_sizes(8):
89
+ example_inputs = {
90
+ "accum": self.fake_tensor(self.element, (l, m, n)),
91
+ "alpha": 0.5,
92
+ "C": self.fake_tensor(self.element, (l, n, m)),
93
+ "F": self.fake_tensor(self.element, (l, m, n)),
94
+ "D": self.fake_tensor(self.element, (l, n, m)),
95
+ }
96
+
97
+ launcher = EVTTestBed(self.element, evt_permute, example_inputs)
98
+ input_keys = ["C", "alpha"]
99
+ result_keys = ["D", "F"]
100
+ launcher.verify((m, n, k), input_keys, result_keys, l)
101
+
102
+ @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only")
103
+ def test_permute_3(self):
104
+ """
105
+ Returning a tensor with shape [m, n]
106
+ """
107
+ def evt_permute(accum, alpha, C):
108
+ F = alpha * accum
109
+ F_permute = permute(F, indices=(1, 0, 2))
110
+ D = F_permute + C
111
+ return D, F
112
+
113
+ for m, n, k, l in self.get_problem_sizes(8):
114
+ example_inputs = {
115
+ "accum": self.fake_tensor(self.element, (l, m, n)),
116
+ "alpha": 0.5,
117
+ "C": self.fake_tensor(self.element, (m, l, n)),
118
+ "F": self.fake_tensor(self.element, (l, m, n)),
119
+ "D": self.fake_tensor(self.element, (m, l, n)),
120
+ }
121
+
122
+ launcher = EVTTestBed(self.element, evt_permute, example_inputs)
123
+ input_keys = ["C", "alpha"]
124
+ result_keys = ["D", "F"]
125
+ launcher.verify((m, n, k), input_keys, result_keys, l)
126
+
127
+ def test_reshape(self):
128
+ """
129
+ Test reshape
130
+ """
131
+ def evt_reshape(accum, alpha, TensorE):
132
+ F = alpha * accum
133
+ E_reshape = reshape(TensorE, new_shape=(512, 1))
134
+ D = F + E_reshape
135
+ return D
136
+
137
+ example_inputs = {
138
+ "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)),
139
+ "alpha": 0.5,
140
+ "TensorE": self.fake_tensor(self.element, (16, 32)),
141
+ "D": self.fake_tensor(self.element, (self.l, self.m, self.n)),
142
+ }
143
+
144
+ launcher = EVTTestBed(self.element, evt_reshape, example_inputs)
145
+ input_keys = ["alpha", "TensorE"]
146
+ result_keys = ["D"]
147
+ launcher.verify(self.problem_size, input_keys, result_keys, self.l)
148
+
149
+ def test_reshape2(self):
150
+ """
151
+ Test reshape
152
+ """
153
+ def evt_reshape(accum, alpha, TensorE):
154
+ F = alpha * accum
155
+ F_reshape = reshape(F, new_shape=(2, 3, 512, 256))
156
+ D = F_reshape + TensorE
157
+ return D
158
+
159
+ example_inputs = {
160
+ "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)),
161
+ "alpha": 0.5,
162
+ "TensorE": self.fake_tensor(self.element, (2, 3, 1, self.n)),
163
+ "D": self.fake_tensor(self.element, (2, 3, self.m, self.n)),
164
+ }
165
+
166
+ launcher = EVTTestBed(self.element, evt_reshape, example_inputs)
167
+ input_keys = ["alpha", "TensorE"]
168
+ result_keys = ["D"]
169
+ launcher.verify(self.problem_size, input_keys, result_keys, self.l)
170
+
171
+
172
+ if __name__ == '__main__':
173
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ """
34
+ Unit test for load nodes in SM90
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen.backend import *
42
+ from cutlass_cppgen.epilogue import *
43
+
44
+ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
45
+
46
+ cutlass_cppgen.set_log_level(logging.WARNING)
47
+
48
+
49
+ @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
50
+ class TestEVTLoad(EVTTestCaseBase):
51
+
52
+ def test_tensor_load(self):
53
+ """
54
+ Load extra tensor with shape [m, n]
55
+ """
56
+ def evt_tensor_load(accum, C, aux, aux_batch):
57
+ D = accum + C + aux + aux_batch
58
+ return D
59
+
60
+ for m, n, k, l in self.get_problem_sizes(8):
61
+ example_inputs = {
62
+ "accum": self.fake_tensor(self.element, (l, m, n)),
63
+ "C": self.fake_tensor(self.element, (l, m, n)),
64
+ "aux": self.fake_tensor(self.element, (m, n)),
65
+ "aux_batch": self.fake_tensor(np.float32, (l, m, n)),
66
+ "D": self.fake_tensor(self.element, (l, m, n)),
67
+ }
68
+
69
+ launcher = EVTTestBed(self.element, evt_tensor_load, example_inputs)
70
+ input_keys = ["C", "aux", "aux_batch"]
71
+ result_keys = ["D"]
72
+ launcher.verify((m, n, k), input_keys, result_keys, l)
73
+
74
+ def test_row_broadcast(self):
75
+ """
76
+ Load extra tensor with shape [1, n]
77
+ """
78
+ def evt_row_broadcast(accum, C, bias, bias_batch):
79
+ D = accum + C + bias + bias_batch
80
+ return D
81
+
82
+ for m, n, k, l in self.get_problem_sizes(8):
83
+ example_inputs = {
84
+ "accum": self.fake_tensor(self.element, (l, m, n)),
85
+ "C": self.fake_tensor(self.element, (l, m, n)),
86
+ "bias": self.fake_tensor(self.element, (n,)),
87
+ "bias_batch": self.fake_tensor(np.float32, (l, 1, n)),
88
+ "D": self.fake_tensor(self.element, (l, m, n)),
89
+ }
90
+
91
+ launcher = EVTTestBed(self.element, evt_row_broadcast, example_inputs)
92
+ input_keys = ["C", "bias", "bias_batch"]
93
+ result_keys = ["D"]
94
+ launcher.verify((m, n, k), input_keys, result_keys, l)
95
+
96
+ def test_column_broadcast(self):
97
+ """
98
+ Load extra tensor with shape [m, 1]
99
+ """
100
+ def evt_column_broadcast(accum, C, bias, bias_batch):
101
+ D = accum + C + bias + bias_batch
102
+ return D
103
+
104
+ for m, n, k, l in self.get_problem_sizes(8):
105
+ example_inputs = {
106
+ "accum": self.fake_tensor(self.element, (l, m, n)),
107
+ "C": self.fake_tensor(self.element, (l, m, n)),
108
+ "bias": self.fake_tensor(self.element, (m, 1)),
109
+ "bias_batch": self.fake_tensor(np.float32, (l, m, 1)),
110
+ "D": self.fake_tensor(self.element, (l, m, n)),
111
+ }
112
+
113
+ launcher = EVTTestBed(self.element, evt_column_broadcast, example_inputs)
114
+ input_keys = ["C", "bias", "bias_batch"]
115
+ result_keys = ["D"]
116
+ launcher.verify((m, n, k), input_keys, result_keys, l)
117
+
118
+ def test_scalar_broadcast(self):
119
+ """
120
+ Load extra tensor with shape [1, 1]
121
+ """
122
+ def evt_scalar_broadcast(accum, C, alpha, alpha_batch):
123
+ D = accum + C + alpha + alpha_batch
124
+ return D
125
+
126
+ for m, n, k, l in self.get_problem_sizes(8):
127
+ example_inputs = {
128
+ "accum": self.fake_tensor(self.element, (l, m, n)),
129
+ "C": self.fake_tensor(self.element, (l, m, n)),
130
+ "alpha": 0.5,
131
+ "alpha_batch": self.fake_tensor(np.float32, (l, 1, 1)),
132
+ "D": self.fake_tensor(self.element, (l, m, n)),
133
+ }
134
+
135
+ launcher = EVTTestBed(self.element, evt_scalar_broadcast, example_inputs)
136
+ input_keys = ["C", "alpha", "alpha_batch"]
137
+ result_keys = ["D"]
138
+ launcher.verify((m, n, k), input_keys, result_keys, l)
139
+
140
+
141
+ if __name__ == '__main__':
142
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ """
34
+ Unittest for mixed types of nodes in SM90
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen.backend import *
42
+ from cutlass_cppgen.epilogue import *
43
+ from cutlass_cppgen.swizzle import ThreadblockSwizzleStreamK
44
+
45
+ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+
49
+
50
+ @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
51
+ class TestEVTMixed(EVTTestCaseBase):
52
+
53
+ def test_same_variable_used_multiple_times(self):
54
+ """
55
+ The same variable z0 is used multiple times
56
+ """
57
+ def evt_aux_store(accum):
58
+ z0 = relu(accum)
59
+ D = z0 + z0
60
+ return z0, D
61
+
62
+ for m, n, k, l in self.get_problem_sizes(8):
63
+ example_inputs = {
64
+ "accum": self.fake_tensor(self.element, (l, m, n)),
65
+ "D": self.fake_tensor(self.element, (l, m, n)),
66
+ "z0": self.fake_tensor(self.element, (l, m, n)),
67
+ }
68
+
69
+ launcher = EVTTestBed(self.element, evt_aux_store, example_inputs)
70
+ input_keys = ["accum"]
71
+ result_keys = ["z0", "D"]
72
+ launcher.verify((m, n, k), input_keys, result_keys, l)
73
+
74
+ def test_no_lca(self):
75
+ """
76
+ The same variable z0 is used multiple times
77
+ """
78
+ def evt_no_lca(accum, bias):
79
+ E = relu(accum)
80
+ F = E + bias
81
+ tmp_2 = E + 2
82
+ D = tmp_2 + E
83
+ return D
84
+
85
+ for m, n, k, l in self.get_problem_sizes(8):
86
+ example_inputs = {
87
+ "accum": self.fake_tensor(self.element, (l, m, n)),
88
+ "D": self.fake_tensor(self.element, (l, m, n)),
89
+ "bias": self.fake_tensor(self.element, (m,1), stride=(1,0)),
90
+ }
91
+
92
+ launcher = EVTTestBed(self.element, evt_no_lca, example_inputs)
93
+ input_keys = ["accum", "bias"]
94
+ result_keys = ["D"]
95
+ launcher.verify((m, n, k), input_keys, result_keys, l)
96
+
97
+ def test_mixed_dag(self):
98
+ def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
99
+ F = alpha * accum + (beta * C + aux)
100
+ F_row_max = max(F, dim=[0, 1])
101
+ E = relu(F + 1) + cbias + rbias
102
+ E_col_max = max(E, dim=[0, 2])
103
+ D = E + F
104
+ return D, F, F_row_max, E_col_max
105
+
106
+ if device_cc() == 80:
107
+ alignments = [2, 4, 8]
108
+ else:
109
+ # Sm90 EVT currently only supports 128-bit alignment
110
+ alignments = [8,]
111
+ for align in alignments:
112
+ for m, n, k, l in self.get_problem_sizes(align):
113
+ example_inputs = {
114
+ "accum": self.fake_tensor(self.element, (l, m, n)),
115
+ "alpha": 1.0,
116
+ "C": self.fake_tensor(self.element, (l, m, n)),
117
+ "beta": 1.0,
118
+ "aux": self.fake_tensor(self.element, (l, m, n)),
119
+ "cbias": self.fake_tensor(self.element, (m, 1)),
120
+ "rbias": self.fake_tensor(self.element, (n,)),
121
+ "D": self.fake_tensor(self.element, (l, m, n)),
122
+ "F": self.fake_tensor(self.element, (l, m, n)),
123
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
124
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
125
+ }
126
+
127
+ launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs)
128
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
129
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
130
+ launcher.verify((m, n, k), input_keys, result_keys, l)
131
+
132
+ @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
133
+ def test_mixed_dag_float(self):
134
+ def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
135
+ F = alpha * accum + (beta * C + aux)
136
+ F_row_max = max(F, dim=[0, 1])
137
+ E = relu(F + 1) + cbias + rbias
138
+ E_col_max = max(E, dim=[0, 2])
139
+ D = E + F
140
+ return D, F, F_row_max, E_col_max
141
+
142
+ for align in [3, 2, 4]:
143
+ for m, n, k, l in self.get_problem_sizes(align):
144
+ example_inputs = {
145
+ "accum": self.fake_tensor(np.float32, (l, m, n)),
146
+ "alpha": 1.0,
147
+ "C": self.fake_tensor(np.float32, (l, m, n)),
148
+ "beta": 1.0,
149
+ "aux": self.fake_tensor(np.float32, (l, m, n)),
150
+ "cbias": self.fake_tensor(np.float32, (m, 1)),
151
+ "rbias": self.fake_tensor(np.float32, (n,)),
152
+ "D": self.fake_tensor(np.float32, (l, m, n)),
153
+ "F": self.fake_tensor(np.float32, (l, m, n)),
154
+ "F_row_max": self.fake_tensor(np.float32, (n,)),
155
+ "E_col_max": self.fake_tensor(np.float32, (m, 1))
156
+ }
157
+ launcher = EVTTestBed(DataType.f32, evt_mixed_dag, example_inputs)
158
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
159
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
160
+ launcher.verify((m, n, k), input_keys, result_keys, l)
161
+
162
+ @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
163
+ def test_mixed_dag_stage2(self):
164
+ def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
165
+ F = alpha * accum + (beta * C + aux)
166
+ F_row_max = max(F, dim=[0, 1])
167
+ E = relu(F + 1) + cbias + rbias
168
+ E_col_max = max(E, dim=[0, 2])
169
+ D = E + F
170
+ return D, F, F_row_max, E_col_max
171
+
172
+ for m, n, k, l in self.get_problem_sizes(8):
173
+ example_inputs = {
174
+ "accum": self.fake_tensor(self.element, (l, m, n)),
175
+ "alpha": 1.0,
176
+ "C": self.fake_tensor(self.element, (l, m, n)),
177
+ "beta": 1.0,
178
+ "aux": self.fake_tensor(self.element, (l, m, n)),
179
+ "cbias": self.fake_tensor(self.element, (m, 1)),
180
+ "rbias": self.fake_tensor(self.element, (n,)),
181
+ "D": self.fake_tensor(self.element, (l, m, n)),
182
+ "F": self.fake_tensor(self.element, (l, m, n)),
183
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
184
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
185
+ }
186
+
187
+ launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, epilogue_stages=2)
188
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
189
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
190
+ launcher.verify((m, n, k), input_keys, result_keys, l)
191
+
192
+ @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
193
+ def test_mixed_dag_partition_k(self):
194
+ def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
195
+ F = alpha * accum + (beta * C + aux)
196
+ F_row_max = max(F, dim=[0, 1])
197
+ E = relu(F + 1) + cbias + rbias
198
+ E_col_max = max(E, dim=[0, 2])
199
+ D = E + F
200
+ return D, F, F_row_max, E_col_max
201
+
202
+ for m, n, k, l in self.get_problem_sizes(8):
203
+ example_inputs = {
204
+ "accum": self.fake_tensor(self.element, (l, m, n)),
205
+ "alpha": 1.0,
206
+ "C": self.fake_tensor(self.element, (l, m, n)),
207
+ "beta": 1.0,
208
+ "aux": self.fake_tensor(self.element, (l, m, n)),
209
+ "cbias": self.fake_tensor(self.element, (m, 1)),
210
+ "rbias": self.fake_tensor(self.element, (n,)),
211
+ "D": self.fake_tensor(self.element, (l, m, n)),
212
+ "F": self.fake_tensor(self.element, (l, m, n)),
213
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
214
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
215
+ }
216
+
217
+ tile_description = {
218
+ "threadblock_shape": [128, 128, 64],
219
+ "warp_count": [2, 2, 2]
220
+ }
221
+
222
+ launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, tile_description=tile_description, epilogue_stages=2)
223
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
224
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
225
+ launcher.verify((m, n, k), input_keys, result_keys, l)
226
+
227
+ @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only")
228
+ def test_mixed_dag_stream_k(self):
229
+ def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias):
230
+ F = alpha * accum + (beta * C + aux)
231
+ F_row_max = max(F, dim=[0, 1])
232
+ E = relu(F + 1) + cbias + rbias
233
+ E_col_max = max(E, dim=[0, 2])
234
+ D = E + F
235
+ return D, F, F_row_max, E_col_max
236
+
237
+ # High per-sm occupancy tile_description
238
+ tile_description = {
239
+ "threadblock_shape": [128, 128, 32],
240
+ "warp_count": [2, 2, 1],
241
+ "stages": 3
242
+ }
243
+ tds = [None, tile_description]
244
+ for td in tds:
245
+ for m, n, k, l in self.get_problem_sizes(8, k=960, batch_count=[1, 3]):
246
+ if l == 1:
247
+ example_inputs = {
248
+ "accum": self.fake_tensor(self.element, (m, n)),
249
+ "alpha": 1.0,
250
+ "C": self.fake_tensor(self.element, (m, n)),
251
+ "beta": 1.0,
252
+ "aux": self.fake_tensor(self.element, (m, n)),
253
+ "cbias": self.fake_tensor(self.element, (m, 1)),
254
+ "rbias": self.fake_tensor(self.element, (n,)),
255
+ "D": self.fake_tensor(self.element, (m, n)),
256
+ "F": self.fake_tensor(self.element, (m, n)),
257
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
258
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
259
+ }
260
+ else:
261
+ example_inputs = {
262
+ "accum": self.fake_tensor(self.element, (l, m, n)),
263
+ "alpha": 1.0,
264
+ "C": self.fake_tensor(self.element, (l, m, n)),
265
+ "beta": 1.0,
266
+ "aux": self.fake_tensor(self.element, (l, m, n)),
267
+ "cbias": self.fake_tensor(self.element, (m, 1)),
268
+ "rbias": self.fake_tensor(self.element, (n,)),
269
+ "D": self.fake_tensor(self.element, (l, m, n)),
270
+ "F": self.fake_tensor(self.element, (l, m, n)),
271
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
272
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
273
+ }
274
+
275
+ if td is not None:
276
+ launcher = EVTTestBed(
277
+ self.element, evt_mixed_dag, example_inputs,
278
+ tile_description=td,
279
+ swizzling_functor=ThreadblockSwizzleStreamK, backend="torch")
280
+ else:
281
+ launcher = EVTTestBed(
282
+ self.element, evt_mixed_dag, example_inputs,
283
+ swizzling_functor=ThreadblockSwizzleStreamK, backend="torch")
284
+
285
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
286
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
287
+ launcher.verify((m, n, k), input_keys, result_keys, l)
288
+
289
+ def test_mixed_dag_no_batch(self):
290
+ def evt_mixed_dag_no_batch(accum, alpha, C, beta, aux, cbias, rbias):
291
+ F = alpha * accum + (beta * C + aux)
292
+ F_row_max = max(F, dim=[0, 1])
293
+ E = relu(F + 1) + cbias + rbias
294
+ E_col_max = max(E, dim=[0, 2])
295
+ D = E + F
296
+ return D, F, F_row_max, E_col_max
297
+
298
+ for m, n, k, _ in self.get_problem_sizes(8):
299
+ example_inputs = {
300
+ "accum": self.fake_tensor(self.element, (m, n)),
301
+ "alpha": 1.0,
302
+ "C": self.fake_tensor(self.element, (m, n)),
303
+ "beta": 1.0,
304
+ "aux": self.fake_tensor(self.element, (m, n)),
305
+ "cbias": self.fake_tensor(self.element, (m, 1)),
306
+ "rbias": self.fake_tensor(self.element, (n,)),
307
+ "D": self.fake_tensor(self.element, (m, n)),
308
+ "F": self.fake_tensor(self.element, (m, n)),
309
+ "F_row_max": self.fake_tensor(DataType.f32, (n,)),
310
+ "E_col_max": self.fake_tensor(DataType.f32, (m, 1))
311
+ }
312
+
313
+ launcher = EVTTestBed(self.element, evt_mixed_dag_no_batch, example_inputs)
314
+ input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"]
315
+ result_keys = ["D", "F", "F_row_max", "E_col_max"]
316
+ launcher.verify((m, n, k), input_keys, result_keys, 1)
317
+
318
+ if __name__ == '__main__':
319
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ """
34
+ Unit test for store nodes in SM90
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen.backend import *
42
+ from cutlass_cppgen.epilogue import *
43
+
44
+ from utils.evt_testbed import EVTTestBed, EVTTestCaseBase
45
+
46
+ cutlass_cppgen.set_log_level(logging.WARNING)
47
+
48
+
49
+ @unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]")
50
+ class TestEVTStore(EVTTestCaseBase):
51
+
52
+ @unittest.skipIf(device_cc() != 90, "This test is only for CC 90")
53
+ def test_invalid_store(self):
54
+ """
55
+ Test invalid store
56
+ """
57
+ def evt_invalid_store(accum):
58
+ D = accum
59
+ F = D + 1 # D has users, which is not allowed on SM90 or higher
60
+ return D, F
61
+
62
+ for m, n, k, l in self.get_problem_sizes(8):
63
+ example_inputs = {
64
+ "accum": self.fake_tensor(self.element, (l, m, n)),
65
+ "D": self.fake_tensor(self.element, (l, m, n)),
66
+ "F": self.fake_tensor(self.element, (l, m, n))
67
+ }
68
+ with self.assertRaisesRegex(
69
+ RuntimeError,
70
+ r"On SM90 or higher, D is expected to be a output node with 0 users "
71
+ r"to enable smem reuse between C and D, but got 1"
72
+ ):
73
+ launcher = EVTTestBed(self.element, evt_invalid_store, example_inputs)
74
+
75
+ break # Only need to test once
76
+
77
+ def test_aux_store(self):
78
+ """
79
+ Returning a tensor with shape [m, n]
80
+ """
81
+ def evt_aux_store(accum, alpha, C):
82
+ F = alpha * accum
83
+ D = F + C
84
+ return D, F
85
+
86
+ for m, n, k, l in self.get_problem_sizes(8):
87
+ example_inputs = {
88
+ "accum": self.fake_tensor(self.element, (l, m, n)),
89
+ "alpha": 0.5,
90
+ "C": self.fake_tensor(self.element, (l, m, n)),
91
+ "F": self.fake_tensor(self.element, (l, m, n)),
92
+ "D": self.fake_tensor(self.element, (l, m, n)),
93
+ }
94
+
95
+ launcher = EVTTestBed(self.element, evt_aux_store, example_inputs)
96
+ input_keys = ["C", "alpha"]
97
+ result_keys = ["D", "F"]
98
+ launcher.verify((m, n, k), input_keys, result_keys, l)
99
+
100
+ def test_col_reduce(self):
101
+ """
102
+ Reduction [m, n] -> [m, 1]
103
+ """
104
+ def evt_row_reduce(accum, alpha, C):
105
+ acc_row_max = max(accum, dim=[2,])
106
+ F = alpha * accum
107
+ F_row_max = max(F, dim=[0, 2])
108
+ D = F + C
109
+ return D, F_row_max, acc_row_max
110
+
111
+ for m, n, k, l in self.get_problem_sizes(8):
112
+ example_inputs = {
113
+ "accum": self.fake_tensor(self.element, (l, m, n)),
114
+ "alpha": 2.0,
115
+ "C": self.fake_tensor(self.element, (l, m, n)),
116
+ "F_row_max": self.fake_tensor(np.float32, (m, 1)),
117
+ "acc_row_max": self.fake_tensor(np.float32, (l, m, 1)),
118
+ "D": self.fake_tensor(self.element, (l, m, n)),
119
+ }
120
+
121
+ launcher = EVTTestBed(self.element, evt_row_reduce, example_inputs)
122
+ input_keys = ["C", "alpha"]
123
+ result_keys = ["D", "F_row_max", "acc_row_max"]
124
+ launcher.verify((m, n, k), input_keys, result_keys, l)
125
+
126
+ def test_row_reduce(self):
127
+ """
128
+ Reduction [m, n] -> [n]
129
+ """
130
+ def evt_col_reduce(accum, alpha, C):
131
+ acc_col_max = max(accum, dim=[1,])
132
+ F = alpha * accum
133
+ F_col_max = max(F, dim=[0, 1])
134
+ D = F + C
135
+ return D, F_col_max, acc_col_max
136
+
137
+ for m, n, k, l in self.get_problem_sizes(8):
138
+ example_inputs = {
139
+ "accum": self.fake_tensor(self.element, (l, m, n)),
140
+ "alpha": 2.0,
141
+ "C": self.fake_tensor(self.element, (l, m, n)),
142
+ "F_col_max": self.fake_tensor(np.float32, (n,)),
143
+ "acc_col_max": self.fake_tensor(np.float32, (l, 1, n)),
144
+ "D": self.fake_tensor(self.element, (l, m, n)),
145
+ }
146
+
147
+ launcher = EVTTestBed(self.element, evt_col_reduce, example_inputs)
148
+ input_keys = ["C", "alpha"]
149
+ result_keys = ["D", "F_col_max", "acc_col_max"]
150
+ launcher.verify((m, n, k), input_keys, result_keys, l)
151
+
152
+ def test_scalar_reduce(self):
153
+ """
154
+ Reduction [m, n] -> [1,]
155
+ """
156
+ def evt_scalar_reduce(accum, alpha, C):
157
+ acc_max = max(accum, dim=[1, 2])
158
+ F = alpha * accum
159
+ F_max = max(F, dim=[0, 1, 2])
160
+ D = F + C
161
+ return D, F_max, acc_max
162
+
163
+ for m, n, k, l in self.get_problem_sizes(8):
164
+ example_inputs = {
165
+ "accum": self.fake_tensor(self.element, (l, m, n)),
166
+ "alpha": 2.0,
167
+ "C": self.fake_tensor(self.element, (l, m, n)),
168
+ "acc_max": self.fake_tensor(np.float32, (l, 1, 1)),
169
+ "F_max": self.fake_tensor(np.float32, (1,)),
170
+ "D": self.fake_tensor(self.element, (l, m, n)),
171
+ }
172
+
173
+ launcher = EVTTestBed(self.element, evt_scalar_reduce, example_inputs)
174
+ input_keys = ["C", "alpha"]
175
+ result_keys = ["D", "F_max", "acc_max"]
176
+ launcher.verify((m, n, k), input_keys, result_keys, l)
177
+
178
+
179
+ if __name__ == '__main__':
180
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import pathlib
34
+ import unittest
35
+
36
+
37
+ if __name__ == '__main__':
38
+ loader = unittest.TestLoader()
39
+ script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/'
40
+ tests = loader.discover(script_dir, 'evt_*.py')
41
+ testRunner = unittest.runner.TextTestRunner()
42
+ results = testRunner.run(tests)
43
+ if not results.wasSuccessful():
44
+ raise Exception('Test cases failed')
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ ################################################################################
32
+
33
+ """
34
+ Testbed classes of EVT
35
+ """
36
+
37
+ import torch
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen import Tensor
42
+ import cutlass_cppgen.backend.evt
43
+ from cutlass_cppgen.shape import GemmCoord
44
+ from cutlass_cppgen.utils.datatypes import torch_type
45
+ from cutlass_cppgen.utils.profiler import CUDAEventProfiler
46
+
47
+
48
+ class EVTReferenceModule:
49
+ def __init__(self, layout_A, layout_B, layout_C, epilogue_visitor):
50
+ self.layout_A = layout_A
51
+ self.layout_B = layout_B
52
+ self.layout_C = layout_C
53
+ self.epilogue_visitor = epilogue_visitor
54
+
55
+ def run(self, A, B, C, problem_size, alpha, beta, batch=1):
56
+ if self.layout_A == cutlass_cppgen.LayoutType.RowMajor:
57
+ A_row = A.view((batch, problem_size.m, problem_size.k))
58
+ else:
59
+ A_col = A.view((batch, problem_size.k, problem_size.m))
60
+ A_row = torch.permute(A_col, (0, 2, 1))
61
+
62
+ if self.layout_B == cutlass_cppgen.LayoutType.RowMajor:
63
+ B_row = B.view((batch, problem_size.k, problem_size.n))
64
+ else:
65
+ B_col = B.view((batch, problem_size.n, problem_size.k))
66
+ B_row = torch.permute(B_col, (0, 2, 1))
67
+
68
+ if self.layout_C == cutlass_cppgen.LayoutType.RowMajor:
69
+ C_row = C.view((batch, problem_size.m, problem_size.n))
70
+ else:
71
+ C_col = C.view((batch, problem_size.n, problem_size.m))
72
+ C_row = torch.permute(C_col, (0, 2, 1))
73
+
74
+ out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta
75
+
76
+ if self.layout_C == cutlass_cppgen.LayoutType.ColumnMajor:
77
+ out = torch.permute(out_row, (0, 2, 1))
78
+ else:
79
+ out = out_row
80
+
81
+ return torch.flatten(out)
82
+
83
+ def __call__(self, A, B, C, problem_size, batch=1, epilogue_args=None):
84
+ # Running the mainloop
85
+ accum = self.run(
86
+ A, B, C, problem_size, 1.0, 0.0, batch=batch
87
+ ).reshape(batch, problem_size.m, problem_size.n)
88
+
89
+ # Running the epilogue
90
+ epilogue_args["accum"] = accum
91
+ references = self.epilogue_visitor(**epilogue_args)
92
+
93
+ # Return the results
94
+ if not isinstance(references, tuple):
95
+ references = (references,)
96
+ return references
97
+
98
+
99
+ class EVTTestBed:
100
+ """
101
+ Epilogue Visitor Testbed
102
+ """
103
+ def __init__(self, element, evt_fn, example_inputs, profile=False, **kwargs) -> None:
104
+ self.element = element
105
+ layout = cutlass_cppgen.LayoutType.RowMajor
106
+ self.example_inputs = example_inputs
107
+
108
+ # Create the Gemm plan
109
+ self.plan = cutlass_cppgen.op.Gemm(element=element, layout=layout, element_accumulator=torch.float32)
110
+
111
+ if "tile_description" in kwargs:
112
+ self.plan.tile_description = kwargs["tile_description"]
113
+
114
+ if "swizzling_functor" in kwargs:
115
+ self.plan.swizzling_functor = kwargs["swizzling_functor"]
116
+
117
+ # Compile the epilogue visitor
118
+ epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_fn, example_inputs)
119
+ if "epilogue_stages" in kwargs:
120
+ epilogue_visitor.epilogue_stages = kwargs["epilogue_stages"]
121
+ self.plan.epilogue_visitor = epilogue_visitor
122
+
123
+ # Reference model
124
+ self.reference_fn = EVTReferenceModule(layout, layout, layout, epilogue_visitor)
125
+
126
+ self.profile = profile
127
+
128
+ def get_torch_tensor(self, shape, dtype=None, fill=None):
129
+ if dtype is None:
130
+ dtype = self.element
131
+
132
+ dtype = torch_type(dtype)
133
+ if fill is None:
134
+ return torch.ceil(
135
+ torch.empty(size=shape, dtype=dtype, device="cuda").uniform_(-4.5, 3.5)
136
+ )
137
+ else:
138
+ return torch.full(shape, fill, dtype=dtype, device="cuda")
139
+
140
+ def verify(self, problem_size, input_keys, result_keys, batch_count=1):
141
+ """
142
+ Verify the results
143
+ """
144
+ problem_size = GemmCoord(*problem_size)
145
+
146
+ # Initiate the GEMM arguments
147
+ tensor_A = self.get_torch_tensor((batch_count, problem_size.m, problem_size.k))
148
+ tensor_B = self.get_torch_tensor((batch_count, problem_size.k, problem_size.n))
149
+
150
+ # Initialize the epilogue args
151
+ epilogue_args = {}
152
+ for key in self.example_inputs.keys():
153
+ if key in input_keys:
154
+ tensor = self.example_inputs[key]
155
+ if isinstance(tensor, Tensor):
156
+ epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element)
157
+ else:
158
+ epilogue_args[key] = tensor
159
+ elif key in result_keys:
160
+ tensor = self.example_inputs[key]
161
+ if isinstance(tensor, Tensor):
162
+ if "max" in key:
163
+ fill = -1000
164
+ else:
165
+ fill = 0
166
+ epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element, fill=fill)
167
+ else:
168
+ epilogue_args[key] = tensor
169
+
170
+ tensor_D = epilogue_args["D"]
171
+ if "C" in epilogue_args:
172
+ tensor_C = epilogue_args["C"]
173
+ else:
174
+ tensor_C = tensor_D
175
+ # Run the device kernel
176
+ self.plan.run(tensor_A, tensor_B, tensor_C, tensor_D, visitor_args=epilogue_args)
177
+
178
+ # Run the host reference
179
+ evt_args_inputs = {}
180
+ for key in input_keys:
181
+ evt_args_inputs[key] = epilogue_args[key]
182
+
183
+ reference_results = self.reference_fn(
184
+ tensor_A, tensor_B, tensor_C, problem_size, batch_count, evt_args_inputs)
185
+
186
+ # Compare the results
187
+ for result, ref in zip(result_keys, reference_results):
188
+ assert torch.equal(
189
+ epilogue_args[result].flatten(),
190
+ ref.masked_fill(torch.isnan(ref), float('inf')).flatten())
191
+
192
+ # Run profile
193
+ if self.profile:
194
+ profiler = CUDAEventProfiler(
195
+ self.plan, 100, 100, tensor_A, tensor_B, tensor_C, tensor_D,
196
+ visitor_args = epilogue_args
197
+ )
198
+ print(f"Cutlass Python Duration: {profiler()}")
199
+
200
+
201
+ class EVTTestCaseBase(unittest.TestCase):
202
+ """
203
+ Base class for EVT Unittest
204
+ """
205
+ def __init__(self, methodName: str = "runTest", lmnk=(6, 512, 256, 128)) -> None:
206
+ super().__init__(methodName)
207
+
208
+ self.element = cutlass_cppgen.DataType.f16
209
+ self.l, self.m, self.n, self.k = lmnk
210
+
211
+ self.problem_size = (self.m, self.n, self.k)
212
+
213
+ torch.random.manual_seed(42)
214
+
215
+ def fake_tensor(self, element, shape, stride=None):
216
+ if stride is None:
217
+ return Tensor(element=element, shape=shape, layout_tag=cutlass_cppgen.LayoutType.RowMajor)
218
+ else:
219
+ return Tensor(element=element, shape=shape, stride=stride)
220
+
221
+ def get_problem_sizes(self, alignment, k=None, batch_count=[3,]):
222
+ k = k if k else self.k
223
+ problem_size_m = [alignment, 512 - 3 * alignment]
224
+ problem_size_n = [alignment, 512 - alignment]
225
+ if alignment % 8 == 0:
226
+ problem_size_m.append(768)
227
+ problem_size_n.append(768)
228
+ problem_size_l = batch_count
229
+ problem_sizes = []
230
+ for m in problem_size_m:
231
+ for n in problem_size_n:
232
+ for l in problem_size_l:
233
+ problem_sizes.append((m, n, k, l))
234
+
235
+ return problem_sizes
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ High-level tests for running batched GEMMs
35
+ """
36
+
37
+ from functools import partial
38
+ import logging
39
+ from math import prod
40
+ import unittest
41
+
42
+ import cutlass_cppgen
43
+ from cutlass_cppgen.backend.utils.device import device_cc
44
+ import torch
45
+
46
+ from utils import LayoutCombination
47
+
48
+ cutlass_cppgen.set_log_level(logging.WARNING)
49
+
50
+ torch.manual_seed(2023)
51
+
52
+
53
+ def pytorch_reference(A, B, C, alpha, beta):
54
+ # Get the batch count. Assume that any of A, B, and C
55
+ # with a batch dimension ahve matching batch count. Thus,
56
+ # we break out of the loop once we have found the first
57
+ # tensor containing a batch dimension.
58
+ batch_count = (1,)
59
+ for tensor in [A, B, C]:
60
+ if len(tensor.shape) > 2:
61
+ batch_count = tensor.shape[:-2]
62
+ break
63
+
64
+ int_batch_count = prod(batch_count)
65
+
66
+ def add_batch(tensor):
67
+ if len(tensor.shape) == 2:
68
+ return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1)
69
+ else:
70
+ return tensor.reshape(-1, tensor.size(-2), tensor.size(-1))
71
+
72
+ # Reshape tensors to have batch dimension
73
+ A = add_batch(A)
74
+ B = add_batch(B)
75
+ C = add_batch(C)
76
+
77
+ ret = (torch.bmm(A, B) * alpha) + (C * beta)
78
+ reshape_vals = batch_count + C.shape[-2:]
79
+ return ret.reshape(*reshape_vals)
80
+
81
+
82
+ def initialize(rows, cols, batch):
83
+ tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half()
84
+ if len(batch) > 0 and prod(batch) > 1:
85
+ reshape_vals = batch + (rows, cols)
86
+ return tensor.reshape(*reshape_vals)
87
+ else:
88
+ return tensor.reshape(rows, cols)
89
+
90
+
91
+ class GemmF16Batched(unittest.TestCase):
92
+ def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool):
93
+ M = 512
94
+ N = 256
95
+ K = 128
96
+ alpha = 1.
97
+ beta = 2.
98
+
99
+ A = initialize(M, K, batch_count if batch_A else (1,))
100
+ B = initialize(K, N, batch_count if batch_B else (1,))
101
+ C = initialize(M, N, batch_count if batch_C else (1,))
102
+ D = initialize(M, N, batch_count)
103
+
104
+ plan = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass_cppgen.DataType.f32)
105
+ plan.run(A, B, C, D, alpha, beta)
106
+ reference = pytorch_reference(A, B, C, alpha, beta)
107
+ assert reference.equal(D)
108
+
109
+ def test_batched_ABC(self):
110
+ self.run_batched((3,), True, True, True)
111
+ self.run_batched((2, 3), True, True, True)
112
+
113
+ def test_batched_AB(self):
114
+ self.run_batched((3,), True, True, False)
115
+ self.run_batched((2, 3), True, True, False)
116
+
117
+ def test_batched_AC(self):
118
+ self.run_batched((3,), True, False, True)
119
+ self.run_batched((2, 3), True, False, True)
120
+
121
+ def test_batched_BC(self):
122
+ self.run_batched((3,), False, True, True)
123
+ self.run_batched((2, 3), False, True, True)
124
+
125
+ def test_batched_A(self):
126
+ self.run_batched((3,), True, False, False)
127
+ self.run_batched((2, 3), True, False, False)
128
+
129
+ def test_batched_B(self):
130
+ self.run_batched((3,), False, True, False)
131
+ self.run_batched((2, 3), False, True, False)
132
+
133
+ if __name__ == '__main__':
134
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Low-level functionality tests for GEMM with F16 operands on SM80
35
+ """
36
+
37
+ from functools import partial
38
+ import logging
39
+ import unittest
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+
44
+ from utils import LayoutCombination, add_test_gemm
45
+
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+ cc = 80
49
+ dtype = cutlass_cppgen.DataType.f16
50
+
51
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
52
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
53
+ class GemmF16Sm80(unittest.TestCase):
54
+ """
55
+ Wrapper class to which tests will be added dynamically in __main__
56
+ """
57
+ pass
58
+
59
+
60
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
61
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
62
+ class GemmF16Sm80StreamK(unittest.TestCase):
63
+ """
64
+ Wrapper class to which tests will be added dynamically in __main__
65
+ """
66
+ pass
67
+
68
+ add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
69
+
70
+ # Tests using TensorOp
71
+ add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
72
+
73
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
74
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
75
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
76
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
77
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
78
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
79
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
80
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
81
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
82
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
83
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
84
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
85
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
86
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
87
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
88
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
89
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
90
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
91
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
92
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3)
93
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
94
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3)
95
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
96
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
97
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
98
+ element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
99
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
100
+ element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
101
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
102
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5)
103
+ add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
104
+ element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
105
+
106
+ # Tests using SIMT
107
+ add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
108
+
109
+ add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
110
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
111
+ add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
112
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
113
+ add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
114
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
115
+ add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
116
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
117
+ add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
118
+ element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
119
+
120
+ # Stream K tests
121
+ add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK)
122
+ add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
123
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
124
+ add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
125
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5)
126
+
127
+ if __name__ == '__main__':
128
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Low-level functionality tests for GEMM with F16 operands on SM90
35
+ """
36
+
37
+ from functools import partial
38
+ import logging
39
+ import unittest
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+
44
+ from utils import LayoutCombination, add_test_gemm
45
+
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+ cc = 90
49
+ dtype = cutlass_cppgen.DataType.f16
50
+
51
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
52
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
53
+ class GemmF16Sm90(unittest.TestCase):
54
+ """
55
+ Wrapper class to which tests will be added dynamically in __main__
56
+ """
57
+ pass
58
+
59
+
60
+ add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype,
61
+ warp_count=None, compilation_modes=['nvcc'])
62
+
63
+ add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
64
+
65
+ # Tests with 1x1x1 clusters
66
+ add_test_unit_cluster = partial(add_test_tensorop, cluster_shape=[1, 1, 1])
67
+ add_test_unit_cluster(layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
68
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=3)
69
+ add_test_unit_cluster(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
70
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None)
71
+ add_test_unit_cluster(layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
72
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None)
73
+ add_test_unit_cluster(layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
74
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None)
75
+ add_test_unit_cluster(layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
76
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None)
77
+ add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16,
78
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None)
79
+ add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16,
80
+ element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None)
81
+ add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
82
+ element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None)
83
+ add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
84
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], stages=5)
85
+ add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16,
86
+ element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None)
87
+
88
+ # Tests with different cluster shapes
89
+ add_test_cluster_shape = partial(add_test_tensorop, threadblock_shape=[64, 128, 64], stages=None)
90
+ add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
91
+ element_accumulator=cutlass_cppgen.DataType.f16, cluster_shape=[2, 2, 1])
92
+ add_test_cluster_shape(layouts=LayoutCombination.TNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
93
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1])
94
+ add_test_cluster_shape(layouts=LayoutCombination.NTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
95
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1])
96
+ add_test_cluster_shape(layouts=LayoutCombination.NNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
97
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1])
98
+ add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
99
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1])
100
+ add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
101
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 4, 1])
102
+ add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
103
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 1, 1])
104
+ add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32,
105
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 2, 1])
106
+
107
+ # Tests for different schedule modes
108
+ add_test_schedule = partial(add_test_specialized, layouts=LayoutCombination.TTN, alignments=[8, 8, 4],
109
+ element_output=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32,
110
+ opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], stages=None)
111
+ add_test_schedule(
112
+ cluster_shape=[1, 1, 1],
113
+ kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong,
114
+ epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized
115
+ )
116
+ add_test_schedule(
117
+ cluster_shape=[1, 1, 1],
118
+ kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
119
+ epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative
120
+ )
121
+ add_test_schedule(
122
+ cluster_shape=[2, 1, 1],
123
+ kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong,
124
+ epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized
125
+ )
126
+ add_test_schedule(
127
+ cluster_shape=[2, 1, 1],
128
+ kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative,
129
+ epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative
130
+ )
131
+
132
+ # Tests using SIMT
133
+ add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], stages=2)
134
+ add_test_simt(layouts=LayoutCombination.NNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8])
135
+ add_test_simt(layouts=LayoutCombination.TNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8])
136
+ add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8])
137
+ add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8])
138
+ add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8])
139
+
140
+ # Tests with void-C kernels
141
+ add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16,
142
+ element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None,
143
+ cluster_shape=[2, 1, 1], element_C=cutlass_cppgen.DataType.void)
144
+
145
+ if __name__ == '__main__':
146
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Low-level functionality tests for GEMM with F32 operands on SM80
35
+ """
36
+
37
+ from functools import partial
38
+ import logging
39
+ import unittest
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+
44
+ from utils import LayoutCombination, add_test_gemm
45
+
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+ cc = 80
49
+ dtype = cutlass_cppgen.DataType.f32
50
+
51
+
52
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
53
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
54
+ class GemmF32Sm80(unittest.TestCase):
55
+ """
56
+ Wrapper class to which tests will be added dynamically in __main__
57
+ """
58
+ pass
59
+
60
+
61
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
62
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
63
+ class GemmF32Sm80StreamK(unittest.TestCase):
64
+ """
65
+ Wrapper class to which tests will be added dynamically in __main__
66
+ """
67
+ pass
68
+
69
+
70
+ add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
71
+
72
+ # Tests using TensorOp
73
+ add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
74
+
75
+ add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
76
+ element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
77
+ add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
78
+ element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
79
+ add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
80
+ element_accumulator=dtype, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3)
81
+ add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
82
+ element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4)
83
+ # Tests using SIMT
84
+ add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
85
+
86
+ add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
87
+ element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
88
+ add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
89
+ element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
90
+ add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
91
+ element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
92
+ add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
93
+ element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
94
+ add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
95
+ element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
96
+
97
+ # Stream K tests
98
+ add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK)
99
+ add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype,
100
+ element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3)
101
+
102
+
103
+ if __name__ == '__main__':
104
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Low-level functionality tests for GEMM with F64 operands on SM80
35
+ """
36
+
37
+ from functools import partial
38
+ import logging
39
+ import unittest
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+
44
+ from utils import LayoutCombination, add_test_gemm
45
+
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+ cc = 80
49
+ dtype = cutlass_cppgen.DataType.f64
50
+
51
+
52
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
53
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
54
+ class GemmF64Sm80(unittest.TestCase):
55
+ """
56
+ Wrapper class to which tests will be added dynamically in __main__
57
+ """
58
+ pass
59
+
60
+
61
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
62
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
63
+ class GemmF64Sm80StreamK(unittest.TestCase):
64
+ """
65
+ Wrapper class to which tests will be added dynamically in __main__
66
+ """
67
+ pass
68
+
69
+
70
+ add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
71
+
72
+ # Tests using TensorOp
73
+ add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
74
+
75
+ add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
76
+ element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
77
+ add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
78
+ element_accumulator=dtype, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4)
79
+ add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
80
+ element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5)
81
+
82
+ # Tests using SIMT
83
+ add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
84
+
85
+ add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
86
+ element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
87
+ add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
88
+ element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
89
+ add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
90
+ element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
91
+ add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
92
+ element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
93
+ add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
94
+ element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
95
+
96
+ # Stream K tests
97
+ add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK)
98
+ add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype,
99
+ element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3)
100
+
101
+
102
+ if __name__ == '__main__':
103
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Low-level functionality tests for GEMM with F64 operands on SM90
35
+ """
36
+
37
+ from functools import partial
38
+ import logging
39
+ import unittest
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+
44
+ from utils import LayoutCombination, add_test_gemm
45
+
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+ cc = 90
49
+ dtype = cutlass_cppgen.DataType.f64
50
+
51
+
52
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
53
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
54
+ class GemmF64Sm90(unittest.TestCase):
55
+ """
56
+ Wrapper class to which tests will be added dynamically in __main__
57
+ """
58
+ pass
59
+
60
+
61
+ add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1],
62
+ element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc'])
63
+
64
+ add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3)
65
+ add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3)
66
+ add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.NNN, threadblock_shape=[128, 128, 8], stages=2)
67
+ add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.TTT, threadblock_shape=[ 64, 128, 8], stages=2)
68
+
69
+
70
+ if __name__ == '__main__':
71
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Low-level functionality tests for GEMM with S8 operands on SM90
35
+ """
36
+
37
+ from functools import partial
38
+ import logging
39
+ import unittest
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+
44
+ from utils import LayoutCombination, add_test_gemm
45
+
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+ cc = 90
49
+ dtype = cutlass_cppgen.DataType.e4m3
50
+
51
+
52
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
53
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
54
+ class GemmF8E4M3Sm90(unittest.TestCase):
55
+ """
56
+ Wrapper class to which tests will be added dynamically in __main__
57
+ """
58
+ pass
59
+
60
+
61
+ add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc'])
62
+
63
+ add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
64
+
65
+ # Test with 1x1x1 clusters
66
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3,
67
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None)
68
+
69
+ # Tests with different cluster shapes
70
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3,
71
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None)
72
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3,
73
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None)
74
+
75
+ # Tests with warp-specialized ping-pong schedule
76
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3,
77
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None,
78
+ kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong,
79
+ epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized)
80
+
81
+ # Tests for SIMT
82
+ add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
83
+ add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.e4m3,
84
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2)
85
+
86
+
87
+ #
88
+ # Add a test for E5M2
89
+ #
90
+ dtype = cutlass_cppgen.DataType.e5m2
91
+
92
+
93
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
94
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
95
+ class GemmF8E5M2Sm90(unittest.TestCase):
96
+ """
97
+ Wrapper class to which tests will be added dynamically in __main__
98
+ """
99
+ pass
100
+
101
+
102
+ add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc'])
103
+
104
+ add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
105
+
106
+ # Tests with 1x1x1 clusters
107
+ add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype,
108
+ element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3)
109
+
110
+
111
+ if __name__ == '__main__':
112
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Low-level functionality tests for GEMM with mixed operands on SM80
35
+ """
36
+
37
+ from functools import partial
38
+ import logging
39
+ import unittest
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+
44
+ from utils import LayoutCombination, add_test_gemm
45
+
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+ cc = 80
49
+ dtype =cutlass_cppgen.DataType.f16
50
+
51
+
52
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
53
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
54
+ class GemmMixedSm80(unittest.TestCase):
55
+ """
56
+ Wrapper class to which tests will be added dynamically in __main__
57
+ """
58
+ pass
59
+
60
+
61
+ add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1],
62
+ opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64],
63
+ warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass_cppgen.DataType.f32)
64
+
65
+ # Test with upcast on A
66
+ add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT)
67
+ add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN)
68
+
69
+ # Test with upcast on B
70
+ add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT)
71
+ add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN)
72
+
73
+
74
+ if __name__ == '__main__':
75
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Low-level functionality tests for GEMM with S8 operands on SM80
35
+ """
36
+
37
+ from functools import partial
38
+ import logging
39
+ import unittest
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+
44
+ from utils import LayoutCombination, add_test_gemm
45
+
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+ cc = 80
49
+ dtype = cutlass_cppgen.DataType.s8
50
+
51
+
52
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
53
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
54
+ class GemmS8Sm80(unittest.TestCase):
55
+ """
56
+ Wrapper class to which tests will be added dynamically in __main__
57
+ """
58
+ pass
59
+
60
+
61
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.')
62
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
63
+ class GemmS8Sm80StreamK(unittest.TestCase):
64
+ """
65
+ Wrapper class to which tests will be added dynamically in __main__
66
+ """
67
+ pass
68
+
69
+
70
+ add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1])
71
+
72
+ # Tests using TensorOp
73
+ add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
74
+
75
+ add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
76
+ element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3)
77
+ add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
78
+ element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3)
79
+ add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32,
80
+ element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4)
81
+
82
+ # Tests using SIMT
83
+ add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
84
+
85
+ add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
86
+ element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
87
+ add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
88
+ element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2)
89
+ add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
90
+ element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2)
91
+ add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32,
92
+ element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2)
93
+ add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32,
94
+ element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2)
95
+
96
+ # Stream K tests
97
+ add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK)
98
+ add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8,
99
+ element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3)
100
+
101
+
102
+ if __name__ == '__main__':
103
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Low-level functionality tests for GEMM with S8 operands on SM90
35
+ """
36
+
37
+ from functools import partial
38
+ import logging
39
+ import unittest
40
+
41
+ import cutlass_cppgen
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+
44
+ from utils import LayoutCombination, add_test_gemm
45
+
46
+
47
+ cutlass_cppgen.set_log_level(logging.WARNING)
48
+ cc = 90
49
+ dtype = cutlass_cppgen.DataType.s8
50
+
51
+
52
+ @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.')
53
+ @unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}')
54
+ class GemmS8Sm90(unittest.TestCase):
55
+ """
56
+ Wrapper class to which tests will be added dynamically in __main__
57
+ """
58
+ pass
59
+
60
+
61
+ add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc'])
62
+
63
+ add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp)
64
+
65
+ # Tests with 1x1x1 clusters
66
+ add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
67
+ element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3)
68
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
69
+ element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None)
70
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 8], element_output=cutlass_cppgen.DataType.s8,
71
+ element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None)
72
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
73
+ element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 128, 128], stages=None)
74
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
75
+ element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 64, 32], stages=None)
76
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[ 4, 4, 16], element_output=cutlass_cppgen.DataType.s8,
77
+ element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None)
78
+
79
+ # Tests with different cluster shapes
80
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
81
+ element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None)
82
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
83
+ element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None)
84
+
85
+ # Tests with warp-specialized ping-pong schedule
86
+ add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8,
87
+ element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None,
88
+ kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong,
89
+ epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized)
90
+
91
+ # Tests for SIMT
92
+ add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt)
93
+ add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8,
94
+ element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2)
95
+
96
+
97
+ if __name__ == '__main__':
98
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from math import prod
34
+ import os
35
+ import re
36
+ import subprocess
37
+
38
+ import torch
39
+
40
+ from cutlass_library import (
41
+ DataType,
42
+ DataTypeSize,
43
+ GemmUniversalMode,
44
+ LayoutType,
45
+ OpcodeClass,
46
+ ShortDataTypeNames,
47
+ SwizzlingFunctor
48
+ )
49
+
50
+ from cutlass_cppgen.backend import compiler
51
+ from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal
52
+ from cutlass_cppgen.backend.reduction_operation import ReductionArguments, ReductionOperation
53
+ from cutlass_cppgen.shape import GemmCoord, MatrixCoord
54
+ from cutlass_cppgen.utils.datatypes import torch_type
55
+
56
+
57
+ class GemmUniversalLauncher:
58
+ def __init__(
59
+ self,
60
+ operation,
61
+ seed=2080,
62
+ verification=True,
63
+ iterations=500,
64
+ compiler_mode= "nvcc",
65
+ **kwargs,
66
+ ) -> None:
67
+ self.math_operation = operation.tile_description.math_instruction.math_operation
68
+ self.verification = verification
69
+
70
+ if compiler_mode == "nvcc":
71
+ compiler.nvcc()
72
+ elif compiler_mode == "nvrtc":
73
+ compiler.nvrtc()
74
+ else:
75
+ raise Exception(f"Unexpected compiler string {compiler_mode}")
76
+
77
+ op_list = [operation]
78
+ if operation.arch < 90:
79
+ # Split K via Python is currently only supported for pre-SM90 kernels
80
+ self.reduction_operation: ReductionOperation = ReductionOperation(
81
+ shape=MatrixCoord(4, 32 * operation.C.alignment),
82
+ C=operation.C,
83
+ element_accumulator=operation.tile_description.math_instruction.element_accumulator,
84
+ element_compute=operation.epilogue_functor.element_epilogue,
85
+ epilogue_functor=operation.epilogue_functor,
86
+ count=operation.C.alignment,
87
+ )
88
+ op_list.append(self.reduction_operation)
89
+
90
+ compiler.add_module(op_list, bypass_cache=False)
91
+
92
+ self.operation = operation
93
+
94
+ self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element)
95
+ self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element)
96
+ self.dtype_C = torch_type(operation.C.element)
97
+ self.dtype_D = torch_type(operation.epilogue_functor.element_output)
98
+
99
+ element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element])
100
+
101
+ if element_size == 1:
102
+ self.rand_max = 1
103
+ self.rand_min = 0
104
+ elif element_size <= 8:
105
+ self.rand_max = 1
106
+ self.rand_min = -1
107
+ elif element_size == 16:
108
+ self.rand_max = 4
109
+ self.rand_min = -4
110
+ else:
111
+ self.rand_max = 8
112
+ self.rand_min = -8
113
+
114
+ self.seed = seed
115
+
116
+ self.compute_type = operation.epilogue_functor.element_epilogue
117
+ self.accumulator_type = operation.tile_description.math_instruction.element_accumulator
118
+
119
+ def print_problem_size(self, p, mode, batch_count):
120
+ if mode == GemmUniversalMode.Gemm:
121
+ mode = "Gemm"
122
+ elif mode == GemmUniversalMode.Batched:
123
+ mode = "GemmBatched"
124
+ elif mode == GemmUniversalMode.GemmSplitKParallel:
125
+ mode = "GemmSplitKParallel"
126
+ print(f"problem: {p.m}, {p.n}, {p.k}\n batch_count: {batch_count}\n mode: {mode}")
127
+
128
+ def uniform_init(self, shape, dtype, layout):
129
+ size = prod(shape)
130
+ if dtype.is_floating_point:
131
+ # Initialize data in FP32 and call convert to the data type we desire.
132
+ # This is a workaround for the following error that occurs when attempting to
133
+ # call uniform_ on a tensor with torch.float8_e4m3fn data:
134
+ # RuntimeError: "check_uniform_bounds" not implemented for 'Float8_e4m3fn'
135
+ data = torch.ceil(
136
+ torch.empty(size=(size,), dtype=torch.float32, device="cuda").uniform_(
137
+ self.rand_min - 0.5, self.rand_max - 0.5)
138
+ ).to(dtype)
139
+ else:
140
+ # PyTorch does not currently support integer-typed matrix multiplications on GPU.
141
+ # Fall back to CPU for integer type references.
142
+ data = torch.empty(size=(size,), dtype=dtype, device="cpu").random_(self.rand_min, self.rand_max + 1)
143
+
144
+ is_fp8 = dtype == getattr(torch, "float8_e4m3fn", -1) or dtype == dtype == getattr(torch, "float8_e5m2", -1)
145
+
146
+ if dtype == torch.float64 or dtype == torch.float32 or is_fp8:
147
+ data = data.to("cpu")
148
+
149
+ data_ref = data.reshape(shape)
150
+
151
+ if layout == LayoutType.RowMajor:
152
+ data_cutlass = data_ref
153
+ else:
154
+ data_cutlass = data_ref.transpose(-1, -2).contiguous()
155
+
156
+ data_cutlass = data_cutlass.to("cuda")
157
+
158
+ # As of this writing, few operations in PyTorch are supported with FP8 data.
159
+ # Thus, we perform computation in FP32 for FP8 reference checks.
160
+ if is_fp8:
161
+ data_ref = data_ref.to(torch.float32)
162
+
163
+ return data_cutlass, data_ref
164
+
165
+ def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta):
166
+ # If any tensor is on CPU, place all tensors on CPU unless only
167
+ # tensor C is on CPU
168
+ # Handle mixed-input cases by casting to the larger data type and overriding
169
+ # to whatever the data type of the larger type is
170
+ if self.dtype_A != self.dtype_B:
171
+ if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]:
172
+ tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device)
173
+ else:
174
+ tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device)
175
+
176
+ devices = [x.device.type for x in [tensor_A, tensor_B]]
177
+ if tensor_C is not None:
178
+ devices.append(tensor_C.device.type)
179
+
180
+ if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]:
181
+ device = torch.device("cpu")
182
+ else:
183
+ device = tensor_A.device
184
+
185
+ tensor_A = tensor_A.to(device)
186
+ tensor_B = tensor_B.to(device)
187
+ if tensor_C is not None:
188
+ tensor_C = tensor_C.to(device)
189
+
190
+ dtype = torch_type(self.compute_type)
191
+ alpha_torch = torch.tensor([alpha], device=device).to(dtype)
192
+ beta_torch = torch.tensor([beta], device=device).to(dtype)
193
+
194
+ tmp = tensor_A @ tensor_B
195
+ tensor_D_ref = (alpha_torch * tmp)
196
+ if tensor_C is not None:
197
+ tensor_D_ref += (tensor_C * beta_torch)
198
+ return tensor_D_ref.to(self.dtype_D)
199
+
200
+ def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0):
201
+ torch.random.manual_seed(self.seed)
202
+
203
+ # Assign an actual batch count in cases where we are not running in batched mode.
204
+ # This is to differentiate between the number of split K slices and the batch count,
205
+ # which are overloaded within the single `batch_count` variable.
206
+ if mode == GemmUniversalMode.Batched:
207
+ true_batch_count = batch_count
208
+ else:
209
+ true_batch_count = 1
210
+
211
+ def transpose(layout):
212
+ if layout == LayoutType.RowMajor:
213
+ return LayoutType.ColumnMajor
214
+ else:
215
+ return LayoutType.RowMajor
216
+
217
+ tensor_A, tensor_A_ref = self.uniform_init(
218
+ (true_batch_count, problem_size.m, problem_size.k),
219
+ self.dtype_A,
220
+ self.operation.A.layout if not self.operation.switched else transpose(self.operation.B.layout),
221
+ )
222
+ tensor_B, tensor_B_ref = self.uniform_init(
223
+ (true_batch_count, problem_size.k, problem_size.n),
224
+ self.dtype_B,
225
+ self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout),
226
+ )
227
+ if self.dtype_C is not None:
228
+ tensor_C, tensor_C_ref = self.uniform_init(
229
+ (true_batch_count, problem_size.m, problem_size.n),
230
+ self.dtype_C,
231
+ self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout),
232
+ )
233
+ else:
234
+ tensor_C = None
235
+ tensor_C_ref = None
236
+
237
+ tensor_D, _ = self.uniform_init(
238
+ (true_batch_count, problem_size.m, problem_size.n),
239
+ self.dtype_D,
240
+ self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout),
241
+ )
242
+ tensor_D = torch.zeros_like(tensor_D)
243
+
244
+ if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]:
245
+ alpha = int(alpha)
246
+ beta = int(beta)
247
+
248
+ #
249
+ # Launch kernel
250
+ #
251
+
252
+ arguments = GemmArguments(
253
+ operation=self.operation,
254
+ problem_size=problem_size,
255
+ A=tensor_A,
256
+ B=tensor_B,
257
+ C=tensor_C,
258
+ D=tensor_D,
259
+ output_op=self.operation.epilogue_type(alpha, beta),
260
+ gemm_mode=mode,
261
+ split_k_slices=split_k_slices,
262
+ batch=batch_count,
263
+ )
264
+
265
+ if mode == GemmUniversalMode.GemmSplitKParallel:
266
+ reduction_arguments = ReductionArguments(
267
+ self.reduction_operation,
268
+ problem_size=[problem_size.m, problem_size.n],
269
+ partitions=split_k_slices,
270
+ workspace=arguments.ptr_D,
271
+ destination=tensor_D,
272
+ source=tensor_C,
273
+ output_op=self.reduction_operation.epilogue_type(alpha, beta),
274
+ )
275
+
276
+ self.operation.run(arguments)
277
+
278
+ if mode == GemmUniversalMode.GemmSplitKParallel:
279
+ self.reduction_operation.run(reduction_arguments)
280
+
281
+ passed = True
282
+
283
+ if self.verification:
284
+ if mode == GemmUniversalMode.GemmSplitKParallel:
285
+ reduction_arguments.sync()
286
+
287
+ # Free memory allocated by args because we are not
288
+ # calling `arguments.sync()` in this case (which will free memory)
289
+ arguments.free()
290
+ else:
291
+ arguments.sync()
292
+ tensor_D_ref = self.reference(
293
+ problem_size,
294
+ tensor_A_ref,
295
+ tensor_B_ref,
296
+ tensor_C_ref,
297
+ alpha,
298
+ beta,
299
+ )
300
+
301
+ tensor_D_ref = tensor_D_ref.to('cuda')
302
+
303
+ if self.operation.switched or self.operation.C.layout == LayoutType.ColumnMajor:
304
+ tensor_D = tensor_D.transpose(-1, -2).contiguous()
305
+
306
+ passed = tensor_D.equal(tensor_D_ref)
307
+
308
+ try:
309
+ assert passed
310
+ except AssertionError:
311
+ self.print_problem_size(problem_size, mode, batch_count)
312
+ del arguments
313
+ if mode == GemmUniversalMode.GemmSplitKParallel:
314
+ del reduction_arguments
315
+
316
+ return passed
317
+
318
+
319
+ def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"):
320
+ passed = True
321
+
322
+ minimum_operand_element_size = min(
323
+ DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]
324
+ )
325
+ opcode_class = operation.tile_description.math_instruction.opcode_class
326
+
327
+ if opcode_class == OpcodeClass.Simt:
328
+ alignment = 1
329
+ else:
330
+ alignment = 128 // minimum_operand_element_size
331
+
332
+ alignment_m = alignment
333
+ alignment_n = alignment
334
+ alignment_k = alignment
335
+
336
+ # INT8 alignment constraints
337
+ if opcode_class == OpcodeClass.Simt:
338
+ A_is_s8 = operation.A.element == DataType.s8
339
+ B_is_s8 = operation.B.element == DataType.s8
340
+
341
+ if A_is_s8 and operation.A.layout == LayoutType.ColumnMajor:
342
+ alignment_m = 4
343
+ if B_is_s8 == DataType.s8 and operation.A.layout == LayoutType.RowMajor:
344
+ alignment_n = 4
345
+ if A_is_s8 and B_is_s8 and (operation.A.layout == LayoutType.RowMajor or operation.B.layout == LayoutType.ColumnMajor):
346
+ alignment_k = 4
347
+
348
+ threadblock_k = operation.tile_description.threadblock_shape[2]
349
+
350
+ assert testcase != "interleaved"
351
+
352
+ supports_split_k = operation.arch < 90 and not operation.swizzling_functor == SwizzlingFunctor.StreamK
353
+
354
+ if testcase == "multistage":
355
+ modes = [GemmUniversalMode.Gemm]
356
+ problem_size_m = [16, 528]
357
+ problem_size_n = [16, 528]
358
+ problem_size_k = [
359
+ threadblock_k,
360
+ threadblock_k * operation.tile_description.stages
361
+ + operation.tile_description.math_instruction.instruction_shape[2],
362
+ ]
363
+ problem_alpha = [1.0]
364
+ problem_beta = [0.0]
365
+ batch_counts = [1]
366
+ else:
367
+ modes = [GemmUniversalMode.Gemm]
368
+ batch_counts = [1, 2, 3, 5, 7]
369
+ if supports_split_k:
370
+ modes.append(GemmUniversalMode.GemmSplitKParallel)
371
+
372
+ problem_size_m = [alignment_m, 512 - 3 * alignment_m]
373
+ problem_size_n = [alignment_n, 512 - 2 * alignment_n]
374
+ if operation.tile_description.stages is None:
375
+ stages_for_k_calc = 7
376
+ else:
377
+ stages_for_k_calc = operation.tile_description.stages
378
+ problem_size_k = [
379
+ alignment_k,
380
+ threadblock_k * stages_for_k_calc - alignment_k,
381
+ threadblock_k * stages_for_k_calc * 3 - alignment_k,
382
+ ]
383
+ problem_alpha = [1.0]
384
+ problem_beta = [2.0]
385
+
386
+ testbed = GemmUniversalLauncher(operation, compiler_mode=compilation_mode)
387
+
388
+ for mode in modes:
389
+ for m in problem_size_m:
390
+ for n in problem_size_n:
391
+ for k in problem_size_k:
392
+ for batch_count in batch_counts:
393
+ for alpha in problem_alpha:
394
+ for beta in problem_beta:
395
+ # skip very small K problems
396
+ if testcase == "universal":
397
+ if k // batch_count < 2 * threadblock_k:
398
+ continue
399
+
400
+ problem_size = GemmCoord(m, n, k)
401
+
402
+ if supports_split_k:
403
+ split_k_slices = batch_count
404
+ else:
405
+ split_k_slices = 1
406
+
407
+ overridden_mode = mode
408
+ if mode == GemmUniversalMode.Gemm and batch_count > 1:
409
+ overridden_mode = GemmUniversalMode.Batched
410
+
411
+ passed = testbed.run(
412
+ overridden_mode,
413
+ problem_size,
414
+ batch_count,
415
+ split_k_slices,
416
+ alpha,
417
+ beta,
418
+ )
419
+
420
+ if not passed:
421
+ return False
422
+
423
+ return passed
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ import pathlib
34
+ import unittest
35
+
36
+
37
+ if __name__ == '__main__':
38
+ loader = unittest.TestLoader()
39
+ script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/'
40
+ tests = loader.discover(script_dir, 'gemm_*.py')
41
+ testRunner = unittest.runner.TextTestRunner()
42
+ results = testRunner.run(tests)
43
+ if not results.wasSuccessful():
44
+ raise Exception('Test cases failed')
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ from cutlass_library import SubstituteTemplate
34
+
35
+ import cutlass_cppgen
36
+ from cutlass_library import (
37
+ DataTypeNames,
38
+ EpilogueScheduleSuffixes,
39
+ KernelScheduleSuffixes,
40
+ LayoutType,
41
+ OpcodeClassNames,
42
+ ShortDataTypeNames,
43
+ ShortLayoutTypeNames
44
+ )
45
+ from cutlass_cppgen.backend import library
46
+
47
+ from gemm_testbed import test_all_gemm
48
+
49
+
50
+ class Layout:
51
+ """
52
+ Utility class to map transpose and non-transpose terminology to row- and column-major terminology
53
+ """
54
+
55
+ T = LayoutType.RowMajor
56
+ N = LayoutType.ColumnMajor
57
+
58
+
59
+ class LayoutCombination:
60
+ """
61
+ Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs
62
+ """
63
+
64
+ NNN = (Layout.N, Layout.N, Layout.N)
65
+ NNT = (Layout.N, Layout.N, Layout.T)
66
+ NTN = (Layout.N, Layout.T, Layout.N)
67
+ NTT = (Layout.N, Layout.T, Layout.T)
68
+ TNN = (Layout.T, Layout.N, Layout.N)
69
+ TNT = (Layout.T, Layout.N, Layout.T)
70
+ TTN = (Layout.T, Layout.T, Layout.N)
71
+ TTT = (Layout.T, Layout.T, Layout.T)
72
+
73
+
74
+ def get_name(
75
+ layouts,
76
+ alignments,
77
+ element_output,
78
+ element_accumulator,
79
+ element_epilogue,
80
+ cluster_shape,
81
+ threadblock_shape,
82
+ stages,
83
+ element_a,
84
+ element_b,
85
+ element_c,
86
+ arch,
87
+ opclass,
88
+ kernel_schedule=None,
89
+ epilogue_schedule=None,
90
+ suffix="",
91
+ ):
92
+ """
93
+ Generates a procedural name for a test case.
94
+
95
+ :param layouts: indexable container of layouts of A, B, and C operands
96
+ :param alignments: indexable container of alignments of A, B, and C operands
97
+ :param element_output: data type of the output element
98
+ :param element_accumulator: data type used in accumulation
99
+ :param element_epilogue: data type used in computing the epilogue
100
+ :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched
101
+ :param threadblock_shape: indexable container of dimensions of threadblock tiles
102
+ :param stages: number of pipeline stages to use in the kernel
103
+ :type stages: int
104
+ :param element_a: data type of operand A
105
+ :param element_b: data type of operand B
106
+ :param element_c: data type of operand C
107
+ :param arch: compute capability of kernel being generated
108
+ :type arch: int
109
+ :param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
110
+ :type opclass: cutlass_cppgen.OpcodeClass
111
+ :param kernel_schedule: kernel_schedule type
112
+ :type kernel_schedule: cutlass_cppgen.KernelScheduleType
113
+ :param epilogue_schedule: epilogue_schedule type
114
+ :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
115
+ :param suffix: additional string to add to the suffix of the name
116
+ :type suffix: str
117
+
118
+ :return: str
119
+ """
120
+ name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}"
121
+ return SubstituteTemplate(
122
+ name_format,
123
+ {
124
+ "arch": str(arch),
125
+ "eA": DataTypeNames[element_a],
126
+ "eB": DataTypeNames[element_b],
127
+ "eC": DataTypeNames[element_c],
128
+ "lA": ShortLayoutTypeNames[layouts[0]],
129
+ "lB": ShortLayoutTypeNames[layouts[1]],
130
+ "lC": ShortLayoutTypeNames[layouts[2]],
131
+ "opclass": OpcodeClassNames[opclass],
132
+ "acc": DataTypeNames[element_accumulator],
133
+ "cM": str(cluster_shape[0]),
134
+ "cN": str(cluster_shape[1]),
135
+ "cK": str(cluster_shape[2]),
136
+ "tbM": str(threadblock_shape[0]),
137
+ "tbN": str(threadblock_shape[1]),
138
+ "tbK": str(threadblock_shape[2]),
139
+ "stages": str(stages) if stages is not None else "auto",
140
+ "aA": str(alignments[0]),
141
+ "aB": str(alignments[1]),
142
+ "aC": str(alignments[2]),
143
+ "k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule],
144
+ "e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule],
145
+ "suffix": "" if suffix is None else suffix,
146
+ },
147
+ )
148
+
149
+
150
+ def add_test_gemm(
151
+ cls=None,
152
+ cc=None,
153
+ element=None,
154
+ layouts=None,
155
+ alignments=None,
156
+ element_output=None,
157
+ element_accumulator=None,
158
+ cluster_shape=None,
159
+ threadblock_shape=None,
160
+ warp_count=None,
161
+ stages=None,
162
+ opclass=None,
163
+ swizzle=None,
164
+ kernel_schedule=None,
165
+ epilogue_schedule=None,
166
+ compilation_modes=['nvcc', 'nvrtc'],
167
+ element_A=None,
168
+ element_B=None,
169
+ element_C=None):
170
+ """
171
+ Create test-running functions with the given specification and set it as a method of ``cls``.
172
+
173
+ :param cls: class to which the generated method will be added
174
+ :type cls: type
175
+ :param cc: compute capability to compile for
176
+ :type cc: int
177
+ :param element: data type of A and B operands
178
+ :type element: cutlass_cppgen.DataType.f16
179
+ :param layouts: layouts of A, B, and C operands
180
+ :type layouts: list or tuple
181
+ :param alignments: alingments of A, B, and C operands
182
+ :type alignments: list or tuple
183
+ :param element_output: data type of the output element
184
+ :type element_output: cutlass_cppgen.DataType
185
+ :param element_accumulator: data type used in accumulation
186
+ :type element_accumulator: cutlass_cppgen.DataType
187
+ :param cluster_shape: dimensions of clusters
188
+ :type cluster_shape: list or tuple
189
+ :param threadblock_shape: dimensions of threadblock tiles
190
+ :type threadblock_shape: list or tuple
191
+ :param warp_count: warps to be launched per threadblock dimension
192
+ :type warp_count: list or tuple
193
+ :param stages: number of pipeline stages to use in the kernel
194
+ :type stages: int
195
+ :param opclass: class of operation being performed (e.g., SIMT, Tensor Core)
196
+ :type opclass: cutlass_cppgen.OpcodeClass
197
+ :param swizzle: threadblock swizzling functor
198
+ :param kernel_schedule: kernel schedule to use
199
+ :type kernel_schedule: cutlass_cppgen.KernelScheduleType
200
+ :param epilogue_schedule: epilogue schedule to use
201
+ :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType
202
+ :param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc')
203
+ :type compilation_modes: list,
204
+ :param element_A: data type of operand A. If set, overrides ``element``
205
+ :type element_A: cutlass_cppgen.DataType
206
+ :param element_B: data type of operand B. If set, overrides ``element``
207
+ :type element_B: cutlass_cppgen.DataType
208
+ :param element_C: data type of operand C. If set, overrides ``element``
209
+ :type element_C: cutlass_cppgen.DataType
210
+ """
211
+
212
+ if element_A is None:
213
+ element_A = element
214
+ if element_B is None:
215
+ element_B = element
216
+ if element_C is None:
217
+ element_C = element
218
+ if element_output is None:
219
+ element_output = element
220
+ if element_accumulator is None:
221
+ element_accumulator = element
222
+
223
+ for compilation_mode in compilation_modes:
224
+ def run(self):
225
+ """
226
+ Dynamically-generated function that constructs a GEMM operation and verifies it against
227
+ multiple test cases.
228
+ """
229
+
230
+ layout_A, layout_B, layout_C = layouts
231
+ alignment_A, alignment_B, alignment_C = alignments
232
+
233
+ plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B,
234
+ element_C=element_C, element_D=element_output,
235
+ layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
236
+ element_accumulator=element_accumulator,
237
+ kernel_cc=cc)
238
+
239
+ plan.opclass = opclass
240
+ if swizzle is not None:
241
+ plan.swizzling_functor = swizzle
242
+
243
+ td = plan.tile_descriptions()[0]
244
+
245
+ if warp_count is not None:
246
+ td.warp_count = warp_count
247
+ td.threadblock_shape = threadblock_shape
248
+ td.stages = stages
249
+ td.cluster_shape = cluster_shape
250
+ op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C)
251
+ self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode))
252
+
253
+ element_epilogue = element_accumulator
254
+ name = get_name(
255
+ layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator,
256
+ element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape,
257
+ stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass,
258
+ kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}')
259
+
260
+ setattr(cls, name, run)
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/installation.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Tests for a successful installation of the CUTLASS Python interface
35
+ """
36
+
37
+ import os
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ import cutlass_library
42
+
43
+
44
+ class InstallationTest(unittest.TestCase):
45
+ def test_cutlass_source_paths(self):
46
+ """
47
+ Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages
48
+ """
49
+ src_file = 'include/cutlass/cutlass.h'
50
+ library_file = os.path.join(cutlass_library.source_path, src_file)
51
+ cutlass_file = os.path.join(cutlass_cppgen.CUTLASS_PATH, src_file)
52
+ assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded."
53
+ assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded."
54
+
55
+
56
+ if __name__ == "__main__":
57
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Tests the high-level Conv2d interface
35
+ """
36
+
37
+ from math import ceil
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ import cutlass_cppgen.utils.datatypes as datatypes
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+ from utils import ExpectException
44
+ import os
45
+
46
+
47
+ class Conv2dEquivalence:
48
+ """
49
+ Helper class for testing the equivalence of different constructions of the Conv2d interface
50
+ """
51
+ def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator,
52
+ alignment_A, alignment_B, alignment_C):
53
+
54
+ self.element_A = element_A
55
+ self.element_B = element_B
56
+ self.element_C = element_C
57
+ self.element_D = element_D
58
+ self.element_accumulator = element_accumulator
59
+ self.alignment_A = alignment_A
60
+ self.alignment_B = alignment_B
61
+ self.alignment_C = alignment_C
62
+
63
+ self.conv_kind = conv_kind
64
+
65
+ self.plan = cutlass_cppgen.op.Conv2d(
66
+ kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C,
67
+ element_D=element_D, element_accumulator=element_accumulator)
68
+
69
+ self.op = self.plan.construct(
70
+ alignment_A=self.alignment_A, alignment_B=self.alignment_B,
71
+ alignment_C=self.alignment_C)
72
+
73
+ def _plans_equal(self, other_plan) -> bool:
74
+ """
75
+ Compares whether two plans are equal
76
+
77
+ :param other_plan: plan to compare against the default Conv2d
78
+ :type other_plan: cutlass_cppgen.op.Conv2d
79
+
80
+ :return: whether `other_plan` is equivalent to `self.plan`
81
+ :rtype: bool
82
+ """
83
+ other_op = other_plan.construct(
84
+ alignment_A=self.alignment_A, alignment_B=self.alignment_B,
85
+ alignment_C=self.alignment_C)
86
+
87
+ return self.op.rt_module.emit() == other_op.rt_module.emit()
88
+
89
+ def generic_test(self):
90
+ """
91
+ Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types
92
+ and layouts for constructing the Conv2d interface
93
+ """
94
+ if not datatypes.is_numpy_available():
95
+ return
96
+
97
+ # Test when specifying all parameters
98
+ plan_other = cutlass_cppgen.op.Conv2d(
99
+ kind=self.conv_kind,
100
+ element_A=self.element_A, element_B=self.element_B, element_C=self.element_C,
101
+ element_D=self.element_D, element_accumulator=self.element_accumulator)
102
+ assert self._plans_equal(plan_other)
103
+
104
+ # Test when specifying all parameters but A
105
+ plan_other = cutlass_cppgen.op.Conv2d(
106
+ kind=self.conv_kind,
107
+ element_B=self.element_B, element_C=self.element_C,
108
+ element_D=self.element_D, element_accumulator=self.element_accumulator,
109
+ element=self.element_A)
110
+ assert self._plans_equal(plan_other)
111
+
112
+ # Test when specifying all parameters but A and B as tensors using generic element and output
113
+ plan_other = cutlass_cppgen.op.Conv2d(
114
+ kind=self.conv_kind,
115
+ element_C=self.element_C,
116
+ element_D=self.element_D, element_accumulator=self.element_accumulator,
117
+ element=self.element_A)
118
+ assert self._plans_equal(plan_other)
119
+
120
+ # Test without explicit accumulator. Only run if the type of C and the accumulator are equal
121
+ if self.element_C == self.element_accumulator:
122
+ plan_other = cutlass_cppgen.op.Conv2d(
123
+ kind=self.conv_kind,
124
+ element_C=self.element_C,
125
+ element_D=self.element_D,
126
+ element=self.element_A)
127
+ assert self._plans_equal(plan_other)
128
+
129
+ # Test with only the generic types. Only rune if the types of A, B, C, and D are the same
130
+ if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D
131
+ and self.element_A == self.element_accumulator):
132
+ plan_other = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=self.element_A)
133
+ assert self._plans_equal(plan_other)
134
+
135
+ def numpy_test(self):
136
+ """
137
+ Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend
138
+ """
139
+ if not datatypes.is_numpy_available():
140
+ return
141
+
142
+ import numpy as np
143
+ type_A = datatypes.numpy_type(self.element_A)
144
+ type_B = datatypes.numpy_type(self.element_B)
145
+ type_C = datatypes.numpy_type(self.element_C)
146
+ type_D = datatypes.numpy_type(self.element_D)
147
+ type_accum = datatypes.numpy_type(self.element_accumulator)
148
+
149
+ size = (2, 2)
150
+ A = np.zeros(size, dtype=type_A)
151
+ B = np.zeros(size, dtype=type_B)
152
+ C = np.zeros(size, dtype=type_C)
153
+ D = np.zeros(size, dtype=type_D)
154
+
155
+ return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D)
156
+
157
+ def torch_test(self):
158
+ """
159
+ Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend
160
+ """
161
+ if not datatypes.is_torch_available():
162
+ return
163
+
164
+ import torch
165
+ type_A = datatypes.torch_type(self.element_A)
166
+ type_B = datatypes.torch_type(self.element_B)
167
+ type_C = datatypes.torch_type(self.element_C)
168
+ type_D = datatypes.torch_type(self.element_D)
169
+ type_accum = datatypes.torch_type(self.element_accumulator)
170
+
171
+ size = (2, 2)
172
+
173
+ A = torch.empty(size, dtype=type_A)
174
+ B = torch.empty(size, dtype=type_B)
175
+ C = torch.empty(size, dtype=type_C)
176
+ D = torch.empty(size, dtype=type_D)
177
+
178
+ return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D)
179
+
180
+ def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D):
181
+ # Test when specifying all parameters via tensors
182
+ plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum)
183
+ assert self._plans_equal(plan_np)
184
+
185
+ # Test when specifying all parameters but A as tensors
186
+ plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A)
187
+ assert self._plans_equal(plan_np)
188
+
189
+ # Test when specifying all parameters but A and B as tensors and using generic element and output
190
+ if type_A == type_B:
191
+ plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A)
192
+ assert self._plans_equal(plan_np)
193
+
194
+ # Test without explicit accumulator. Only run if the type of C and the accumulator.
195
+ if type_C == type_accum:
196
+ plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D)
197
+ assert self._plans_equal(plan_np)
198
+
199
+ # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same.
200
+ if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum):
201
+ plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=type_A)
202
+ assert self._plans_equal(plan_np)
203
+
204
+ def test_all(self):
205
+ """
206
+ Runs all tests on the Gemm interface
207
+ """
208
+ self.generic_test()
209
+ self.numpy_test()
210
+ self.torch_test()
211
+
212
+
213
+ @unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.')
214
+ class ConvEquivalenceTest(unittest.TestCase):
215
+ """
216
+ Tests the equivalence of different constructions of the Conv2d interface
217
+ """
218
+ pass
219
+
220
+ type2alignment = {
221
+ cutlass_cppgen.DataType.f16: 8,
222
+ cutlass_cppgen.DataType.f32: 4
223
+ }
224
+
225
+ def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator):
226
+
227
+ test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}"
228
+
229
+ def run(self):
230
+ conv2d_eq = Conv2dEquivalence(
231
+ conv_kind=conv_kind,
232
+ element_A=element_A, element_B=element_B,
233
+ element_C=element_C, element_D=element_D,
234
+ element_accumulator=element_accumulator,
235
+ alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B],
236
+ alignment_C=type2alignment[element_C]
237
+ )
238
+ conv2d_eq.test_all()
239
+
240
+ setattr(ConvEquivalenceTest, test_name, run)
241
+
242
+ for conv_kind in ["fprop", "wgrad", "dgrad"]:
243
+ for types in [
244
+ [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16],
245
+ [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32],
246
+ [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16],
247
+ [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32],
248
+ [cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32]
249
+ ]:
250
+ add_test(conv_kind, types[0], types[1], types[2], types[3], types[4])
251
+
252
+
253
+ @unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.')
254
+ class Conv2dErrorTests(unittest.TestCase):
255
+ """
256
+ Tests various error scenarios that arise with the high-level Gemm interface
257
+ """
258
+
259
+ def test_alignment(self):
260
+ """
261
+ Tests case in which the alignment specified is unsupported
262
+ """
263
+ plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16)
264
+
265
+ with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'):
266
+ op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3)
267
+
268
+ def test_invalid_tile_description(self):
269
+ """
270
+ Tests scenarios in which an invalid tile description is provided for a given CC
271
+ """
272
+ plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16)
273
+
274
+ td = plan.tile_descriptions()[0]
275
+ td.threadblock_shape=[17, 32, 5]
276
+
277
+ plan.tile_description = td
278
+ with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'):
279
+ plan.compile()
280
+ # Clean up the error message
281
+ os.remove("./cutlass_python_compilation_device_error.txt")
282
+
283
+ if __name__ == '__main__':
284
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Test the EVT interface
35
+ """
36
+
37
+ import numpy as np
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ from cutlass_cppgen import LayoutType, Tensor
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+ from cutlass_cppgen.epilogue import reshape, permute
44
+
45
+ from utils import ExpectException
46
+
47
+
48
+ @unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only")
49
+ class EVTErrorTests(unittest.TestCase):
50
+ """
51
+ Tests various error scenarios that arise with the EVT interface
52
+ """
53
+ @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT requires root node be 'D'")
54
+ def test_root_not_d(self):
55
+ """
56
+ Test when "D" does not exist in Sm90 EVT
57
+ """
58
+ def evt_root_not_d(accum, alpha):
59
+ F = accum * alpha
60
+ return F
61
+
62
+ example_tensors = {
63
+ "accum": self.fake_tensor(np.float16, (6, 512, 512)),
64
+ "alpha": 1.2,
65
+ "F": self.fake_tensor(np.float16, (6, 512, 512))
66
+ }
67
+
68
+ with ExpectException(device_cc() == 90,
69
+ "SyntaxError: Sm90 EVT requires the epilogue to have a returned tensor D, "
70
+ "but the variable 'D' is not found in the return values.", True):
71
+
72
+ cutlass_cppgen.epilogue.trace(evt_root_not_d, example_tensors)
73
+
74
+ def test_no_accum(self):
75
+ """
76
+ Test when "accum" is not in input arguments
77
+ """
78
+ def evt_no_accum(alpha, C):
79
+ D = alpha * C
80
+ return D
81
+
82
+ example_tensors = {
83
+ "C": self.fake_tensor(np.float16, (6, 512, 512)),
84
+ "alpha": 1.2,
85
+ "D": self.fake_tensor(np.float16, (6, 512, 512))
86
+ }
87
+
88
+ with ExpectException(True, "SyntaxError: Cannot find 'accum' in the argument list.", True):
89
+ cutlass_cppgen.epilogue.trace(evt_no_accum, example_tensors)
90
+
91
+ @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT has concern on smem size")
92
+ def test_too_much_shared_memory(self):
93
+ """
94
+ Test when the epilogue consumes too much shared memory
95
+ """
96
+ def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8):
97
+ D1 = accum + C1
98
+ D2 = D1 + C2
99
+ D3 = D2 + C3
100
+ D4 = D3 + C4
101
+ D5 = D4 + C5
102
+ D6 = D5 + C6
103
+ D7 = D6 + C7
104
+ D = D7 + C8
105
+ return D, D1, D2, D3, D4, D5, D6, D7
106
+
107
+ example_tensors = {
108
+ "accum": self.fake_tensor(np.float16, (6, 512, 512)),
109
+ "C1": self.fake_tensor(np.float16, (6, 512, 512)),
110
+ "C2": self.fake_tensor(np.float16, (6, 512, 512)),
111
+ "C3": self.fake_tensor(np.float16, (6, 512, 512)),
112
+ "C4": self.fake_tensor(np.float16, (6, 512, 512)),
113
+ "C5": self.fake_tensor(np.float16, (6, 512, 512)),
114
+ "C6": self.fake_tensor(np.float16, (6, 512, 512)),
115
+ "C7": self.fake_tensor(np.float16, (6, 512, 512)),
116
+ "C8": self.fake_tensor(np.float16, (6, 512, 512)),
117
+ "D1": self.fake_tensor(np.float16, (6, 512, 512)),
118
+ "D2": self.fake_tensor(np.float16, (6, 512, 512)),
119
+ "D3": self.fake_tensor(np.float16, (6, 512, 512)),
120
+ "D4": self.fake_tensor(np.float16, (6, 512, 512)),
121
+ "D5": self.fake_tensor(np.float16, (6, 512, 512)),
122
+ "D6": self.fake_tensor(np.float16, (6, 512, 512)),
123
+ "D7": self.fake_tensor(np.float16, (6, 512, 512)),
124
+ "D": self.fake_tensor(np.float16, (6, 512, 512))
125
+ }
126
+
127
+ epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_too_much_shared_memory, example_tensors)
128
+
129
+ plan = cutlass_cppgen.op.Gemm(
130
+ element=np.float16, layout=cutlass_cppgen.LayoutType.RowMajor,
131
+ element_accumulator=np.float32
132
+ )
133
+
134
+ with ExpectException(True,
135
+ "RuntimeError: The epilogue consumes too much shared memory. "
136
+ "No valid tile description is found in the generator.", True):
137
+ plan.epilogue_visitor = epilogue_visitor
138
+
139
+ def test_not_ssa(self):
140
+ """
141
+ Test when the epilogue is not in SSA
142
+ """
143
+ def evt_redefine(accum, C, alpha):
144
+ F = accum + C
145
+ F = F * alpha
146
+ D = F
147
+ return D, F
148
+
149
+ example_tensors = {
150
+ "accum": self.fake_tensor(np.float16, (6, 512, 512)),
151
+ "C": self.fake_tensor(np.float16, (6, 512, 512)),
152
+ "alpha": 1.5,
153
+ "D": self.fake_tensor(np.float16, (6, 512, 512)),
154
+ "F": self.fake_tensor(np.float16, (6, 512, 512))
155
+ }
156
+
157
+ with ExpectException(True, "SyntaxError: Variable 'F' cannot be defined twice.", True):
158
+ cutlass_cppgen.epilogue.trace(evt_redefine, example_tensors)
159
+
160
+ def evt_undefine(accum, alpha):
161
+ F = accum + C
162
+ D = F * alpha
163
+ return D, F
164
+
165
+ example_tensors = {
166
+ "accum": self.fake_tensor(np.float16, (6, 512, 512)),
167
+ "alpha": 1.5,
168
+ "D": self.fake_tensor(np.float16, (6, 512, 512)),
169
+ "F": self.fake_tensor(np.float16, (6, 512, 512))
170
+ }
171
+
172
+ with ExpectException(True, "SyntaxError: Variable 'C' is undefined.", True):
173
+ cutlass_cppgen.epilogue.trace(evt_undefine, example_tensors)
174
+
175
+ def test_missing_example_tensor(self):
176
+ """
177
+ Test when the example tensor of an input/output variable is not provided
178
+ """
179
+ def evt_missing_example_tensor(accum, C):
180
+ D = accum + C
181
+ return D
182
+
183
+ example_tensors = {
184
+ "accum": self.fake_tensor(np.float16, (6, 512, 512)),
185
+ "C": self.fake_tensor(np.float16, (6, 512, 512)),
186
+ }
187
+
188
+ with ExpectException(True, "RuntimeError: Example input for D is not provided.", True):
189
+ cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors)
190
+
191
+ example_tensors = {
192
+ "accum": self.fake_tensor(np.float16, (6, 512, 512)),
193
+ "D": self.fake_tensor(np.float16, (6, 512, 512)),
194
+ }
195
+
196
+ with ExpectException(True, "RuntimeError: Example input for C is not provided.", True):
197
+ cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors)
198
+
199
+ def test_return_expression(self):
200
+ """
201
+ Test when the return value is an expression
202
+ """
203
+ def evt_return_expr(accum, C):
204
+ return accum + C
205
+
206
+ example_tensors = {
207
+ "accum": self.fake_tensor(np.float16, (6, 512, 512)),
208
+ "C": self.fake_tensor(np.float16, (6, 512, 512)),
209
+ }
210
+
211
+ with ExpectException(True, "SyntaxError: Return value cannot be an expression", True):
212
+ cutlass_cppgen.epilogue.trace(evt_return_expr, example_tensors)
213
+
214
+ def test_incompatible_shape(self):
215
+ """
216
+ Test when the shape of example tensors are incompatible
217
+ """
218
+ def evt_incompatible_shape(accum, C):
219
+ D = accum + C
220
+ return D
221
+
222
+ example_tensors = {
223
+ "accum": self.fake_tensor(np.float16, (6, 256, 512)),
224
+ "C": self.fake_tensor(np.float16, (6, 512, 512)),
225
+ "D": self.fake_tensor(np.float16, (6, 512, 512))
226
+ }
227
+
228
+ with ExpectException(True,
229
+ "RuntimeError: Dimension mismatch between accum(6, 256, 512), C(6, 512, 512).", True):
230
+ cutlass_cppgen.epilogue.trace(evt_incompatible_shape, example_tensors)
231
+
232
+ def test_no_matching_impl(self):
233
+ def evt_no_matching_impl(accum, bias):
234
+ D = accum + reshape(permute(bias, indices=(1, 0)), new_shape=(512, 1))
235
+ return D
236
+
237
+ example_tensors = {
238
+ "accum": self.fake_tensor(np.float16, (6, 512, 256)),
239
+ "bias": self.fake_tensor(np.float16, (16, 32)),
240
+ "D": self.fake_tensor(np.float16, (6, 512, 256))
241
+ }
242
+
243
+ with ExpectException(True, "NotImplementedError: No matching op for node bias with stride (0, (1, 32), 0).", True):
244
+ cutlass_cppgen.epilogue.trace(evt_no_matching_impl, example_tensors)
245
+ #
246
+ # Helper functions
247
+ #
248
+
249
+ def fake_tensor(self, element, shape):
250
+ return Tensor(element=element, shape=shape, layout_tag=LayoutType.RowMajor)
251
+
252
+
253
+ if __name__ == '__main__':
254
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Tests the high-level GEMM interface
35
+ """
36
+
37
+ from math import ceil
38
+ import unittest
39
+
40
+ import cutlass_cppgen
41
+ import cutlass_cppgen.utils.datatypes as datatypes
42
+ from cutlass_cppgen.backend.utils.device import device_cc
43
+ from utils import ExpectException
44
+
45
+
46
+ class GemmEquivalence:
47
+ """
48
+ Helper class for testing the equivalence of different constructions of the Gemm interface
49
+ """
50
+ def __init__(self, element_A, element_B, element_C, element_D, element_accumulator,
51
+ layout_A, layout_B, layout_C, alignment_A, alignment_B, alignment_C):
52
+ self.element_A = element_A
53
+ self.element_B = element_B
54
+ self.element_C = element_C
55
+ self.element_D = element_D
56
+ self.element_accumulator = element_accumulator
57
+ self.layout_A = layout_A
58
+ self.layout_B = layout_B
59
+ self.layout_C = layout_C
60
+ self.alignment_A = alignment_A
61
+ self.alignment_B = alignment_B
62
+ self.alignment_C = alignment_C
63
+ self.plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C,
64
+ element_D=element_D, element_accumulator=element_accumulator,
65
+ layout_A=layout_A, layout_B=layout_B, layout_C=layout_C)
66
+ self.op = self.plan.construct(alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C)
67
+
68
+ def _plans_equal(self, other_plan) -> bool:
69
+ """
70
+ Compares whether two plans are equal
71
+
72
+ :param other_plan: plan to compare against the default GEMM
73
+ :type other_plan: cutlass_cppgen.op.Gemm
74
+
75
+ :return: whether `other_plan` is equivalent to `self.plan`
76
+ :rtype: bool
77
+ """
78
+ other_op = other_plan.construct(alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C)
79
+
80
+ # Compare whether the operations are equal by comparing the C++ code that would be emitted for them
81
+ return self.op.rt_module.emit() == other_op.rt_module.emit()
82
+
83
+ def generic_test(self):
84
+ """
85
+ Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types
86
+ and layouts for constructing the Gemm interface
87
+ """
88
+ if not datatypes.is_numpy_available():
89
+ return
90
+
91
+ # Test when specifying all parameters
92
+ plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C,
93
+ element_D=self.element_D, element_accumulator=self.element_accumulator,
94
+ layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C)
95
+ assert self._plans_equal(plan_other)
96
+
97
+ # Test when specifying all parameters but A
98
+ plan_other = cutlass_cppgen.op.Gemm(element_B=self.element_B, element_C=self.element_C,
99
+ element_D=self.element_D, element_accumulator=self.element_accumulator,
100
+ layout_B=self.layout_B, layout_C=self.layout_C,
101
+ element=self.element_A, layout=self.layout_A)
102
+ assert self._plans_equal(plan_other)
103
+
104
+ # Test when specifying all parameters but A and B as tensors and using generic element and output
105
+ # Only run this test if the layouts and types for A and B are equal.
106
+ if self.element_A == self.element_B and self.layout_A == self.layout_B:
107
+ plan_other = cutlass_cppgen.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator,
108
+ layout_C=self.layout_C, element=self.element_A, layout=self.layout_A)
109
+ assert self._plans_equal(plan_other)
110
+
111
+ # Test without explicit accumulator. Only run if the type of C and the accumulator.
112
+ if self.element_C == self.element_accumulator:
113
+ plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C,
114
+ element_D=self.element_D, layout_A=self.layout_A, layout_B=self.layout_B,
115
+ layout_C=self.layout_C)
116
+ assert self._plans_equal(plan_other)
117
+
118
+ # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same.
119
+ if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D
120
+ and self.element_A == self.element_accumulator and
121
+ self.layout_A == self.layout_B and self.layout_A == self.layout_C):
122
+ plan_other = cutlass_cppgen.op.Gemm(element=self.element_A, layout=self.layout_A)
123
+ assert self._plans_equal(plan_other)
124
+
125
+ def numpy_test(self):
126
+ """
127
+ Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend
128
+ """
129
+ if not datatypes.is_numpy_available():
130
+ return
131
+
132
+ import numpy as np
133
+ type_A = datatypes.numpy_type(self.element_A)
134
+ type_B = datatypes.numpy_type(self.element_B)
135
+ type_C = datatypes.numpy_type(self.element_C)
136
+ type_D = datatypes.numpy_type(self.element_D)
137
+ type_accum = datatypes.numpy_type(self.element_accumulator)
138
+
139
+ layout_to_order = {
140
+ cutlass_cppgen.LayoutType.RowMajor: 'C',
141
+ cutlass_cppgen.LayoutType.ColumnMajor: 'F'
142
+ }
143
+ size = (2, 2)
144
+ A = np.zeros(size, order=layout_to_order[self.layout_A], dtype=type_A)
145
+ B = np.zeros(size, order=layout_to_order[self.layout_B], dtype=type_B)
146
+ C = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_C)
147
+ D = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_D)
148
+
149
+ # Test when specifying all parameters via tensors
150
+ plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum)
151
+ assert self._plans_equal(plan_np)
152
+
153
+ # Test when specifying all parameters but A as tensors
154
+ plan_np = cutlass_cppgen.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A)
155
+ assert self._plans_equal(plan_np)
156
+
157
+ # Test when specifying all parameters but A and B as tensors and using generic element and output
158
+ # Only run this test if the layouts and types for A and B are equal.
159
+ if type_A == type_B and self.layout_A == self.layout_B:
160
+ plan_np = cutlass_cppgen.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A)
161
+ assert self._plans_equal(plan_np)
162
+
163
+ # Test without explicit accumulator. Only run if the type of C and the accumulator.
164
+ if type_C == type_accum:
165
+ plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D)
166
+ assert self._plans_equal(plan_np)
167
+
168
+ # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same.
169
+ if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum and
170
+ self.layout_A == self.layout_B and self.layout_A == self.layout_C):
171
+ plan_np = cutlass_cppgen.op.Gemm(element=type_A, layout=self.layout_A)
172
+ assert self._plans_equal(plan_np)
173
+
174
+ def test_all(self):
175
+ """
176
+ Runs all tests on the Gemm interface
177
+ """
178
+ self.generic_test()
179
+ self.numpy_test()
180
+
181
+
182
+ class GemmEquivalenceTest(unittest.TestCase):
183
+ """
184
+ Tests the equivalence of different constructions of the Gemm interface
185
+ """
186
+ @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.")
187
+ def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_8_8_8(self):
188
+ gemm_eq = GemmEquivalence(
189
+ element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
190
+ element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16,
191
+ layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor,
192
+ alignment_A=8, alignment_B=8, alignment_C=8)
193
+ gemm_eq.test_all()
194
+
195
+ @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.")
196
+ def test_gemm_equivalence_f16_f16_f16_f16_f32_ntn_8_8_8(self):
197
+ gemm_eq = GemmEquivalence(
198
+ element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
199
+ element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32,
200
+ layout_A=cutlass_cppgen.LayoutType.ColumnMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.ColumnMajor,
201
+ alignment_A=8, alignment_B=8, alignment_C=8)
202
+ gemm_eq.test_all()
203
+
204
+ @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.")
205
+ def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_4_4_4(self):
206
+ gemm_eq = GemmEquivalence(
207
+ element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16,
208
+ element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16,
209
+ layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor,
210
+ alignment_A=8, alignment_B=8, alignment_C=8)
211
+ gemm_eq.test_all()
212
+
213
+ @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for F64 Tensor Core tests.")
214
+ def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self):
215
+ gemm_eq = GemmEquivalence(
216
+ element_A=cutlass_cppgen.DataType.f64, element_B=cutlass_cppgen.DataType.f64, element_C=cutlass_cppgen.DataType.f64,
217
+ element_D=cutlass_cppgen.DataType.f64, element_accumulator=cutlass_cppgen.DataType.f64,
218
+ layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.ColumnMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor,
219
+ alignment_A=1, alignment_B=1, alignment_C=1)
220
+ gemm_eq.test_all()
221
+
222
+
223
+ class GemmErrorTests(unittest.TestCase):
224
+ """
225
+ Tests various error scenarios that arise with the high-level Gemm interface
226
+ """
227
+
228
+ def test_alignment(self):
229
+ """
230
+ Tests case in which the alignment specified is unsupported
231
+ """
232
+ plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
233
+
234
+ with ExpectException(True, 'Alignment 16 is not supported for F16. The construction should fail.'):
235
+ op = plan.construct(alignment_A=16, alignment_B=16, alignment_C=16)
236
+
237
+ def test_tensorop_availability(self):
238
+ """
239
+ Tests case in which only SIMT operations are available but TensorOp is requested
240
+ """
241
+ cc = device_cc()
242
+
243
+ # F64 Tensor Core operations are only avaiable on certain devices
244
+ supports_tensorop_f64 = cc in [80, 89, 90]
245
+ plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f64, layout=cutlass_cppgen.LayoutType.RowMajor)
246
+
247
+ error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}'
248
+ with ExpectException(not supports_tensorop_f64, error_msg):
249
+ plan.opclass = cutlass_cppgen.OpcodeClass.TensorOp
250
+
251
+ expected_opclass = cutlass_cppgen.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass_cppgen.OpcodeClass.Simt
252
+ assert plan.opclass == expected_opclass, f'Expected opclass to be {expected_opclass}, but received {plan.opclass} for SM{cc}'
253
+
254
+ @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for F16 Tensor Core tests.")
255
+ def test_opclass_switch(self):
256
+ """
257
+ Tests cases in which the opcode class in question is switched (e.g., from TensorOp to SIMT)
258
+ """
259
+ plan = cutlass_cppgen.op.Gemm( element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
260
+ assert plan.opclass == cutlass_cppgen.OpcodeClass.TensorOp
261
+
262
+ # Ensure that all tile descriptions have opclass of TensorOp
263
+ for td in plan.tile_descriptions():
264
+ assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.TensorOp
265
+
266
+ plan.opclass = cutlass_cppgen.OpcodeClass.Simt
267
+
268
+ # Ensure that all tile descriptions have opclass of Simt
269
+ for td in plan.tile_descriptions():
270
+ assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.Simt
271
+
272
+ def test_invalid_tile_description(self):
273
+ """
274
+ Tests scenarios in which an invalid tile description is provided for a given CC
275
+ """
276
+ cc = device_cc()
277
+ plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor)
278
+ td = plan.tile_descriptions()[0]
279
+ stages = td.stages
280
+
281
+ # Zero stage count is valid for SM90+, as this is used to indicate that the builder's auto stage
282
+ # count should be used
283
+ with ExpectException(cc < 90, f'Requested zero stages'):
284
+ td.stages = 0
285
+ plan.construct(td)
286
+
287
+ if cc < 90:
288
+ with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'):
289
+ td.stages = 3
290
+ plan.construct(td)
291
+ elif cc == 90:
292
+ original_kschedule = td.kernel_schedule
293
+ original_eschedule = td.epilogue_schedule
294
+ with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'):
295
+ td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong
296
+ td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.NoSmemWarpSpecialized
297
+ td.stages = 3
298
+ plan.construct(td)
299
+ # Reset schedules
300
+ td.kernel_schedule = original_kschedule
301
+ td.epilogue_schedule = original_eschedule
302
+ elif cc in [100, 101, 103]:
303
+ with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'):
304
+ td.stages = 3
305
+ plan.construct(td)
306
+
307
+ with ExpectException(True, f'Requested too many stages'):
308
+ td.stages = 100
309
+ plan.construct(td)
310
+
311
+ # Reset stage count
312
+ td.stages = stages
313
+
314
+ cluster_shape = td.cluster_shape
315
+ with ExpectException(cc < 90, f'Requested non-unit cluster shape on SM{cc}'):
316
+ td.cluster_shape = [2, 1, 1]
317
+ plan.construct(td)
318
+
319
+ # Reset cluster shape
320
+ td.cluster_shape = cluster_shape
321
+
322
+ with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'):
323
+ td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong
324
+ td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized
325
+ plan.construct(td)
326
+
327
+ with ExpectException(cc == 90, f'Requested a non-auto kernel schedule with an auto epilogue schedule'):
328
+ td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong
329
+ td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto
330
+ plan.construct(td)
331
+
332
+ with ExpectException(cc == 90, f'Requested an auto kernel schedule with a non-auto epilogue schedule'):
333
+ td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto
334
+ td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized
335
+ plan.construct(td)
336
+
337
+ with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'):
338
+ td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative
339
+ td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative
340
+ td.tile_scheduler = cutlass_cppgen.TileSchedulerType.StreamK
341
+ plan.construct(td)
342
+
343
+ # Ensure that all returned tile descriptions are unique
344
+ ops = {}
345
+ for i, td in enumerate(plan.tile_descriptions()):
346
+ op = plan.construct(td)
347
+ code_str = op.rt_module.emit()
348
+ if code_str in ops:
349
+ conflicting_td = ops[code_str]
350
+ assert False, f'Multiple tile descriptions emitted {code_str}\nTile descriptions are:\n{td}\n{conflicting_td}'
351
+
352
+
353
+ if __name__ == '__main__':
354
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Helper functions & classes for interface test
35
+ """
36
+ class ExpectException:
37
+ """
38
+ Utility class to assert that an exception was raised when expected
39
+
40
+ Example:
41
+
42
+ .. highlight:: python
43
+ .. code-block:: python
44
+
45
+ with ExceptionExpected(True, 'Division by zero'):
46
+ x = 1.0 / 0.0
47
+
48
+ :param exception_expected: whether an exception is expected to be raised
49
+ :type exception_expected: bool
50
+ :param message: message to print if an exception is raised when not expected or vice versa
51
+ :type message: str
52
+ """
53
+ def __init__(self, exception_expected: bool, message: str = '', verify_msg=False):
54
+ self.exception_expected = exception_expected
55
+ self.message = message
56
+ self.verify_msg = verify_msg
57
+
58
+ def __enter__(self):
59
+ return self
60
+
61
+ def __exit__(self, exc_type, exc_val, traceback):
62
+ exception_raised = exc_type is not None
63
+ assert self.exception_expected == exception_raised, self.message
64
+ if self.verify_msg:
65
+ exc_message = f"{exc_type.__name__}: {exc_val}"
66
+ assert exc_message == self.message, f"expect error message {self.message}, got {exc_message}"
67
+
68
+ # Suppress the exception
69
+ return True
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Utility script for discovering and running all PyCuTe tests
35
+ """
36
+
37
+ import argparse
38
+ import logging
39
+ import pathlib
40
+ import unittest
41
+
42
+
43
+ def numeric_log_level(log_level: str) -> int:
44
+ """
45
+ Converts the string identifier of the log level into the numeric identifier used
46
+ in setting the log level
47
+
48
+ :param x: string representation of log level (e.g., 'INFO', 'DEBUG')
49
+ :type x: str
50
+
51
+ :return: numeric representation of log level
52
+ :rtype: int
53
+ """
54
+ numeric_level = getattr(logging, log_level.upper(), None)
55
+ if not isinstance(numeric_level, int):
56
+ raise ValueError(f"Invalid log level: {log_level}")
57
+ return numeric_level
58
+
59
+
60
+ if __name__ == "__main__":
61
+ parser = argparse.ArgumentParser()
62
+ parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False,
63
+ help='Logging level to be used by the generator script')
64
+ args = parser.parse_args()
65
+
66
+ # Set the logging level based on the user-provided `--log-level` command-line option
67
+ logging.basicConfig(level=args.log_level)
68
+
69
+ loader = unittest.TestLoader()
70
+ script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/'
71
+ tests = loader.discover(script_dir, "test_*.py")
72
+ test_runner = unittest.runner.TextTestRunner()
73
+ results = test_runner.run(tests)
74
+ if not results.wasSuccessful():
75
+ raise Exception("Test cases failed")
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Unit tests for pycute.coalesce
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ from pycute import *
41
+
42
+ _LOGGER = logging.getLogger(__name__)
43
+
44
+
45
+ class TestCoalesce(unittest.TestCase):
46
+ def helper_test_coalesce(self, layout):
47
+ layoutR = coalesce(layout)
48
+
49
+ _LOGGER.debug(f"{layout} => {layoutR}")
50
+
51
+ self.assertEqual(size(layoutR), size(layout))
52
+
53
+ for i in range(size(layout)):
54
+ self.assertEqual(layoutR(i), layout(i))
55
+
56
+ def test_coalesce(self):
57
+ layout = Layout(1,0)
58
+ self.helper_test_coalesce(layout)
59
+
60
+ layout = Layout(1,1)
61
+ self.helper_test_coalesce(layout)
62
+
63
+ layout = Layout((2,4))
64
+ self.helper_test_coalesce(layout)
65
+
66
+ layout = Layout((2,4,6))
67
+ self.helper_test_coalesce(layout)
68
+
69
+ layout = Layout((2,4,6), (1,6,2))
70
+ self.helper_test_coalesce(layout)
71
+
72
+ layout = Layout((2,1,6), (1,7,2))
73
+ self.helper_test_coalesce(layout)
74
+
75
+ layout = Layout((2,1,6), (4,7,8))
76
+ self.helper_test_coalesce(layout)
77
+
78
+ layout = Layout((2,(4,6)))
79
+ self.helper_test_coalesce(layout)
80
+
81
+ layout = Layout((2,4), (4,1))
82
+ self.helper_test_coalesce(layout)
83
+
84
+ layout = Layout((2,4,6), (24,6,1))
85
+ self.helper_test_coalesce(layout)
86
+
87
+ layout = Layout((2,1,3), (2,4,4))
88
+ self.helper_test_coalesce(layout)
89
+
90
+ layout = Layout(((2,2),(2,2)), ((1,4),(8,32)))
91
+ self.helper_test_coalesce(layout)
92
+
93
+
94
+ if __name__ == "__main__":
95
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Unit tests for pycute.complement
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ from pycute import *
41
+
42
+ _LOGGER = logging.getLogger(__name__)
43
+
44
+
45
+ class TestComplement(unittest.TestCase):
46
+ def helper_test_complement(self, layout):
47
+ layoutR = complement(layout)
48
+
49
+ _LOGGER.debug(f"{layout} => {layoutR}")
50
+
51
+ # Post-condition: test disjointness of the codomains
52
+ for a in range(size(layout)):
53
+ for b in range(size(layoutR)):
54
+ assert (layout(a) != layoutR(b)) or (layout(a) == 0 and layoutR(b) == 0)
55
+
56
+ def test_complement(self):
57
+ test = Layout(1,0)
58
+ self.helper_test_complement(test)
59
+
60
+ test = Layout(1,1)
61
+ self.helper_test_complement(test)
62
+
63
+ test = Layout(4,0)
64
+ self.helper_test_complement(test)
65
+
66
+ test = Layout((2,4),(1,2))
67
+ self.helper_test_complement(test)
68
+
69
+ test = Layout((2,3),(1,2))
70
+ self.helper_test_complement(test)
71
+
72
+ test = Layout((2,4),(1,4))
73
+ self.helper_test_complement(test)
74
+
75
+ test = Layout((2,4,8),(8,1,64))
76
+ self.helper_test_complement(test)
77
+
78
+ test = Layout(((2,2),(2,2)),((1,4),(8,32)))
79
+ self.helper_test_complement(test)
80
+
81
+ test = Layout((2,(3,4)),(3,(1,6)))
82
+ self.helper_test_complement(test)
83
+
84
+ test = Layout((4,6),(1,6))
85
+ self.helper_test_complement(test)
86
+
87
+ test = Layout((4,10),(1,10))
88
+ self.helper_test_complement(test)
89
+
90
+
91
+ if __name__ == "__main__":
92
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Unit tests for pycute.composition
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ from pycute import *
41
+
42
+ _LOGGER = logging.getLogger(__name__)
43
+
44
+
45
+ class TestComposition(unittest.TestCase):
46
+ def helper_test_composition(self, layoutA, layoutB):
47
+ layoutR = composition(layoutA, layoutB)
48
+
49
+ _LOGGER.debug(f"{layoutA} o {layoutB} => {layoutR}")
50
+
51
+ # True post-condition: Every coordinate c of layoutB with L1D(c) < size(layoutR) is a coordinate of layoutR.
52
+
53
+ # Test that R(c) = A(B(c)) for all coordinates c in layoutR
54
+ for i in range(size(layoutR)):
55
+ self.assertEqual(layoutR(i), layoutA(layoutB(i)))
56
+
57
+ def test_composition(self):
58
+ layoutA = Layout(1,0)
59
+ layoutB = Layout(1,0)
60
+ self.helper_test_composition(layoutA, layoutB)
61
+
62
+ layoutA = Layout(1,0)
63
+ layoutB = Layout(1,1)
64
+ self.helper_test_composition(layoutA, layoutB)
65
+
66
+ layoutA = Layout(1,1)
67
+ layoutB = Layout(1,0)
68
+ self.helper_test_composition(layoutA, layoutB)
69
+
70
+ layoutA = Layout(1,1)
71
+ layoutB = Layout(1,1)
72
+ self.helper_test_composition(layoutA, layoutB)
73
+
74
+ layoutA = Layout((4))
75
+ layoutB = Layout((4))
76
+ self.helper_test_composition(layoutA, layoutB)
77
+
78
+ layoutA = Layout((4), (2))
79
+ layoutB = Layout((4))
80
+ self.helper_test_composition(layoutA, layoutB)
81
+
82
+ layoutA = Layout((4))
83
+ layoutB = Layout((4), (2))
84
+ self.helper_test_composition(layoutA, layoutB)
85
+
86
+ layoutA = Layout((4), (0))
87
+ layoutB = Layout((4))
88
+ self.helper_test_composition(layoutA, layoutB)
89
+
90
+ layoutA = Layout((4))
91
+ layoutB = Layout((4), (0))
92
+ self.helper_test_composition(layoutA, layoutB)
93
+
94
+ layoutA = Layout((1), (0))
95
+ layoutB = Layout((4))
96
+ self.helper_test_composition(layoutA, layoutB)
97
+
98
+ layoutA = Layout((4))
99
+ layoutB = Layout((1), (0))
100
+ self.helper_test_composition(layoutA, layoutB)
101
+
102
+ layoutA = Layout((4))
103
+ layoutB = Layout((2))
104
+ self.helper_test_composition(layoutA, layoutB)
105
+
106
+ layoutA = Layout((4), (2))
107
+ layoutB = Layout((2))
108
+ self.helper_test_composition(layoutA, layoutB)
109
+
110
+ layoutA = Layout((4))
111
+ layoutB = Layout((2), (2))
112
+ self.helper_test_composition(layoutA, layoutB)
113
+
114
+ layoutA = Layout((4), (2))
115
+ layoutB = Layout((2), (2))
116
+ self.helper_test_composition(layoutA, layoutB)
117
+
118
+ layoutA = Layout((12))
119
+ layoutB = Layout((4,3))
120
+ self.helper_test_composition(layoutA, layoutB)
121
+
122
+ layoutA = Layout((12), (2))
123
+ layoutB = Layout((4,3))
124
+ self.helper_test_composition(layoutA, layoutB)
125
+
126
+ layoutA = Layout((12))
127
+ layoutB = Layout((4,3), (3,1))
128
+ self.helper_test_composition(layoutA, layoutB)
129
+
130
+ layoutA = Layout((12), (2))
131
+ layoutB = Layout((4,3), (3,1))
132
+ self.helper_test_composition(layoutA, layoutB)
133
+
134
+ layoutA = Layout((12))
135
+ layoutB = Layout((2,3), (2,4))
136
+ self.helper_test_composition(layoutA, layoutB)
137
+
138
+ layoutA = Layout((4,3))
139
+ layoutB = Layout((4,3))
140
+ self.helper_test_composition(layoutA, layoutB)
141
+
142
+ layoutA = Layout((4,3))
143
+ layoutB = Layout((12))
144
+ self.helper_test_composition(layoutA, layoutB)
145
+
146
+ layoutA = Layout((4,3))
147
+ layoutB = Layout((6), (2))
148
+ self.helper_test_composition(layoutA, layoutB)
149
+
150
+ layoutA = Layout((4,3))
151
+ layoutB = Layout((6,2), (2,1))
152
+ self.helper_test_composition(layoutA, layoutB)
153
+
154
+ layoutA = Layout((4,3), (3,1))
155
+ layoutB = Layout((4,3))
156
+ self.helper_test_composition(layoutA, layoutB)
157
+
158
+ layoutA = Layout((4,3), (3,1))
159
+ layoutB = Layout((12))
160
+ self.helper_test_composition(layoutA, layoutB)
161
+
162
+ layoutA = Layout((4,3), (3,1))
163
+ layoutB = Layout((6), (2))
164
+ self.helper_test_composition(layoutA, layoutB)
165
+
166
+ layoutA = Layout((4,3), (3,1))
167
+ layoutB = Layout((6,2), (2,1))
168
+ self.helper_test_composition(layoutA, layoutB)
169
+
170
+ layoutA = Layout((8,8))
171
+ layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32)))
172
+ self.helper_test_composition(layoutA, layoutB)
173
+
174
+ layoutA = Layout((8,8), (8,1))
175
+ layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32)))
176
+ self.helper_test_composition(layoutA, layoutB)
177
+
178
+ layoutA = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32)))
179
+ layoutB = Layout(8, 4)
180
+ self.helper_test_composition(layoutA, layoutB)
181
+
182
+ layoutA = Layout(((4,2)), ((1,16)))
183
+ layoutB = Layout((4,2), (2,1))
184
+ self.helper_test_composition(layoutA, layoutB)
185
+
186
+ layoutA = Layout((2,2), (2,1))
187
+ layoutB = Layout((2,2), (2,1))
188
+ self.helper_test_composition(layoutA, layoutB)
189
+
190
+ layoutA = Layout((4,8,2))
191
+ layoutB = Layout((2,2,2), (2,8,1))
192
+ self.helper_test_composition(layoutA, layoutB)
193
+
194
+ layoutA = Layout((4,8,2), (2,8,1))
195
+ layoutB = Layout((2,2,2), (1,8,2))
196
+ self.helper_test_composition(layoutA, layoutB)
197
+
198
+ layoutA = Layout((4,8,2), (2,8,1))
199
+ layoutB = Layout((4,2,2), (2,8,1))
200
+ self.helper_test_composition(layoutA, layoutB)
201
+
202
+ # Pre-coalesced LHS
203
+ layoutA = Layout((4,6,8),(1,4,7))
204
+ layoutB = Layout((6),(1))
205
+ self.helper_test_composition(layoutA, layoutB)
206
+
207
+ # Mid-layout truncation
208
+ layoutA = Layout((4,6,8,10),(2,3,5,7))
209
+ layoutB = Layout(6,12)
210
+ self.helper_test_composition(layoutA, layoutB)
211
+
212
+ if __name__ == "__main__":
213
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Unit tests for pycute.int_tuple
35
+ """
36
+
37
+ import unittest
38
+
39
+ from pycute import *
40
+
41
+
42
+ class TestIntTuple(unittest.TestCase):
43
+ def test_product(self):
44
+ self.assertEqual(product(2), 2)
45
+
46
+ self.assertEqual(product((3,2)), 6)
47
+
48
+ self.assertEqual(product(product(((2,3),4))), 24)
49
+
50
+ def test_inner_product(self):
51
+ self.assertEqual(inner_product(2, 3), 6)
52
+
53
+ self.assertEqual(inner_product((1,2), (3,2)), 7)
54
+
55
+ self.assertEqual(inner_product(((2,3),4), ((2,1),2)), 15)
56
+
57
+ def test_shape_div(self):
58
+ self.assertEqual(shape_div((3,4), 6), (1,2))
59
+
60
+ self.assertEqual(shape_div((3,4), 12), (1,1))
61
+
62
+ self.assertEqual(shape_div((3,4), 36), (1,1))
63
+
64
+ self.assertEqual(shape_div(((3,4),6), 36), ((1,1),2))
65
+
66
+ self.assertEqual(shape_div((6,(3,4)), 36), (1,(1,2)))
67
+
68
+ def test_prefix_product(self):
69
+ self.assertEqual(prefix_product(2), 1)
70
+
71
+ self.assertEqual(prefix_product((3,2)), (1,3))
72
+
73
+ self.assertEqual(prefix_product((3,2,4)), (1,3,6))
74
+
75
+ self.assertEqual(prefix_product(((2,3),4)), ((1,2),6))
76
+
77
+ self.assertEqual(prefix_product(((2,3),(2, 1, 2),( 5, 2, 1))),
78
+ ((1,2),(6,12,12),(24,120,240)))
79
+
80
+
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Unit tests for pycute.left_inverse
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ from pycute import *
41
+
42
+ _LOGGER = logging.getLogger(__name__)
43
+
44
+
45
+ class TestLeftInverse(unittest.TestCase):
46
+ def helper_test_left_inverse(self, layout):
47
+ inv_layout = left_inverse(layout)
48
+
49
+ _LOGGER.debug(f"{layout} => {inv_layout}")
50
+
51
+ for i in range(size(layout)):
52
+ self.assertEqual(inv_layout(layout(i)), i)
53
+
54
+ def test_left_inverse(self):
55
+ test = Layout(1,0)
56
+ self.helper_test_left_inverse(test)
57
+
58
+ test = Layout((1,1),(0,0))
59
+ self.helper_test_left_inverse(test)
60
+
61
+ test = Layout(1,1)
62
+ self.helper_test_left_inverse(test)
63
+
64
+ test = Layout(4,1)
65
+ self.helper_test_left_inverse(test)
66
+
67
+ test = Layout(4,2)
68
+ self.helper_test_left_inverse(test)
69
+
70
+ test = Layout((8,4),(1,8))
71
+ self.helper_test_left_inverse(test)
72
+
73
+ test = Layout((8,4),(4,1))
74
+ self.helper_test_left_inverse(test)
75
+
76
+ test = Layout((2,4,6),(1,2,8))
77
+ self.helper_test_left_inverse(test)
78
+
79
+ test = Layout((2,4,6),(4,1,8))
80
+ self.helper_test_left_inverse(test)
81
+
82
+ test = Layout((4,2),(1,16))
83
+ self.helper_test_left_inverse(test)
84
+
85
+
86
+ if __name__ == "__main__":
87
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Unit tests for pycute.left_inverse
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+
40
+ from pycute import *
41
+
42
+ _LOGGER = logging.getLogger(__name__)
43
+
44
+
45
+ class TestRightInverse(unittest.TestCase):
46
+ def helper_test_right_inverse(self, layout):
47
+ inv_layout = right_inverse(layout)
48
+
49
+ _LOGGER.debug(f"{layout} => {inv_layout}")
50
+
51
+ for i in range(size(inv_layout)):
52
+ self.assertEqual(layout(inv_layout(i)), i)
53
+
54
+ def test_right_inverse(self):
55
+ test = Layout(1,0)
56
+ self.helper_test_right_inverse(test)
57
+
58
+ test = Layout((1,1),(0,0))
59
+ self.helper_test_right_inverse(test)
60
+
61
+ test = Layout((3,7),(0,0))
62
+ self.helper_test_right_inverse(test)
63
+
64
+ test = Layout(1,1)
65
+ self.helper_test_right_inverse(test)
66
+
67
+ test = Layout(4,0)
68
+ self.helper_test_right_inverse(test)
69
+
70
+ test = Layout(4,1)
71
+ self.helper_test_right_inverse(test)
72
+
73
+ test = Layout(4,2)
74
+ self.helper_test_right_inverse(test)
75
+
76
+ test = Layout((2,4),(0,2))
77
+ self.helper_test_right_inverse(test)
78
+
79
+ test = Layout((8,4),(1,8))
80
+ self.helper_test_right_inverse(test)
81
+
82
+ test = Layout((8,4),(4,1))
83
+ self.helper_test_right_inverse(test)
84
+
85
+ test = Layout((2,4,6),(1,2,8))
86
+ self.helper_test_right_inverse(test)
87
+
88
+ test = Layout((2,4,6),(4,1,8))
89
+ self.helper_test_right_inverse(test)
90
+
91
+ test = Layout((4,2),(1,16))
92
+ self.helper_test_right_inverse(test)
93
+
94
+
95
+ if __name__ == "__main__":
96
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #################################################################################################
2
+ #
3
+ # Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # Redistribution and use in source and binary forms, with or without
7
+ # modification, are permitted provided that the following conditions are met:
8
+ #
9
+ # 1. Redistributions of source code must retain the above copyright notice, this
10
+ # list of conditions and the following disclaimer.
11
+ #
12
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ # this list of conditions and the following disclaimer in the documentation
14
+ # and/or other materials provided with the distribution.
15
+ #
16
+ # 3. Neither the name of the copyright holder nor the names of its
17
+ # contributors may be used to endorse or promote products derived from
18
+ # this software without specific prior written permission.
19
+ #
20
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30
+ #
31
+ #################################################################################################
32
+
33
+ """
34
+ Unit tests for pycute.typing
35
+ """
36
+
37
+ import logging
38
+ import unittest
39
+ from pycute import *
40
+
41
+ _LOGGER = logging.getLogger(__name__)
42
+
43
+
44
+ class TestTyping(unittest.TestCase):
45
+ def helper_test_typing(self, _cls, _obj, cls, expected: bool):
46
+ _LOGGER.debug(f"issubclass({_cls}, {cls})")
47
+ _LOGGER.debug(f"isinstance({_obj}, {cls})")
48
+
49
+ self.assertEqual(expected, issubclass(_cls, cls))
50
+ self.assertEqual(expected, isinstance(_obj, cls))
51
+
52
+ def test_typing(self):
53
+ self.helper_test_typing(int, 1, Integer, True)
54
+ self.helper_test_typing(float, 1., Integer, False)
55
+ self.helper_test_typing(str, 'hi', Integer, False)
56
+ self.helper_test_typing(bool, False, Integer, False)
57
+
58
+ if __name__ == '__main__':
59
+ unittest.main()
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+ #pragma warning (disable : 4068 ) /* disable unknown pragma warnings for visual studio */
34
+
35
+ #pragma nv_diag_suppress boolean_controlling_expr_is_constant
36
+ #include <gtest/gtest.h>
37
+ #pragma nv_diag_warning boolean_controlling_expr_is_constant
38
+ #pragma warning( disable : 4503)
39
+
40
+ #include <cstdlib>
41
+ #include <string>
42
+
43
+ #include <cuda_runtime_api.h>
44
+
45
+ /////////////////////////////////////////////////////////////////////////////////////////////////
46
+
47
+ /// Gets a CUDA device
48
+ cudaDeviceProp GetCudaDevice();
49
+
50
+ /// Prints device properties
51
+ std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &device);
52
+
53
+ /////////////////////////////////////////////////////////////////////////////////////////////////
54
+
55
+ /// Sets flags for Unit test
56
+ void FilterArchitecture();
57
+
58
+ /////////////////////////////////////////////////////////////////////////////////////////////////
59
+
60
+ /// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order
61
+ // of problem sizes run by CUTLASS unit tests
62
+ int CutlassUnitTestProblemCount();
63
+
64
+ /////////////////////////////////////////////////////////////////////////////////////////////////
65
+
66
+ // active test macro
67
+ #define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \
68
+ TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__
69
+
70
+ // disabled test macro
71
+ #define CUTLASS_TEST_LEVEL_DISABLED(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \
72
+ TEST(NAME_STATIC,DISABLED_L##LEVEL##_##NAME_DYNAMIC) {}
73
+
74
+ #if CUTLASS_TEST_LEVEL == 0
75
+ #define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
76
+ #define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
77
+ #define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
78
+ #elif CUTLASS_TEST_LEVEL == 1
79
+ #define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
80
+ #define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
81
+ #define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
82
+ #else
83
+ #define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
84
+ #define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
85
+ #define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__)
86
+ #endif
87
+
88
+ #if !defined(CUTLASS_TEST_UNIT_ENABLE_WARNINGS)
89
+ #define CUTLASS_TEST_UNIT_ENABLE_WARNINGS false
90
+ #endif
91
+
92
+ #if (__CUDACC_VER_MAJOR__ >= 12)
93
+ #define CUDA_12_0_SM90_FEATURES_SUPPORTED true
94
+ #else
95
+ #define CUDA_12_0_SM90_FEATURES_SUPPORTED false
96
+ #endif
97
+
98
+ #include <cutlass/cutlass.h>
99
+ #include <cutlass/numeric_types.h>
100
+ #include <cutlass/trace.h>
101
+
102
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Helper to construct cached name for
33
+ */
34
+ #pragma once
35
+
36
+ #include <typeinfo>
37
+ #include <fstream>
38
+ #include <list>
39
+ #include <utility>
40
+ #include <sstream>
41
+
42
+ #include "cutlass/cutlass.h"
43
+ #include "cutlass/layout/matrix.h"
44
+ #include "cutlass/conv/convolution.h"
45
+ #include "cutlass/conv/conv2d_problem_size.h"
46
+
47
+ #include "cutlass/conv/conv3d_problem_size.h"
48
+ #include "cutlass/core_io.h"
49
+ #include "cutlass/util/tensor_view_io.h"
50
+
51
+ #include "thrust/universal_vector.h"
52
+
53
+ #ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS
54
+ #define CUTLASS_TEST_ENABLE_CACHED_RESULTS false
55
+ #endif
56
+
57
+ /////////////////////////////////////////////////////////////////////////////////////////////////
58
+
59
+ namespace test::conv::device {
60
+
61
+ /////////////////////////////////////////////////////////////////////////////////////////////////
62
+
63
+ /// Result of a test
64
+ struct CachedTestKey {
65
+
66
+ std::string op; ///< Concatenated string representation of operation performed
67
+ std::string problem; ///< Concatenated string representation of problem description
68
+ std::string types; ///< Concatenated string representation of operand types
69
+ uint32_t A; ///< Hashed result of tensor A
70
+ uint32_t B; ///< Hashed result of tensor B
71
+ uint32_t C; ///< Hashed result of tensor C
72
+
73
+ //
74
+ // Methods
75
+ //
76
+ inline CachedTestKey(): A(), B(), C() { }
77
+
78
+ inline CachedTestKey(
79
+ std::string op, ///< Concatenated string representation of operation performed
80
+ std::string problem, ///< Concatenated string representation of problem description
81
+ std::string types, ///< Concatenated string representation of operand types
82
+ uint32_t A, ///< Hashed result of tensor A
83
+ uint32_t B, ///< Hashed result of tensor B
84
+ uint32_t C ///< Hashed result of tensor C
85
+ ):
86
+ op(op), problem(problem), types(types), A(A), B(B), C(C)
87
+ { }
88
+
89
+ /// Checks for equality of the problem
90
+ bool operator==(CachedTestKey const &rhs) const {
91
+ return op == rhs.op && problem == rhs.problem && types == rhs.types && A == rhs.A && B == rhs.B && C == rhs.C;
92
+ }
93
+ };
94
+
95
+ /////////////////////////////////////////////////////////////////////////////////////////////////
96
+
97
+ inline std::istream &operator>>(std::istream &in, CachedTestKey &result) {
98
+
99
+ in >> result.op;
100
+ in >> result.problem;
101
+ in >> result.types;
102
+ in >> result.A;
103
+ in >> result.B;
104
+ in >> result.C;
105
+
106
+ return in;
107
+ }
108
+
109
+ inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) {
110
+
111
+ out << result.op << " ";
112
+ out << result.problem << " ";
113
+ out << result.types << " ";
114
+ out << result.A << " ";
115
+ out << result.B << " ";
116
+ out << result.C << " ";
117
+
118
+ return out;
119
+ }
120
+
121
+ /////////////////////////////////////////////////////////////////////////////////////////////////
122
+
123
+ struct CachedTestResult {
124
+ uint32_t D;
125
+ //
126
+ // Methods
127
+ //
128
+
129
+ CachedTestResult(): D()
130
+ { }
131
+
132
+ CachedTestResult(uint32_t D): D(D)
133
+ { }
134
+
135
+ operator bool() const {
136
+ return bool(D);
137
+ }
138
+ };
139
+
140
+ /////////////////////////////////////////////////////////////////////////////////////////////////
141
+
142
+ inline std::istream &operator>>(std::istream &in, CachedTestResult &result) {
143
+ in >> result.D;
144
+ return in;
145
+ }
146
+
147
+ inline std::ostream &operator<<(std::ostream &out, CachedTestResult const &result) {
148
+ out << result.D;
149
+ return out;
150
+ }
151
+
152
+ /////////////////////////////////////////////////////////////////////////////////////////////////
153
+
154
+ struct CachedTestResultListing {
155
+
156
+ std::list<std::pair<CachedTestKey, CachedTestResult>> results;
157
+
158
+ //
159
+ // Methods
160
+ //
161
+
162
+ inline CachedTestResultListing(std::string const &path) {
163
+ std::ifstream file(path);
164
+
165
+ while (file.good()) {
166
+ CachedTestKey key;
167
+ file >> key;
168
+
169
+ CachedTestResult result;
170
+ file >> result;
171
+
172
+ if (result) {
173
+ results.push_back(std::make_pair(key, result));
174
+ }
175
+ }
176
+ }
177
+
178
+ /// Returns the cached result
179
+ std::pair<bool, CachedTestResult> find(CachedTestKey const &rhs) const {
180
+ for (auto const & result : results) {
181
+ if (result.first == rhs) {
182
+ return std::make_pair(true, result.second);
183
+ }
184
+ }
185
+ return std::make_pair(false, CachedTestResult());
186
+ }
187
+
188
+ /// Appends an entry
189
+ void append(CachedTestKey const &key, CachedTestResult const &result) {
190
+ if (result) {
191
+ results.push_back(std::make_pair(key, result));
192
+ }
193
+ }
194
+
195
+ /// Writes the entire listing to a file
196
+ bool write(std::string const &path) {
197
+ std::ofstream file(path);
198
+ if (!file.good()) {
199
+ return false;
200
+ }
201
+
202
+ for (auto const &result : results) {
203
+ file << result.first << result.second << std::endl;
204
+ }
205
+
206
+ return true;
207
+ }
208
+ };
209
+
210
+ /////////////////////////////////////////////////////////////////////////////////////////////////
211
+
212
+ template <typename Element>
213
+ struct ScalarEncoder {
214
+ Element scalar;
215
+
216
+ ScalarEncoder(Element s): scalar(s) { }
217
+
218
+ std::string str() const {
219
+ std::stringstream ss;
220
+ Element s = scalar;
221
+ if (s < Element()) {
222
+ s = -s;
223
+ ss << "n";
224
+ }
225
+ ss << s;
226
+ return ss.str();
227
+ }
228
+ };
229
+
230
+ template <typename Element>
231
+ ScalarEncoder<Element> EncodeScalar(Element a) {
232
+ return ScalarEncoder<Element>(a);
233
+ }
234
+
235
+ template <typename Element>
236
+ struct ScalarEncoder<cutlass::complex<Element>> {
237
+ cutlass::complex<Element> scalar;
238
+
239
+ ScalarEncoder(cutlass::complex<Element> s): scalar(s) { }
240
+
241
+ std::string str() const {
242
+ std::stringstream ss;
243
+ ss << EncodeScalar<Element>(scalar.real()) << "_" << EncodeScalar<Element>(scalar.imag()) << "i";
244
+ return ss.str();
245
+ }
246
+ };
247
+
248
+ template <typename Element>
249
+ std::ostream &operator<<(std::ostream &out, ScalarEncoder<Element> const &scalar) {
250
+ out << scalar.str();
251
+ return out;
252
+ }
253
+
254
+ /////////////////////////////////////////////////////////////////////////////////////////////////
255
+
256
+ inline char const *EncodeOperator(cutlass::conv::Operator conv_op) {
257
+ switch (conv_op) {
258
+ case cutlass::conv::Operator::kFprop: return "fprop";
259
+ case cutlass::conv::Operator::kDgrad: return "dgrad";
260
+ case cutlass::conv::Operator::kWgrad: return "wgrad";
261
+ case cutlass::conv::Operator::kDeconv: return "deconv";
262
+ }
263
+ return "conv_unknown";
264
+ }
265
+
266
+ /////////////////////////////////////////////////////////////////////////////////////////////////
267
+
268
+ // Encode GemmCoord (Gemm problem size)
269
+ inline std::ostream &EncodeProblemSize(
270
+ std::ostream &out,
271
+ cutlass::gemm::GemmCoord const &problem) {
272
+
273
+ out << problem.m() << "x" << problem.n() << "x" << problem.k() << "_";
274
+
275
+ return out;
276
+ }
277
+
278
+ /////////////////////////////////////////////////////////////////////////////////////////////////
279
+ // Encode Conv2dProblemSize
280
+ inline std::ostream &EncodeProblemSize(
281
+ std::ostream &out,
282
+ cutlass::conv::Conv2dProblemSize const &problem) {
283
+
284
+ out << problem.N << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_"
285
+ << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_";
286
+
287
+ out << "pad_h" << problem.pad_h << "w" << problem.pad_w << "_";
288
+ out << "stride_h" << problem.stride_h << "w" << problem.stride_w << "_";
289
+ out << "dil_h" << problem.dilation_h << "w" << problem.dilation_w << "_";
290
+
291
+ switch (problem.mode) {
292
+ case cutlass::conv::Mode::kCrossCorrelation:
293
+ out << "corr";
294
+ break;
295
+ case cutlass::conv::Mode::kConvolution:
296
+ out << "conv";
297
+ break;
298
+ }
299
+
300
+ return out;
301
+ }
302
+
303
+ /////////////////////////////////////////////////////////////////////////////////////////////////
304
+
305
+ // Encode Conv3dProblemSize
306
+ inline std::ostream &EncodeProblemSize(
307
+ std::ostream &out,
308
+ cutlass::conv::Conv3dProblemSize const &problem) {
309
+
310
+ out << problem.N << "x" << problem.D << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_"
311
+ << problem.Z << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_";
312
+
313
+ out << "pad_d" << problem.pad_h << "h" << problem.pad_h << "w" << problem.pad_w << "_";
314
+ out << "stride_d" << problem.stride_d << "h" << problem.stride_h << "w" << problem.stride_w << "_";
315
+ out << "dil_d" << problem.dilation_d << "h" << problem.dilation_h << "w" << problem.dilation_w << "_";
316
+
317
+ switch (problem.mode) {
318
+ case cutlass::conv::Mode::kCrossCorrelation:
319
+ out << "corr";
320
+ break;
321
+ case cutlass::conv::Mode::kConvolution:
322
+ out << "conv";
323
+ break;
324
+ }
325
+
326
+ return out;
327
+ }
328
+
329
+ /////////////////////////////////////////////////////////////////////////////////////////////////
330
+ // Encode 3.x ConvNd ProblemShape
331
+ template <class ProblemShape>
332
+ inline std::ostream &EncodeProblemSize(
333
+ std::ostream &out,
334
+ ProblemShape const& problem_shape) {
335
+
336
+ out << problem_shape.shape_A << "_";
337
+ out << problem_shape.shape_B << "_";
338
+
339
+ out << "padl" << problem_shape.lower_padding << "_";
340
+ out << "padu" << problem_shape.upper_padding << "_";
341
+ out << "str" << problem_shape.traversal_stride << "_";
342
+ out << "dil" << problem_shape.dilation << "_";
343
+
344
+ switch (problem_shape.mode) {
345
+ case cutlass::conv::Mode::kCrossCorrelation:
346
+ out << "corr";
347
+ break;
348
+ case cutlass::conv::Mode::kConvolution:
349
+ out << "conv";
350
+ break;
351
+ }
352
+
353
+ return out;
354
+ }
355
+
356
+ /////////////////////////////////////////////////////////////////////////////////////////////////
357
+
358
+ template <typename Element>
359
+ inline std::string ElementTypeName() {
360
+ return std::string(typeid(Element).name());
361
+ }
362
+
363
+ template <>
364
+ inline std::string ElementTypeName<cutlass::half_t>() {
365
+ return "h";
366
+ }
367
+
368
+ template <>
369
+ inline std::string ElementTypeName<cutlass::complex<cutlass::half_t>>() {
370
+ return "ch";
371
+ }
372
+
373
+ template <>
374
+ inline std::string ElementTypeName<cutlass::bfloat16_t>() {
375
+ return "bf16";
376
+ }
377
+
378
+ template <>
379
+ inline std::string ElementTypeName<cutlass::complex<cutlass::bfloat16_t>>() {
380
+ return "cbf16";
381
+ }
382
+
383
+ template <>
384
+ inline std::string ElementTypeName<cutlass::tfloat32_t>() {
385
+ return "tf32";
386
+ }
387
+
388
+ template <>
389
+ inline std::string ElementTypeName<cutlass::complex<cutlass::tfloat32_t>>() {
390
+ return "ctf32";
391
+ }
392
+
393
+ template <>
394
+ inline std::string ElementTypeName<cutlass::complex<float>>() {
395
+ return "c";
396
+ }
397
+
398
+ template <>
399
+ inline std::string ElementTypeName<cutlass::complex<double>>() {
400
+ return "z";
401
+ }
402
+
403
+ template <>
404
+ inline std::string ElementTypeName<cutlass::Quaternion<float>>() {
405
+ return "q";
406
+ }
407
+
408
+ template <>
409
+ inline std::string ElementTypeName<int8_t>() {
410
+ return "s8";
411
+ }
412
+
413
+ template <>
414
+ inline std::string ElementTypeName<uint8_t>() {
415
+ return "u8";
416
+ }
417
+
418
+ template <>
419
+ inline std::string ElementTypeName<cutlass::int4b_t>() {
420
+ return "s4";
421
+ }
422
+
423
+ template <>
424
+ inline std::string ElementTypeName<cutlass::uint4b_t>() {
425
+ return "u4";
426
+ }
427
+
428
+ /////////////////////////////////////////////////////////////////////////////////////////////////
429
+
430
+ template <typename Layout>
431
+ inline std::string LayoutTypeName() {
432
+ return std::string(typeid(Layout).name());
433
+ }
434
+
435
+ template <>
436
+ inline std::string LayoutTypeName<cutlass::layout::ColumnMajor>() {
437
+ return "n";
438
+ }
439
+
440
+ template <>
441
+ inline std::string LayoutTypeName<cutlass::layout::RowMajor>() {
442
+ return "t";
443
+ }
444
+
445
+ template <>
446
+ inline std::string LayoutTypeName<cutlass::layout::TensorNHWC>() {
447
+ return "nhwc";
448
+ }
449
+
450
+ template <>
451
+ inline std::string LayoutTypeName<cutlass::layout::TensorNCxHWx<32>>() {
452
+ return "nc32hw32";
453
+ }
454
+
455
+ template <>
456
+ inline std::string LayoutTypeName<cutlass::layout::TensorNCxHWx<64>>() {
457
+ return "nc64hw64";
458
+ }
459
+
460
+ template <>
461
+ inline std::string LayoutTypeName<cutlass::layout::TensorCxRSKx<32>>() {
462
+ return "c32rsk32";
463
+ }
464
+
465
+ template <>
466
+ inline std::string LayoutTypeName<cutlass::layout::TensorCxRSKx<64>>() {
467
+ return "c64rsk64";
468
+ }
469
+
470
+ template <>
471
+ inline std::string LayoutTypeName<cutlass::layout::TensorNDHWC>() {
472
+ return "ndhwc";
473
+ }
474
+
475
+ /////////////////////////////////////////////////////////////////////////////////////////////////
476
+
477
+ template <typename Element, typename Layout>
478
+ inline std::string TensorTypeName() {
479
+ std::stringstream ss;
480
+ ss << ElementTypeName<Element>() << LayoutTypeName<Layout>();
481
+ return ss.str();
482
+ }
483
+
484
+ template <typename Element>
485
+ inline std::string TensorTypeName() {
486
+ std::stringstream ss;
487
+ ss << ElementTypeName<Element>();
488
+ return ss.str();
489
+ }
490
+ /////////////////////////////////////////////////////////////////////////////////////////////////
491
+
492
+ /// Hash function on a byte array
493
+ struct CRC32 {
494
+
495
+ uint32_t table[256];
496
+
497
+ //
498
+ // Methods
499
+ //
500
+
501
+ CRC32() {
502
+
503
+ uint32_t rem;
504
+ int i, j;
505
+
506
+ for (i = 0; i < 256; i++) {
507
+ rem = i;
508
+ for (j = 0; j < 8; j++) {
509
+ if (rem & 1) {
510
+ rem >>= 1;
511
+ rem ^= 0xedb88320;
512
+ } else
513
+ rem >>= 1;
514
+ }
515
+ table[i] = rem;
516
+ }
517
+ }
518
+
519
+ /// Computes the CRC of an array of bytes
520
+ uint32_t operator()(void const *start, size_t length, uint32_t crc = uint32_t()) const {
521
+ uint8_t const *p = static_cast<uint8_t const *>(start);
522
+ uint8_t const *q = static_cast<uint8_t const *>(start) + length;
523
+
524
+ crc = ~crc;
525
+
526
+ for (; p != q; ++p) {
527
+ uint8_t octet = *p;
528
+ crc = (crc >> 8) ^ table[(crc & 0xff) ^ octet];
529
+ }
530
+
531
+ return ~crc;
532
+ }
533
+ };
534
+
535
+ /////////////////////////////////////////////////////////////////////////////////////////////////
536
+
537
+ template <
538
+ typename Element, typename Layout
539
+ >
540
+ uint32_t TensorHash(
541
+ cutlass::TensorView<Element, Layout> view,
542
+ CRC32 const &hash = CRC32(),
543
+ uint32_t crc = uint32_t()
544
+ ) {
545
+
546
+ return hash(view.data(), view.capacity() * cutlass::sizeof_bits<Element>::value / 8, crc);
547
+ }
548
+
549
+ template <typename Element>
550
+ uint32_t TensorHash(
551
+ thrust::universal_vector<Element>& tensor,
552
+ CRC32 const &hash = CRC32(),
553
+ uint32_t crc = uint32_t()
554
+ ) {
555
+
556
+ return hash(tensor.data().get(), tensor.size() * cutlass::sizeof_bits<Element>::value / 8, crc);
557
+ }
558
+
559
+ /////////////////////////////////////////////////////////////////////////////////////////////////
560
+
561
+ template <
562
+ typename ElementA, typename LayoutA,
563
+ typename ElementB, typename LayoutB,
564
+ typename ElementC, typename LayoutC,
565
+ typename ElementAccumulator,
566
+ typename ElementCompute
567
+ >
568
+ inline std::ostream &EncodeTypes(
569
+ std::ostream &out
570
+ ) {
571
+
572
+ out << TensorTypeName<ElementA, LayoutA>() << "_"
573
+ << TensorTypeName<ElementB, LayoutB>() << "_"
574
+ << TensorTypeName<ElementC, LayoutC>() << "_"
575
+ << ElementTypeName<ElementAccumulator>() << "_"
576
+ << ElementTypeName<ElementCompute>();
577
+
578
+ return out;
579
+ }
580
+
581
+ template <
582
+ typename ElementA,
583
+ typename ElementB,
584
+ typename ElementC,
585
+ typename ElementD
586
+ >
587
+ inline std::ostream &EncodeTypes(
588
+ std::ostream &out
589
+ ) {
590
+
591
+ out << TensorTypeName<ElementA>() << "_"
592
+ << TensorTypeName<ElementB>() << "_"
593
+ << TensorTypeName<ElementC>() << "_"
594
+ << ElementTypeName<ElementD>();
595
+
596
+ return out;
597
+ }
598
+ /////////////////////////////////////////////////////////////////////////////////////////////////
599
+
600
+ template <
601
+ typename ElementA, typename LayoutA,
602
+ typename ElementB, typename LayoutB,
603
+ typename ElementC, typename LayoutC,
604
+ typename ElementAccumulator,
605
+ typename ElementCompute
606
+ >
607
+ inline CachedTestKey CreateCachedGemmTestKey(
608
+ cutlass::gemm::GemmCoord const &problem,
609
+ ElementCompute alpha,
610
+ ElementCompute beta,
611
+ cutlass::TensorView<ElementA, LayoutA> A,
612
+ cutlass::TensorView<ElementB, LayoutB> B,
613
+ cutlass::TensorView<ElementC, LayoutC> C
614
+ ) {
615
+
616
+ CachedTestKey key;
617
+
618
+ // Encode gemm operator and problem sizes
619
+ key.op = "gemm";
620
+
621
+ std::stringstream ss_problem;
622
+ EncodeProblemSize(ss_problem, problem);
623
+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
624
+ key.problem = ss_problem.str();
625
+
626
+ // Encode problem data types
627
+ std::stringstream ss_types;
628
+ EncodeTypes<
629
+ ElementA, LayoutA,
630
+ ElementB, LayoutB,
631
+ ElementC, LayoutC,
632
+ ElementAccumulator,
633
+ ElementCompute>(ss_types);
634
+ key.types = ss_types.str();
635
+
636
+ // Encode hash for problem data
637
+ CRC32 crc_hash;
638
+ key.A = TensorHash(A, crc_hash);
639
+ key.B = TensorHash(B, crc_hash);
640
+ key.C = TensorHash(C, crc_hash);
641
+
642
+ return key;
643
+ }
644
+
645
+ /////////////////////////////////////////////////////////////////////////////////////////////////
646
+
647
+
648
+ template <
649
+ typename ElementA, typename LayoutA,
650
+ typename ElementB, typename LayoutB,
651
+ typename ElementC, typename LayoutC,
652
+ typename ElementAccumulator,
653
+ typename ElementCompute
654
+ >
655
+ inline CachedTestKey CreateCachedConv2dTestKey(
656
+
657
+ cutlass::conv::Operator conv_operator,
658
+ cutlass::conv::Conv2dProblemSize const &problem,
659
+ ElementCompute alpha,
660
+ ElementCompute beta,
661
+ cutlass::TensorView<ElementA, LayoutA> A,
662
+ cutlass::TensorView<ElementB, LayoutB> B,
663
+ cutlass::TensorView<ElementC, LayoutC> C
664
+ ) {
665
+
666
+ CachedTestKey key;
667
+
668
+ // Encode conv2d operator and problem sizes
669
+ key.op = "conv2d";
670
+
671
+ std::stringstream ss_problem;
672
+ ss_problem << EncodeOperator(conv_operator) << "_";
673
+ EncodeProblemSize(ss_problem, problem);
674
+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
675
+
676
+ key.problem = ss_problem.str();
677
+
678
+ // Encode problem data types
679
+ std::stringstream ss_types;
680
+ EncodeTypes<
681
+ ElementA, LayoutA,
682
+ ElementB, LayoutB,
683
+ ElementC, LayoutC,
684
+ ElementAccumulator,
685
+ ElementCompute>(ss_types);
686
+ key.types = ss_types.str();
687
+
688
+ // Encode hash for problem data
689
+ CRC32 crc_hash;
690
+
691
+ key.A = TensorHash(A, crc_hash);
692
+ key.B = TensorHash(B, crc_hash);
693
+ key.C = TensorHash(C, crc_hash);
694
+
695
+ return key;
696
+ }
697
+
698
+ /////////////////////////////////////////////////////////////////////////////////////////////////
699
+
700
+ template <
701
+ typename ElementA, typename LayoutA,
702
+ typename ElementB, typename LayoutB,
703
+ typename ElementC, typename LayoutC,
704
+ typename ElementAccumulator,
705
+ typename ElementCompute
706
+ >
707
+ inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey(
708
+
709
+ cutlass::conv::Operator conv_operator,
710
+ cutlass::conv::Conv2dProblemSize const &problem,
711
+ ElementCompute alpha,
712
+ ElementCompute beta,
713
+ cutlass::TensorView<ElementA, LayoutA> A,
714
+ cutlass::TensorView<ElementB, LayoutB> B,
715
+ cutlass::TensorView<ElementC, LayoutC> C
716
+ ) {
717
+
718
+ CachedTestKey key;
719
+
720
+ // Encode conv2d operator and problem sizes
721
+ key.op = "conv2d_with_broadcast";
722
+
723
+ std::stringstream ss_problem;
724
+ ss_problem << EncodeOperator(conv_operator) << "_";
725
+ EncodeProblemSize(ss_problem, problem);
726
+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
727
+
728
+ key.problem = ss_problem.str();
729
+
730
+ // Encode problem data types
731
+ std::stringstream ss_types;
732
+ EncodeTypes<
733
+ ElementA, LayoutA,
734
+ ElementB, LayoutB,
735
+ ElementC, LayoutC,
736
+ ElementAccumulator,
737
+ ElementCompute>(ss_types);
738
+ key.types = ss_types.str();
739
+
740
+ // Encode hash for problem data
741
+ CRC32 crc_hash;
742
+
743
+ key.A = TensorHash(A, crc_hash);
744
+ key.B = TensorHash(B, crc_hash);
745
+ key.C = TensorHash(C, crc_hash);
746
+
747
+ return key;
748
+ }
749
+
750
+ /////////////////////////////////////////////////////////////////////////////////////////////////
751
+
752
+ template <
753
+ typename ElementA, typename LayoutA,
754
+ typename ElementB, typename LayoutB,
755
+ typename ElementC, typename LayoutC,
756
+ typename ElementAccumulator,
757
+ typename ElementCompute
758
+ >
759
+ inline CachedTestKey CreateCachedConv2dWithReductionTestKey(
760
+
761
+ cutlass::conv::Operator conv_operator,
762
+ cutlass::conv::Conv2dProblemSize const &problem,
763
+ ElementCompute alpha,
764
+ ElementCompute beta,
765
+ cutlass::TensorView<ElementA, LayoutA> A,
766
+ cutlass::TensorView<ElementB, LayoutB> B,
767
+ cutlass::TensorView<ElementC, LayoutC> C
768
+ ) {
769
+
770
+ CachedTestKey key;
771
+
772
+ // Encode conv2d operator and problem sizes
773
+ key.op = "conv2d_with_reduction";
774
+
775
+ std::stringstream ss_problem;
776
+ ss_problem << EncodeOperator(conv_operator) << "_";
777
+ EncodeProblemSize(ss_problem, problem);
778
+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
779
+
780
+ key.problem = ss_problem.str();
781
+
782
+ // Encode problem data types
783
+ std::stringstream ss_types;
784
+ EncodeTypes<
785
+ ElementA, LayoutA,
786
+ ElementB, LayoutB,
787
+ ElementC, LayoutC,
788
+ ElementAccumulator,
789
+ ElementCompute>(ss_types);
790
+ key.types = ss_types.str();
791
+
792
+ // Encode hash for problem data
793
+ CRC32 crc_hash;
794
+
795
+ key.A = TensorHash(A, crc_hash);
796
+ key.B = TensorHash(B, crc_hash);
797
+ key.C = TensorHash(C, crc_hash);
798
+
799
+ return key;
800
+ }
801
+
802
+ /////////////////////////////////////////////////////////////////////////////////////////////////
803
+
804
+ template <
805
+ typename ElementA, typename LayoutA,
806
+ typename ElementB, typename LayoutB,
807
+ typename ElementC, typename LayoutC,
808
+ typename ElementAccumulator,
809
+ typename ElementCompute
810
+ >
811
+ inline CachedTestKey CreateCachedConv3dTestKey(
812
+ cutlass::conv::Operator conv_operator,
813
+ cutlass::conv::Conv3dProblemSize const &problem,
814
+ ElementCompute alpha,
815
+ ElementCompute beta,
816
+ cutlass::TensorView<ElementA, LayoutA> A,
817
+ cutlass::TensorView<ElementB, LayoutB> B,
818
+ cutlass::TensorView<ElementC, LayoutC> C
819
+ ) {
820
+
821
+ CachedTestKey key;
822
+
823
+ // Encode conv3d operator and problem sizes
824
+ key.op = "conv3d";
825
+
826
+ std::stringstream ss_problem;
827
+
828
+ ss_problem << EncodeOperator(conv_operator) << "_";
829
+ EncodeProblemSize(ss_problem, problem);
830
+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
831
+
832
+ key.problem = ss_problem.str();
833
+
834
+ // Encode problem data types
835
+ std::stringstream ss_types;
836
+ EncodeTypes<
837
+ ElementA, LayoutA,
838
+ ElementB, LayoutB,
839
+ ElementC, LayoutC,
840
+ ElementAccumulator,
841
+ ElementCompute>(ss_types);
842
+ key.types = ss_types.str();
843
+
844
+ // Encode problem data
845
+ CRC32 crc_hash;
846
+ key.A = TensorHash(A, crc_hash);
847
+ key.B = TensorHash(B, crc_hash);
848
+ key.C = TensorHash(C, crc_hash);
849
+
850
+ return key;
851
+ }
852
+
853
+ /////////////////////////////////////////////////////////////////////////////////////////////////
854
+
855
+ template <
856
+ class ProblemShape,
857
+ typename ElementA,
858
+ typename ElementB,
859
+ typename ElementC,
860
+ typename ElementD
861
+ >
862
+ inline CachedTestKey CreateCachedConvNd3xTestKey(
863
+ cutlass::conv::Operator conv_operator,
864
+ ProblemShape const& problem_shape,
865
+ double alpha,
866
+ double beta,
867
+ thrust::universal_vector<ElementA> A,
868
+ thrust::universal_vector<ElementB> B,
869
+ thrust::universal_vector<ElementC> C
870
+ ) {
871
+
872
+ CachedTestKey key;
873
+
874
+ // Encode convNd operator and problem sizes
875
+ std::stringstream ss_op;
876
+ ss_op << "conv" << ProblemShape::RankS << "d";
877
+ key.op = ss_op.str();
878
+
879
+ std::stringstream ss_problem;
880
+ ss_problem << EncodeOperator(conv_operator) << "_";
881
+ EncodeProblemSize(ss_problem, problem_shape);
882
+ ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta);
883
+ key.problem = ss_problem.str();
884
+
885
+ // Encode problem data types
886
+ std::stringstream ss_types;
887
+ EncodeTypes<
888
+ ElementA,
889
+ ElementB,
890
+ ElementC,
891
+ ElementD>(ss_types);
892
+ key.types = ss_types.str();
893
+
894
+ // Encode problem data
895
+ CRC32 crc_hash;
896
+ key.A = TensorHash(A, crc_hash);
897
+ key.B = TensorHash(B, crc_hash);
898
+ key.C = TensorHash(C, crc_hash);
899
+
900
+ return key;
901
+ }
902
+
903
+ /////////////////////////////////////////////////////////////////////////////////////////////////
904
+
905
+ } // namespace test::conv::device
906
+
907
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implicit GEMM testbed sizes for Conv2d problem
33
+ */
34
+ #pragma once
35
+
36
+ #include <vector>
37
+
38
+ #include "cutlass/cutlass.h"
39
+ #include "cutlass/layout/matrix.h"
40
+ #include "cutlass/conv/convolution.h"
41
+ #include "cutlass/conv/conv2d_problem_size.h"
42
+
43
+ namespace test {
44
+ namespace conv {
45
+ namespace device {
46
+
47
+ using Conv2dProblemVector = std::vector<cutlass::conv::Conv2dProblemSize>;
48
+
49
+ //
50
+ // Structures to prune items from Conv2dProblemVector
51
+ //
52
+ // Specification template for pruning items for convolution problem lists
53
+ template <typename T> struct Specification
54
+ {
55
+ virtual ~Specification() = default;
56
+ virtual bool is_satisfied(T item) const = 0;
57
+ };
58
+
59
+ // input size (NHWC) specification
60
+ struct InputSizeSpecification : Specification<cutlass::conv::Conv2dProblemSize>
61
+ {
62
+ cutlass::Tensor4DCoord input_size;
63
+
64
+ InputSizeSpecification(cutlass::Tensor4DCoord input_size_) : input_size(input_size_) {}
65
+
66
+ bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override {
67
+ return ((input_size.n() == item.N) && (input_size.h() == item.H) && (input_size.w() == item.W) && (input_size.c() == item.C));
68
+ }
69
+ };
70
+
71
+ // stride (stride_h, stride_w) specification
72
+ struct StrideSpecification : Specification<cutlass::conv::Conv2dProblemSize>
73
+ {
74
+ cutlass::MatrixCoord stride;
75
+
76
+ StrideSpecification(cutlass::MatrixCoord stride_) : stride(stride_) {}
77
+
78
+ bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override {
79
+ return ((stride.row() == item.stride_h) && (stride.column() == item.stride_h));
80
+ }
81
+ };
82
+
83
+ // channel (C,K) specification, must be multiple of minimum channel
84
+ struct ChannelDivisibilitySpecification : Specification<cutlass::conv::Conv2dProblemSize>
85
+ {
86
+ int channel_multiple;
87
+
88
+ ChannelDivisibilitySpecification(int channel_multiple_) : channel_multiple(channel_multiple_) {}
89
+
90
+ bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override {
91
+ return ((item.K % channel_multiple == 0) && (item.C % channel_multiple == 0));
92
+ }
93
+ };
94
+
95
+ //
96
+ // Pruning function for items from Conv2dProblemVector based on a Specification
97
+ //
98
+ inline Conv2dProblemVector prune(Conv2dProblemVector const &items,
99
+ Specification<cutlass::conv::Conv2dProblemSize> const &spec)
100
+ {
101
+ Conv2dProblemVector pruned_list;
102
+
103
+ for (auto& p : items)
104
+ if (spec.is_satisfied(p))
105
+ pruned_list.push_back(p);
106
+ return pruned_list;
107
+ }
108
+
109
+
110
+ ////////////////////////////////////////////////////////////////////////////
111
+ /// Structure TestbedConv2dProblemSizes initializes and holds conv default and
112
+ /// important network sizes
113
+ ////////////////////////////////////////////////////////////////////////////
114
+ struct TestbedConv2dProblemSizes {
115
+
116
+ //
117
+ // Data members
118
+ //
119
+ int minimum_channel_size;
120
+
121
+ Conv2dProblemVector conv2d_default_sizes;
122
+ Conv2dProblemVector conv2d_rigorous_sizes;
123
+ Conv2dProblemVector conv2d_resnet50_sizes;
124
+ Conv2dProblemVector conv2d_resnet50_sizes_perf;
125
+
126
+ //
127
+ // Methods
128
+ //
129
+ /// Default ctor
130
+ TestbedConv2dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) {
131
+ initialize_conv2d_default_sizes();
132
+ initialize_conv2d_rigorous_sizes();
133
+ initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes, 1 /*batch-size*/);
134
+
135
+ initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes_perf, 34 /*batch-size*/);
136
+ filter_all();
137
+ }
138
+
139
+ /// Eliminates some illegal cases
140
+ void filter_all() {
141
+
142
+ Conv2dProblemVector *problems_vectors[] = {
143
+ &conv2d_default_sizes,
144
+ &conv2d_rigorous_sizes,
145
+ &conv2d_resnet50_sizes,
146
+ &conv2d_resnet50_sizes_perf
147
+ };
148
+
149
+ for (Conv2dProblemVector *problems : problems_vectors) {
150
+ Conv2dProblemVector filtered;
151
+
152
+ for (cutlass::conv::Conv2dProblemSize const & problem : *problems) {
153
+ if (!(problem.C % minimum_channel_size)) {
154
+ filtered.push_back(problem);
155
+ }
156
+ }
157
+
158
+ *problems = filtered;
159
+ }
160
+ }
161
+
162
+ // Add a few standard convolution problem sizes
163
+ void initialize_conv2d_default_sizes() {
164
+
165
+ ////////////////////////////////////////////////////////////////////////////////////////////
166
+ // Small input size x stride (1,1)
167
+ // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
168
+ ////////////////////////////////////////////////////////////////////////////////////////////
169
+
170
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
171
+ {1, 1, 1, minimum_channel_size}, // input size (NHWC)
172
+ {8, 1, 1, minimum_channel_size}, // filter size (KRSC)
173
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
174
+ {1, 1}, // stride (stride_h, stride_w)
175
+ {1, 1} // dilation (dilation_h, dilation_w)
176
+ ));
177
+
178
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
179
+ {1, 1, 8, minimum_channel_size}, // input size (NHWC)
180
+ {8, 1, 3, minimum_channel_size}, // filter size (KRSC)
181
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
182
+ {1, 1}, // stride (stride_h, stride_w)
183
+ {1, 1} // dilation (dilation_h, dilation_w)
184
+ ));
185
+
186
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
187
+ {1, 7, 8, minimum_channel_size}, // input size (NHWC)
188
+ {8, 3, 3, minimum_channel_size}, // filter size (KRSC)
189
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
190
+ {1, 1}, // stride (stride_h, stride_w)
191
+ {1, 1} // dilation (dilation_h, dilation_w)
192
+ ));
193
+
194
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
195
+ {1, 7, 9, minimum_channel_size}, // input size (NHWC)
196
+ {8, 4, 4, minimum_channel_size}, // filter size (KRSC)
197
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
198
+ {1, 1}, // stride (stride_h, stride_w)
199
+ {1, 1} // dilation (dilation_h, dilation_w)
200
+ ));
201
+
202
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
203
+ {2, 7, 9, minimum_channel_size}, // input size (NHWC)
204
+ {8, 5, 5, minimum_channel_size}, // filter size (KRSC)
205
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
206
+ {1, 1}, // stride (stride_h, stride_w)
207
+ {1, 1} // dilation (dilation_h, dilation_w)
208
+ ));
209
+
210
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
211
+ {3, 7, 9, minimum_channel_size}, // input size (NHWC)
212
+ {8, 6, 5, minimum_channel_size}, // filter size (KRSC)
213
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
214
+ {1, 1}, // stride (stride_h, stride_w)
215
+ {1, 1} // dilation (dilation_h, dilation_w)
216
+ ));
217
+
218
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
219
+ {3, 7, 9, minimum_channel_size}, // input size (NHWC)
220
+ {8, 6, 6, minimum_channel_size}, // filter size (KRSC)
221
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
222
+ {1, 1}, // stride (stride_h, stride_w)
223
+ {1, 1} // dilation (dilation_h, dilation_w)
224
+ ));
225
+
226
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
227
+ {3, 7, 9, minimum_channel_size}, // input size (NHWC)
228
+ {8, 7, 7, minimum_channel_size}, // filter size (KRSC)
229
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
230
+ {1, 1}, // stride (stride_h, stride_w)
231
+ {1, 1} // dilation (dilation_h, dilation_w)
232
+ ));
233
+
234
+ ////////////////////////////////////////////////////////////////////////////////////////////
235
+ // Small input size x stride (1,1) asymmetric paddings (1, 0, 1, 0)
236
+ // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
237
+ ////////////////////////////////////////////////////////////////////////////////////////////
238
+
239
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
240
+ {1, 1, 1, minimum_channel_size}, // input size (NHWC)
241
+ {8, 1, 1, minimum_channel_size}, // filter size (KRSC)
242
+ {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
243
+ {1, 1}, // stride (stride_h, stride_w)
244
+ {1, 1} // dilation (dilation_h, dilation_w)
245
+ ));
246
+
247
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
248
+ {1, 1, 8, minimum_channel_size}, // input size (NHWC)
249
+ {8, 1, 3, minimum_channel_size}, // filter size (KRSC)
250
+ {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
251
+ {1, 1}, // stride (stride_h, stride_w)
252
+ {1, 1} // dilation (dilation_h, dilation_w)
253
+ ));
254
+
255
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
256
+ {1, 7, 8, minimum_channel_size}, // input size (NHWC)
257
+ {8, 3, 3, minimum_channel_size}, // filter size (KRSC)
258
+ {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
259
+ {1, 1}, // stride (stride_h, stride_w)
260
+ {1, 1} // dilation (dilation_h, dilation_w)
261
+ ));
262
+
263
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
264
+ {1, 7, 9, minimum_channel_size}, // input size (NHWC)
265
+ {8, 4, 4, minimum_channel_size}, // filter size (KRSC)
266
+ {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
267
+ {1, 1}, // stride (stride_h, stride_w)
268
+ {1, 1} // dilation (dilation_h, dilation_w)
269
+ ));
270
+
271
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
272
+ {2, 7, 9, minimum_channel_size}, // input size (NHWC)
273
+ {8, 5, 5, minimum_channel_size}, // filter size (KRSC)
274
+ {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
275
+ {1, 1}, // stride (stride_h, stride_w)
276
+ {1, 1} // dilation (dilation_h, dilation_w)
277
+ ));
278
+
279
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
280
+ {3, 7, 9, minimum_channel_size}, // input size (NHWC)
281
+ {8, 6, 5, minimum_channel_size}, // filter size (KRSC)
282
+ {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
283
+ {1, 1}, // stride (stride_h, stride_w)
284
+ {1, 1} // dilation (dilation_h, dilation_w)
285
+ ));
286
+
287
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
288
+ {3, 7, 9, minimum_channel_size}, // input size (NHWC)
289
+ {8, 6, 6, minimum_channel_size}, // filter size (KRSC)
290
+ {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
291
+ {1, 1}, // stride (stride_h, stride_w)
292
+ {1, 1} // dilation (dilation_h, dilation_w)
293
+ ));
294
+
295
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
296
+ {3, 7, 9, minimum_channel_size}, // input size (NHWC)
297
+ {8, 7, 7, minimum_channel_size}, // filter size (KRSC)
298
+ {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _)
299
+ {1, 1}, // stride (stride_h, stride_w)
300
+ {1, 1} // dilation (dilation_h, dilation_w)
301
+ ));
302
+
303
+ ////////////////////////////////////////////////////////////////////////////////////////////
304
+ // Small input size x stride (2,2)
305
+ // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
306
+ ////////////////////////////////////////////////////////////////////////////////////////////
307
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
308
+ {1, 11, 7, minimum_channel_size}, // input size (NHWC)
309
+ {8, 1, 1, minimum_channel_size}, // filter size (KRSC)
310
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
311
+ {2, 2}, // stride (stride_h, stride_w)
312
+ {1, 1} // dilation (dilation_h, dilation_w)
313
+ ));
314
+
315
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
316
+ {1, 11, 7, minimum_channel_size}, // input size (NHWC)
317
+ {8, 3, 3, minimum_channel_size}, // filter size (KRSC)
318
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
319
+ {2, 2}, // stride (stride_h, stride_w)
320
+ {1, 1} // dilation (dilation_h, dilation_w)
321
+ ));
322
+
323
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
324
+ {1, 13, 11, minimum_channel_size}, // input size (NHWC)
325
+ {8, 1, 1, minimum_channel_size}, // filter size (KRSC)
326
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
327
+ {2, 2}, // stride (stride_h, stride_w)
328
+ {1, 1} // dilation (dilation_h, dilation_w)
329
+ ));
330
+
331
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
332
+ {1, 17, 19, minimum_channel_size}, // input size (NHWC)
333
+ {16, 2, 2, minimum_channel_size}, // filter size (KRSC)
334
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
335
+ {2, 2}, // stride (stride_h, stride_w)
336
+ {1, 1} // dilation (dilation_h, dilation_w)
337
+ ));
338
+
339
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
340
+ {1, 23, 5, minimum_channel_size}, // input size (NHWC)
341
+ {16, 3, 3, minimum_channel_size}, // filter size (KRSC)
342
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
343
+ {2, 2}, // stride (stride_h, stride_w)
344
+ {1, 1} // dilation (dilation_h, dilation_w)
345
+ ));
346
+
347
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
348
+ {1, 13, 17, 8}, // input size (NHWC)
349
+ {24, 3, 3, 8}, // filter size (KRSC)
350
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
351
+ {2, 2}, // stride (stride_h, stride_w)
352
+ {1, 1} // dilation (dilation_h, dilation_w)
353
+ ));
354
+
355
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
356
+ {1, 23, 21, 8}, // input size (NHWC)
357
+ {24, 3, 3, 8}, // filter size (KRSC)
358
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
359
+ {3, 3}, // stride (stride_h, stride_w)
360
+ {1, 1} // dilation (dilation_h, dilation_w)
361
+ ));
362
+
363
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
364
+ {1, 20, 24, 8}, // input size (NHWC)
365
+ {40, 3, 3, 8}, // filter size (KRSC)
366
+ {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _)
367
+ {3, 3}, // stride (stride_h, stride_w)
368
+ {1, 1} // dilation (dilation_h, dilation_w)
369
+ ));
370
+
371
+ ////////////////////////////////////////////////////////////////////////////////////
372
+ // Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1)
373
+ ////////////////////////////////////////////////////////////////////////////////////
374
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
375
+ {1, 15, 19, 160}, // input size (NHWC)
376
+ {224, 1, 1, 160}, // filter size (KRSC)
377
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
378
+ {1, 1}, // stride (stride_h, stride_w)
379
+ {1, 1} // dilation (dilation_h, dilation_w)
380
+ ));
381
+
382
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
383
+ {1, 19, 37, 160}, // input size (NHWC)
384
+ {224, 3, 3, 160}, // filter size (KRSC)
385
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
386
+ {2, 2}, // stride (stride_h, stride_w)
387
+ {1, 1} // dilation (dilation_h, dilation_w)
388
+ ));
389
+
390
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
391
+ {1, 16, 16, 160}, // input size (NHWC)
392
+ {224, 2, 3, 160}, // filter size (KRSC)
393
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
394
+ {1, 1}, // stride (stride_h, stride_w)
395
+ {1, 1} // dilation (dilation_h, dilation_w)
396
+ ));
397
+
398
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
399
+ {1, 23, 21, 128}, // input size (NHWC)
400
+ {224, 3, 3, 128}, // filter size (KRSC)
401
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
402
+ {1, 1}, // stride (stride_h, stride_w)
403
+ {1, 1} // dilation (dilation_h, dilation_w)
404
+ ));
405
+
406
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
407
+ {1, 29, 37, 160}, // input size (NHWC)
408
+ {224, 5, 5, 160}, // filter size (KRSC)
409
+ {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _)
410
+ {1, 1}, // stride (stride_h, stride_w)
411
+ {1, 1} // dilation (dilation_h, dilation_w)
412
+ ));
413
+
414
+ ////////////////////////////////////////////////////////////////////////////////////
415
+ // C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64}
416
+ ////////////////////////////////////////////////////////////////////////////////////
417
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
418
+ {1, 15, 19, 32 + minimum_channel_size}, // input size (NHWC)
419
+ {96, 3, 3, 32 + minimum_channel_size}, // filter size (KRSC)
420
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
421
+ {1, 1}, // stride (stride_h, stride_w)
422
+ {1, 1} // dilation (dilation_h, dilation_w)
423
+ ));
424
+
425
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
426
+ {1, 16, 24, 64 + minimum_channel_size}, // input size (NHWC)
427
+ {96, 3, 3, 64 + minimum_channel_size}, // filter size (KRSC)
428
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
429
+ {1, 1}, // stride (stride_h, stride_w)
430
+ {1, 1} // dilation (dilation_h, dilation_w)
431
+ ));
432
+
433
+ ////////////////////////////////////////////////////////////////////////////////////
434
+ // Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2)
435
+ ////////////////////////////////////////////////////////////////////////////////////
436
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
437
+ {1, 13, 16, 288}, // input size (NHWC)
438
+ {160, 5, 5, 288}, // filter size (KRSC)
439
+ {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _)
440
+ {2, 2}, // stride (stride_h, stride_w)
441
+ {1, 1} // dilation (dilation_h, dilation_w)
442
+ ));
443
+
444
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
445
+ {1, 55, 51, 256}, // input size (NHWC)
446
+ {512, 1, 1, 256}, // filter size (KRSC)
447
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
448
+ {2, 2}, // stride (stride_h, stride_w)
449
+ {1, 1} // dilation (dilation_h, dilation_w)
450
+ ));
451
+
452
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
453
+ {1, 71, 80, 32}, // input size (NHWC)
454
+ {64, 5, 5, 32}, // filter size (KRSC)
455
+ {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _)
456
+ {2, 2}, // stride (stride_h, stride_w)
457
+ {1, 1} // dilation (dilation_h, dilation_w)
458
+ ));
459
+
460
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
461
+ {1, 224, 224, 8}, // input size (NHWC)
462
+ {64, 7, 7, 8}, // filter size (KRSC)
463
+ {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _)
464
+ {2, 2}, // stride (stride_h, stride_w)
465
+ {1, 1} // dilation (dilation_h, dilation_w)
466
+ ));
467
+
468
+ ////////////////////////////////////////////////////////////////////////////////////
469
+ // Medium input size stride (3, 3), filter (3, 3), non-default padding
470
+ ////////////////////////////////////////////////////////////////////////////////////
471
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
472
+ {1, 27, 23, 256}, // input size (NHWC)
473
+ {512, 3, 3, 256}, // filter size (KRSC)
474
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
475
+ {3, 3}, // stride (stride_h, stride_w)
476
+ {1, 1} // dilation (dilation_h, dilation_w)
477
+ ));
478
+
479
+ ////////////////////////////////////////////////////////////////////////////////////
480
+ // Medium input size padding > stride, asymmetric filter, padding and striding
481
+ ////////////////////////////////////////////////////////////////////////////////////
482
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
483
+ {1, 27, 31, 256}, // input size (NHWC)
484
+ {512, 3, 3, 256}, // filter size (KRSC)
485
+ {5, 5, 7, 7}, // padding (pad_h, _, pad_w, _)
486
+ {3, 4}, // stride (stride_h, stride_w)
487
+ {1, 1} // dilation (dilation_h, dilation_w)
488
+ ));
489
+
490
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
491
+ {1, 27, 35, 256}, // input size (NHWC)
492
+ {512, 7, 5, 256}, // filter size (KRSC)
493
+ {11, 11, 7, 7}, // padding (pad_h, _, pad_w, _)
494
+ {3, 5}, // stride (stride_h, stride_w)
495
+ {1, 1} // dilation (dilation_h, dilation_w)
496
+ ));
497
+
498
+ ////////////////////////////////////////////////////////////////////////////////////
499
+ // Medium input size *mixed* stride (1, 2) and (2, 1),
500
+ // filter (3, 3), default padding
501
+ ////////////////////////////////////////////////////////////////////////////////////
502
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
503
+ {1, 27, 27, 256}, // input size (NHWC)
504
+ {512, 3, 3, 256}, // filter size (KRSC)
505
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
506
+ {1, 2}, // stride (stride_h, stride_w)
507
+ {1, 1} // dilation (dilation_h, dilation_w)
508
+ ));
509
+
510
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
511
+ {1, 27, 27, 256}, // input size (NHWC)
512
+ {512, 3, 3, 256}, // filter size (KRSC)
513
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
514
+ {2, 1}, // stride (stride_h, stride_w)
515
+ {1, 1} // dilation (dilation_h, dilation_w)
516
+ ));
517
+
518
+ /////////////////////////////////////////////////////////////////////////////
519
+ // Additional input size
520
+ /////////////////////////////////////////////////////////////////////////////
521
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
522
+ {3, 28, 28, 256}, // input size (NHWC)
523
+ {256, 2, 2, 256}, // filter size (KRSC)
524
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
525
+ {2, 2}, // stride (stride_h, stride_w)
526
+ {1, 1} // dilation (dilation_h, dilation_w)
527
+ ));
528
+
529
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
530
+ {1, 32, 32, 16}, // input size (NHWC)
531
+ {32, 3, 3, 16}, // filter size (KRSC)
532
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
533
+ {6, 2}, // stride (stride_h, stride_w)
534
+ {1, 1} // dilation (dilation_h, dilation_w)
535
+ ));
536
+
537
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
538
+ {32, 24, 32, 32}, // input size (NHWC)
539
+ {32, 1, 2, 32}, // filter size (KRSC)
540
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
541
+ {1, 1}, // stride (stride_h, stride_w)
542
+ {1, 1} // dilation (dilation_h, dilation_w)
543
+ ));
544
+
545
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
546
+ {4, 4, 5, 128}, // input size (NHWC)
547
+ {256, 3, 6, 128}, // filter size (KRSC)
548
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
549
+ {1, 1}, // stride (stride_h, stride_w)
550
+ {1, 1}, // dilation (dilation_h, dilation_w)
551
+ {4, 3, 3, 256} // output size (NPQK)
552
+ ));
553
+
554
+ conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize(
555
+ {4, 2, 3, 256}, // input size (NHWC)
556
+ {328, 3, 5, 256}, // filter size (KRSC)
557
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
558
+ {1, 1}, // stride (stride_h, stride_w)
559
+ {1, 1}, // dilation (dilation_h, dilation_w)
560
+ {4, 1, 1, 328} // output size (NPQK)
561
+ ));
562
+ }
563
+
564
+
565
+ // Add a few large and rigorous convolution problem sizes
566
+ void initialize_conv2d_rigorous_sizes() {
567
+
568
+ #if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
569
+ conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize(
570
+ {1, 124, 224, 96}, // input size (NHWC)
571
+ {24, 7, 7, 96}, // filter size (KRSC)
572
+ {1, 229, 129, 32} // output size (NPQK)
573
+ ));
574
+
575
+ conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize(
576
+ {1, 233, 35, 48}, // input size (NHWC)
577
+ {24, 7, 5, 48}, // filter size (KRSC)
578
+ {1, 233, 35, 24} // output size (NPQK)
579
+ ));
580
+
581
+ #endif
582
+
583
+ }
584
+
585
+
586
+ // Add resent50 layers to unit testing sizes
587
+ void initialize_conv2d_resnet50_sizes(Conv2dProblemVector &conv2d_problem_vector, int batch_size = 1){
588
+
589
+ #if 0 // Resnet50 first layer (layer_id = 0) with channel = 3 is not supported in cutlass
590
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
591
+ [1, 224, 224, 3], // input size (NHWC)
592
+ [64, 7, 7, 3], // filter size (KRSC)
593
+ [3, 3, 3, 3], // padding (pad_h, _, pad_w, _)
594
+ [2, 2], // stride (stride_h, stride_w)
595
+ [1, 1], // dilation (dilation_h, dilation_w)
596
+ ));
597
+ #endif
598
+
599
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
600
+ {batch_size, 56, 56, 64}, // input size (NHWC)
601
+ {256, 1, 1, 64}, // filter size (KRSC)
602
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
603
+ {1, 1}, // stride (stride_h, stride_w)
604
+ {1, 1} // dilation (dilation_h, dilation_w)
605
+ ));
606
+
607
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
608
+ {batch_size, 56, 56, 64}, // input size (NHWC)
609
+ {64, 1, 1, 64}, // filter size (KRSC)
610
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
611
+ {1, 1}, // stride (stride_h, stride_w)
612
+ {1, 1} // dilation (dilation_h, dilation_w)
613
+ ));
614
+
615
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
616
+ {batch_size, 56, 56, 64}, // input size (NHWC)
617
+ {64, 3, 3, 64}, // filter size (KRSC)
618
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
619
+ {1, 1}, // stride (stride_h, stride_w)
620
+ {1, 1} // dilation (dilation_h, dilation_w)
621
+ ));
622
+
623
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
624
+ {batch_size, 56, 56, 256}, // input size (NHWC)
625
+ {64, 1, 1, 256}, // filter size (KRSC)
626
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
627
+ {1, 1}, // stride (stride_h, stride_w)
628
+ {1, 1} // dilation (dilation_h, dilation_w)
629
+ ));
630
+
631
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
632
+ {batch_size, 56, 56, 256}, // input size (NHWC)
633
+ {512, 1, 1, 256}, // filter size (KRSC)
634
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
635
+ {2, 2}, // stride (stride_h, stride_w)
636
+ {1, 1} // dilation (dilation_h, dilation_w)
637
+ ));
638
+
639
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
640
+ {batch_size, 56, 56, 256}, // input size (NHWC)
641
+ {128, 1, 1, 256}, // filter size (KRSC)
642
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
643
+ {2, 2}, // stride (stride_h, stride_w)
644
+ {1, 1} // dilation (dilation_h, dilation_w)
645
+ ));
646
+
647
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
648
+ {batch_size, 28, 28, 128}, // input size (NHWC)
649
+ {128, 3, 3, 128}, // filter size (KRSC)
650
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
651
+ {1, 1}, // stride (stride_h, stride_w)
652
+ {1, 1} // dilation (dilation_h, dilation_w)
653
+ ));
654
+
655
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
656
+ {batch_size, 28, 28, 128}, // input size (NHWC)
657
+ {512, 1, 1, 128}, // filter size (KRSC)
658
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
659
+ {1, 1}, // stride (stride_h, stride_w)
660
+ {1, 1} // dilation (dilation_h, dilation_w)
661
+ ));
662
+
663
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
664
+ {batch_size, 28, 28, 512}, // input size (NHWC)
665
+ {128, 1, 1, 512}, // filter size (KRSC)
666
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
667
+ {1, 1}, // stride (stride_h, stride_w)
668
+ {1, 1} // dilation (dilation_h, dilation_w)
669
+ ));
670
+
671
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
672
+ {batch_size, 28, 28, 512}, // input size (NHWC)
673
+ {1024, 1, 1, 512}, // filter size (KRSC)
674
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
675
+ {2, 2}, // stride (stride_h, stride_w)
676
+ {1, 1} // dilation (dilation_h, dilation_w)
677
+ ));
678
+
679
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
680
+ {batch_size, 28, 28, 512}, // input size (NHWC)
681
+ {256, 1, 1, 512}, // filter size (KRSC)
682
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
683
+ {2, 2}, // stride (stride_h, stride_w)
684
+ {1, 1} // dilation (dilation_h, dilation_w)
685
+ ));
686
+
687
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
688
+ {batch_size, 14, 14, 256}, // input size (NHWC)
689
+ {256, 3, 3, 256}, // filter size (KRSC)
690
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
691
+ {1, 1}, // stride (stride_h, stride_w)
692
+ {1, 1} // dilation (dilation_h, dilation_w)
693
+ ));
694
+
695
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
696
+ {batch_size, 14, 14, 256}, // input size (NHWC)
697
+ {1024, 1, 1, 256}, // filter size (KRSC)
698
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
699
+ {1, 1}, // stride (stride_h, stride_w)
700
+ {1, 1} // dilation (dilation_h, dilation_w)
701
+ ));
702
+
703
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
704
+ {batch_size, 14, 14, 1024}, // input size (NHWC)
705
+ {256, 1, 1, 1024}, // filter size (KRSC)
706
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
707
+ {1, 1}, // stride (stride_h, stride_w)
708
+ {1, 1} // dilation (dilation_h, dilation_w)
709
+ ));
710
+
711
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
712
+ {batch_size, 14, 14, 1024}, // input size (NHWC)
713
+ {2048, 1, 1, 1024}, // filter size (KRSC)
714
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
715
+ {2, 2}, // stride (stride_h, stride_w)
716
+ {1, 1} // dilation (dilation_h, dilation_w)
717
+ ));
718
+
719
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
720
+ {batch_size, 14, 14, 1024}, // input size (NHWC)
721
+ {512, 1, 1, 1024}, // filter size (KRSC)
722
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
723
+ {2, 2}, // stride (stride_h, stride_w)
724
+ {1, 1} // dilation (dilation_h, dilation_w)
725
+ ));
726
+
727
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
728
+ {batch_size, 7, 7, 512}, // input size (NHWC)
729
+ {512, 3, 3, 512}, // filter size (KRSC)
730
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
731
+ {1, 1}, // stride (stride_h, stride_w)
732
+ {1, 1} // dilation (dilation_h, dilation_w)
733
+ ));
734
+
735
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
736
+ {batch_size, 7, 7, 512}, // input size (NHWC)
737
+ {2048, 1, 1, 512}, // filter size (KRSC)
738
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
739
+ {1, 1}, // stride (stride_h, stride_w)
740
+ {1, 1} // dilation (dilation_h, dilation_w)
741
+ ));
742
+
743
+ conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize(
744
+ {batch_size, 7, 7, 2048}, // input size (NHWC)
745
+ {512, 1, 1, 2048}, // filter size (KRSC)
746
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
747
+ {1, 1}, // stride (stride_h, stride_w)
748
+ {1, 1} // dilation (dilation_h, dilation_w)
749
+ ));
750
+ }
751
+
752
+ };
753
+
754
+
755
+ ////////////////////////////////////////////////////////////////////////////
756
+ /// Structure TestbedGroupConv2dProblemSizes initializes and holds group conv default and
757
+ /// important network sizes
758
+ ////////////////////////////////////////////////////////////////////////////
759
+ struct TestbedGroupConv2dProblemSizes {
760
+
761
+ //
762
+ // Data members
763
+ //
764
+ int threadblock_n;
765
+ int threadblock_k;
766
+ int minimum_channel_size;
767
+
768
+ Conv2dProblemVector default_single_group_sizes;
769
+ Conv2dProblemVector default_multiple_group_sizes;
770
+
771
+ //
772
+ // Methods
773
+ //
774
+ /// Default ctor
775
+ TestbedGroupConv2dProblemSizes(
776
+ int threadblock_n_,
777
+ int threadblock_k_,
778
+ int minimum_channel_size_ = 64)
779
+ : threadblock_n (threadblock_n_),
780
+ threadblock_k (threadblock_k_),
781
+ minimum_channel_size (minimum_channel_size_) {
782
+ initialize_group_conv2d_default_sizes();
783
+ filter_all();
784
+ }
785
+
786
+ /// Eliminates some illegal cases
787
+ void filter_all() {
788
+
789
+ Conv2dProblemVector *problems_vectors[] = {
790
+ &default_single_group_sizes,
791
+ &default_multiple_group_sizes
792
+ };
793
+
794
+ for (Conv2dProblemVector *problems : problems_vectors) {
795
+ Conv2dProblemVector filtered;
796
+
797
+ for (cutlass::conv::Conv2dProblemSize const & problem : *problems) {
798
+ if (!((problem.C / problem.groups) % minimum_channel_size)) {
799
+ filtered.push_back(problem);
800
+ }
801
+ }
802
+
803
+ *problems = filtered;
804
+ }
805
+ }
806
+
807
+ // Add a few standard convolution problem sizes
808
+ void initialize_group_conv2d_default_sizes() {
809
+
810
+ ////////////////////////////////////////////////////////////////////////////////////
811
+ // One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0
812
+ // One CTA calculates a single group
813
+ ////////////////////////////////////////////////////////////////////////////////////
814
+
815
+ for (int cta_per_group_k = 1; cta_per_group_k < 4; ++cta_per_group_k) {
816
+ // groups = 2, 3, 4
817
+ for (int groups = 2; groups < 5; ++groups) {
818
+
819
+ int conv_k = cta_per_group_k * threadblock_n * groups;
820
+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
821
+ {1, 8, 8, threadblock_k * 2 * groups}, // input size (NHWC)
822
+ {conv_k, 3, 3, threadblock_k * 2}, // filter size (KRSC)
823
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
824
+ {1, 1}, // stride (stride_h, stride_w)
825
+ {1, 1}, // dilation (dilation_h, dilation_w)
826
+ cutlass::conv::Mode::kCrossCorrelation,
827
+ 1, // split_k_slices
828
+ groups // groups
829
+ ));
830
+
831
+ } // loop groups
832
+ } // loop cta_per_group_k
833
+
834
+ // Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K
835
+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
836
+ {1, 8, 8, threadblock_k}, // input size (NHWC)
837
+ {threadblock_n * 2, 3, 3, threadblock_k / 2}, // filter size (KRSC)
838
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
839
+ {1, 1}, // stride (stride_h, stride_w)
840
+ {1, 1}, // dilation (dilation_h, dilation_w)
841
+ cutlass::conv::Mode::kCrossCorrelation,
842
+ 1, // split_k_slices
843
+ 2 // groups
844
+ ));
845
+
846
+ // Larger problem sizes
847
+
848
+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
849
+ {1, 56, 56, 696}, // input size (NHWC)
850
+ {768, 3, 3, 232}, // filter size (KRSC)
851
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
852
+ {2, 2}, // stride (stride_h, stride_w)
853
+ {1, 1}, // dilation (dilation_h, dilation_w)
854
+ cutlass::conv::Mode::kCrossCorrelation,
855
+ 1, // split_k_slices
856
+ 3 // groups
857
+ ));
858
+ default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
859
+ {1, 14, 14, 1392}, // input size (NHWC)
860
+ {1536, 3, 3, 232}, // filter size (KRSC)
861
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
862
+ {1, 1}, // stride (stride_h, stride_w)
863
+ {1, 1}, // dilation (dilation_h, dilation_w)
864
+ cutlass::conv::Mode::kCrossCorrelation,
865
+ 1, // split_k_slices
866
+ 3 // groups
867
+ ));
868
+
869
+ ////////////////////////////////////////////////////////////////////////////////////
870
+ // One CTA calculate multiple groups: CTA::N % k_per_group = 0
871
+ ////////////////////////////////////////////////////////////////////////////////////
872
+
873
+ // 2 groups per CTA
874
+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
875
+ {1, 8, 8, threadblock_k * 4}, // input size (NHWC)
876
+ {threadblock_n, 3, 3, threadblock_k * 2}, // filter size (KRSC)
877
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
878
+ {1, 1}, // stride (stride_h, stride_w)
879
+ {1, 1}, // dilation (dilation_h, dilation_w)
880
+ cutlass::conv::Mode::kCrossCorrelation,
881
+ 1, // split_k_slices
882
+ 2 // groups
883
+ ));
884
+
885
+ // 2 groups per CTA and partial gemm_k
886
+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
887
+ {1, 8, 8, threadblock_k}, // input size (NHWC)
888
+ {threadblock_n, 3, 3, threadblock_k / 2}, // filter size (KRSC)
889
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
890
+ {1, 1}, // stride (stride_h, stride_w)
891
+ {1, 1}, // dilation (dilation_h, dilation_w)
892
+ cutlass::conv::Mode::kCrossCorrelation,
893
+ 1, // split_k_slices
894
+ 2 // groups
895
+ ));
896
+
897
+ // 4 groups per CTA
898
+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
899
+ {1, 8, 8, threadblock_k * 8}, // input size (NHWC)
900
+ {threadblock_n / 2, 3, 3, threadblock_k * 2}, // filter size (KRSC)
901
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
902
+ {1, 1}, // stride (stride_h, stride_w)
903
+ {1, 1}, // dilation (dilation_h, dilation_w)
904
+ cutlass::conv::Mode::kCrossCorrelation,
905
+ 1, // split_k_slices
906
+ 4 // groups
907
+ ));
908
+
909
+ // 4 groups per CTA and partial gemm_k
910
+ default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize(
911
+ {1, 8, 8, threadblock_k * 2}, // input size (NHWC)
912
+ {threadblock_n / 2, 3, 3, threadblock_k / 2}, // filter size (KRSC)
913
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
914
+ {1, 1}, // stride (stride_h, stride_w)
915
+ {1, 1}, // dilation (dilation_h, dilation_w)
916
+ cutlass::conv::Mode::kCrossCorrelation,
917
+ 1, // split_k_slices
918
+ 4 // groups
919
+ ));
920
+ }
921
+
922
+ };
923
+
924
+
925
+ } // namespace device
926
+ } // namespace conv
927
+ } // namespace test
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h ADDED
@@ -0,0 +1,818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implicit GEMM testbed
33
+ */
34
+ #pragma once
35
+
36
+ #include <fstream>
37
+
38
+ #include "../../common/cutlass_unit_test.h"
39
+ #include "cutlass/cutlass.h"
40
+
41
+ #include "cutlass/conv/device/implicit_gemm_convolution.h"
42
+ #include "cutlass/reduction/device/reduce_split_k.h"
43
+ #include "cutlass/reduction/thread/reduction_operators.h"
44
+
45
+ #include "conv2d_problems.h"
46
+
47
+ #include "cutlass/util/host_tensor.h"
48
+ #include "cutlass/util/reference/host/tensor_fill.h"
49
+ #include "cutlass/util/reference/device/tensor_compare.h"
50
+ #include "cutlass/util/reference/host/tensor_compare.h"
51
+
52
+ #include "cutlass/util/reference/host/convolution.h"
53
+ #include "cutlass/util/reference/device/convolution.h"
54
+
55
+ #include "cutlass/core_io.h"
56
+ #include "cutlass/util/tensor_view_io.h"
57
+
58
+ #include "../cache_testbed_output.h"
59
+
60
+ namespace test {
61
+ namespace conv {
62
+ namespace device {
63
+
64
+ template <typename Conv2d>
65
+ class TestbedConv2d {
66
+ public:
67
+
68
+ using ElementA = typename Conv2d::ElementA;
69
+ using LayoutA = typename Conv2d::LayoutA;
70
+ using ElementB = typename Conv2d::ElementB;
71
+ using LayoutB = typename Conv2d::LayoutB;
72
+ using ElementC = typename Conv2d::ElementC;
73
+ using LayoutC = typename Conv2d::LayoutC;
74
+ using ElementAccumulator = typename Conv2d::ElementAccumulator;
75
+ using ElementCompute = typename Conv2d::ElementCompute;
76
+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
77
+
78
+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
79
+
80
+ /// Reduction kernel
81
+ using ReductionOp = cutlass::reduction::thread::ReduceAdd<
82
+ ElementAccumulator,
83
+ typename EpilogueOutputOp::ElementAccumulator,
84
+ EpilogueOutputOp::kCount
85
+ >;
86
+
87
+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
88
+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
89
+ EpilogueOutputOp,
90
+ ReductionOp
91
+ >;
92
+
93
+ using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
94
+ using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
95
+
96
+ public:
97
+
98
+ /// Initialization
99
+ cutlass::Distribution::Kind init_A;
100
+ cutlass::Distribution::Kind init_B;
101
+ cutlass::Distribution::Kind init_C;
102
+ uint64_t seed;
103
+
104
+ cutlass::HostTensor<ElementA, LayoutA> tensor_A;
105
+ cutlass::HostTensor<ElementB, LayoutB> tensor_B;
106
+ cutlass::HostTensor<ElementC, LayoutC> tensor_C;
107
+ cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
108
+ cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
109
+
110
+ int tested_problem_count;
111
+
112
+ public:
113
+
114
+ TestbedConv2d(
115
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
116
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
117
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
118
+ uint64_t seed_ = 2080
119
+ ):
120
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {
121
+
122
+ }
123
+
124
+ /// Helper to initialize a tensor view
125
+ template <typename Element, typename Layout>
126
+ void initialize_tensor(
127
+ cutlass::TensorView<Element, Layout> view,
128
+ cutlass::Distribution::Kind dist_kind,
129
+ uint64_t seed) {
130
+
131
+ if (dist_kind == cutlass::Distribution::Uniform) {
132
+
133
+ int scope;
134
+ int bits = cutlass::sizeof_bits<Element>::value;
135
+
136
+ if (bits <= 8) {
137
+ scope = 2;
138
+ }
139
+ else if (bits == 16) {
140
+ if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
141
+ scope = 3;
142
+ }
143
+ else {
144
+ scope = 5;
145
+ }
146
+ }
147
+ else {
148
+ scope = 8;
149
+ }
150
+ cutlass::reference::host::TensorFillRandomUniform(
151
+ view, seed, scope, -scope, 0);
152
+ }
153
+ else if (dist_kind == cutlass::Distribution::Identity) {
154
+
155
+ cutlass::reference::host::TensorFillIdentity(view);
156
+ }
157
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
158
+
159
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
160
+ }
161
+ else if (dist_kind == cutlass::Distribution::Sequential) {
162
+
163
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
164
+ }
165
+ else {
166
+ }
167
+ }
168
+
169
+ void initialize(
170
+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
171
+
172
+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
173
+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
174
+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
175
+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
176
+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
177
+
178
+ initialize_tensor(tensor_A.host_view(), init_A, seed);
179
+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
180
+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
181
+
182
+ tensor_A.sync_device();
183
+ tensor_B.sync_device();
184
+ tensor_C.sync_device();
185
+ tensor_D_computed.sync_device();
186
+ tensor_D_reference.sync_device();
187
+ }
188
+
189
+ bool sufficient() const {
190
+ //
191
+ // Determine SMEM requirements and waive if not satisfied
192
+ //
193
+
194
+ size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage);
195
+
196
+ cudaDeviceProp properties;
197
+ int device_idx;
198
+ cudaError_t result = cudaGetDevice(&device_idx);
199
+
200
+ if (result != cudaSuccess) {
201
+ throw std::runtime_error("cudaGetDevice() API call failed.");
202
+ }
203
+
204
+ result = cudaGetDeviceProperties(&properties, device_idx);
205
+
206
+ if (result != cudaSuccess) {
207
+ throw std::runtime_error("cudaGetDeviceProperties() failed");
208
+ }
209
+
210
+ if (properties.sharedMemPerBlockOptin < smem_size) {
211
+ return false;
212
+ }
213
+
214
+ return true;
215
+ }
216
+
217
+ /// Executes one test
218
+ bool run(
219
+ cutlass::conv::Conv2dProblemSize const &problem_size,
220
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
221
+ ElementCompute alpha = ElementCompute(1),
222
+ ElementCompute beta = ElementCompute(0)) {
223
+
224
+ // Waive test if insufficient CUDA device
225
+ if (!sufficient()) {
226
+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
227
+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
228
+ }
229
+ return true;
230
+ }
231
+
232
+ // increment tested problem count run by the testbed
233
+ tested_problem_count++;
234
+
235
+ #if 0 // display conv2d problem size for debugging
236
+ std::cout << problem_size << std::endl
237
+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
238
+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
239
+ << std::endl;
240
+ #endif
241
+
242
+ initialize(problem_size);
243
+
244
+ // configure the operator
245
+ Conv2d conv2d_op;
246
+
247
+ typename Conv2d::Arguments conv2d_args(
248
+ problem_size,
249
+ tensor_A.device_ref(),
250
+ tensor_B.device_ref(),
251
+ tensor_C.device_ref(),
252
+ tensor_D_computed.device_ref(),
253
+ {alpha, beta},
254
+ split_k_mode
255
+ );
256
+
257
+ // find workspace requirement for parallel split-k reduction
258
+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
259
+
260
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
261
+
262
+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get());
263
+
264
+ if (status != cutlass::Status::kSuccess) {
265
+ cudaError_t error = cudaGetLastError();
266
+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
267
+ return true;
268
+ }
269
+
270
+ // conv2d operation with parallel split-k-mode
271
+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
272
+
273
+ // conv2d output is written to workspace in global memory
274
+ conv2d_args.ref_D.reset(reinterpret_cast<ElementC*>(workspace.get()));
275
+ // accumulate mma for each cta in k-dimension (1.0 * A * B)
276
+ conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)};
277
+ // update conv2d operator arguments
278
+ status = conv2d_op.update(conv2d_args, workspace.get());
279
+ }
280
+
281
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
282
+ if (status != cutlass::Status::kSuccess) {
283
+ return false;
284
+ }
285
+
286
+ // run conv2d operator
287
+ status = conv2d_op();
288
+
289
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
290
+ if (status != cutlass::Status::kSuccess) {
291
+ std::cerr << "Failed to run." << std::endl;
292
+ return false;
293
+ }
294
+
295
+
296
+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
297
+
298
+ // configure parallel reduction operator
299
+ ReductionDevice reduction_op;
300
+
301
+ typename ReductionDevice::Arguments reduction_args(
302
+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
303
+ problem_size.split_k_slices,
304
+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
305
+ {
306
+ reinterpret_cast<ElementAccumulator*> (workspace.get()),
307
+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
308
+ },
309
+ {
310
+ tensor_D_computed.device_data(),
311
+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
312
+ },
313
+ {
314
+ tensor_C.device_data(),
315
+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
316
+ },
317
+ // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C
318
+ {alpha, beta}
319
+ );
320
+
321
+ status = reduction_op.initialize(reduction_args, nullptr);
322
+
323
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
324
+ if (status != cutlass::Status::kSuccess) {
325
+ return false;
326
+ }
327
+
328
+ // run prallel reduction kernel
329
+ status = reduction_op();
330
+
331
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
332
+ if (status != cutlass::Status::kSuccess) {
333
+ return false;
334
+ }
335
+ }
336
+ bool passed = false;
337
+
338
+ cudaError_t result = cudaDeviceSynchronize();
339
+ EXPECT_EQ(result, cudaSuccess) << " device reference error: "
340
+ << cudaGetErrorString(result);
341
+
342
+ tensor_D_computed.sync_host();
343
+
344
+ //
345
+ // Reference check - support caching results
346
+ //
347
+
348
+ CachedTestKey cached_test_key = CreateCachedConv2dTestKey<
349
+ ElementA, LayoutA,
350
+ ElementB, LayoutB,
351
+ ElementC, LayoutC,
352
+ ElementAccumulator,
353
+ ElementCompute
354
+ >(
355
+ kConvolutionalOperator,
356
+ problem_size,
357
+ alpha,
358
+ beta,
359
+ tensor_A.host_view(),
360
+ tensor_B.host_view(),
361
+ tensor_C.host_view()
362
+ );
363
+
364
+ //
365
+ // Look for the cached key
366
+ //
367
+
368
+ bool cached_result_loaded = false;
369
+ CachedTestResult cached_test_result;
370
+
371
+ std::string conv2d_result_cache_name =
372
+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
373
+
374
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
375
+
376
+ CachedTestResultListing cached_results(conv2d_result_cache_name);
377
+
378
+ auto cached = cached_results.find(cached_test_key);
379
+
380
+ cached_result_loaded = cached.first;
381
+ if (cached_result_loaded) {
382
+ cached_test_result = cached.second;
383
+ }
384
+ }
385
+
386
+ if (!cached_result_loaded) {
387
+
388
+ #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
389
+
390
+ cutlass::reference::device::Conv2d<
391
+ ElementA,
392
+ LayoutA,
393
+ ElementB,
394
+ LayoutB,
395
+ ElementC,
396
+ LayoutC,
397
+ ElementCompute,
398
+ ElementAccumulator
399
+ >(
400
+ kConvolutionalOperator,
401
+ problem_size,
402
+ tensor_A.device_ref(),
403
+ tensor_B.device_ref(),
404
+ tensor_C.device_ref(),
405
+ tensor_D_reference.device_ref(),
406
+ alpha,
407
+ beta);
408
+
409
+ // sync host (copy device data to host) for dumping error output in case of mismatches
410
+ tensor_D_reference.sync_host();
411
+
412
+ #else
413
+
414
+ cutlass::reference::host::Conv2d<
415
+ ElementA,
416
+ LayoutA,
417
+ ElementB,
418
+ LayoutB,
419
+ ElementC,
420
+ LayoutC,
421
+ ElementCompute,
422
+ ElementAccumulator
423
+ >(
424
+ kConvolutionalOperator,
425
+ problem_size,
426
+ tensor_A.host_ref(),
427
+ tensor_B.host_ref(),
428
+ tensor_C.host_ref(),
429
+ tensor_D_reference.host_ref(),
430
+ alpha,
431
+ beta);
432
+
433
+ #endif
434
+
435
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
436
+
437
+ cached_test_result.D = TensorHash(tensor_D_reference.host_view());
438
+
439
+ CachedTestResultListing cached_results(conv2d_result_cache_name);
440
+
441
+ cached_results.append(cached_test_key, cached_test_result);
442
+ cached_results.write(conv2d_result_cache_name);
443
+ }
444
+ } // if (!cached_result_loaded)
445
+
446
+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
447
+
448
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
449
+ passed = (tensor_D_hash == cached_test_result.D);
450
+
451
+ EXPECT_EQ(tensor_D_hash, cached_test_result.D)
452
+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
453
+ }
454
+ else {
455
+
456
+ passed = cutlass::reference::host::TensorEquals(
457
+ tensor_D_computed.host_view(),
458
+ tensor_D_reference.host_view());
459
+ }
460
+
461
+ EXPECT_TRUE(passed);
462
+
463
+ std::stringstream ss_problem_size_text;
464
+ ss_problem_size_text << "nhwc_"
465
+ << problem_size.N << "x"
466
+ << problem_size.H << "x"
467
+ << problem_size.W << "x"
468
+ << problem_size.C
469
+ << "_krsc_"
470
+ << problem_size.K << "x"
471
+ << problem_size.R << "x"
472
+ << problem_size.S << "x"
473
+ << problem_size.C
474
+ << "_padding_"
475
+ << problem_size.pad_h << "x"
476
+ << problem_size.pad_w
477
+ << "_stride_"
478
+ << problem_size.stride_h << "x"
479
+ << problem_size.stride_w
480
+ << "_dilation_"
481
+ << problem_size.dilation_h << "x"
482
+ << problem_size.dilation_w << "_"
483
+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_");
484
+
485
+ if (!passed) {
486
+ std::stringstream fname;
487
+
488
+ fname << "error_Conv2d_ImplicitGemm_device_"
489
+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
490
+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
491
+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" :
492
+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_")))
493
+ << ss_problem_size_text.str()
494
+ << Conv2d::ThreadblockShape::kM << "x"
495
+ << Conv2d::ThreadblockShape::kN << "x"
496
+ << Conv2d::ThreadblockShape::kK << "_"
497
+ << Conv2d::WarpShape::kM << "x"
498
+ << Conv2d::WarpShape::kN << "x"
499
+ << Conv2d::WarpShape::kK << ".txt";
500
+
501
+ std::cout << fname.str() << std::endl;
502
+
503
+ std::ofstream results(fname.str());
504
+
505
+ results << problem_size << std::endl;
506
+
507
+ results
508
+ << "\nA:\n" << tensor_A.host_view() << "\n"
509
+ << "\nB:\n" << tensor_B.host_view() << "\n"
510
+ << "\nC:\n" << tensor_C.host_view() << "\n";
511
+
512
+ results << "\nD reference (hash: " << cached_test_result.D << ")\n";
513
+
514
+ if (!cached_result_loaded) {
515
+ results
516
+ << tensor_D_reference.host_view() << "\n";
517
+ }
518
+
519
+ results
520
+ << "\nD computed (hash: " << tensor_D_hash << ")\n"
521
+ << tensor_D_computed.host_view() << "\n";
522
+
523
+ }
524
+
525
+ return passed;
526
+ }
527
+
528
+ };
529
+
530
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
531
+
532
+ template <typename ImplicitGemm>
533
+ bool TestSpecificConv2d(
534
+ const Conv2dProblemVector & problem_sizes) {
535
+
536
+ bool passed = true;
537
+
538
+ //
539
+ // Testbed object
540
+ //
541
+
542
+ TestbedConv2d<ImplicitGemm> testbed;
543
+
544
+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
545
+ for(auto conv_problem : problem_sizes) {
546
+
547
+ //
548
+ // Test
549
+ //
550
+
551
+ // test mode = xcross
552
+ passed = testbed.run(
553
+ conv_problem,
554
+ cutlass::conv::SplitKMode::kSerial);
555
+
556
+ if (!passed) {
557
+ return false;
558
+ }
559
+
560
+ // test mode = convolution
561
+ passed = testbed.run(
562
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
563
+ cutlass::conv::SplitKMode::kSerial);
564
+
565
+ if (!passed) {
566
+ return false;
567
+ }
568
+ }
569
+
570
+ return true;
571
+ }
572
+
573
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////
574
+ // TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
575
+ // TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
576
+ // Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
577
+ // (conv_blacklist_sizes)
578
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
579
+ template <typename ImplicitGemm>
580
+ bool TestAllConv2d(
581
+ const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(),
582
+ const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) {
583
+
584
+ bool passed = true;
585
+
586
+ //
587
+ // Testbed object
588
+ //
589
+
590
+ TestbedConv2d<ImplicitGemm> testbed;
591
+
592
+ //
593
+ // Get conv problem sizes to run conv operator
594
+ //
595
+ TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
596
+
597
+ // Vector of conv2d problem sizes to avoid duplicate runs
598
+ Conv2dProblemVector conv_tested_sizes;
599
+
600
+ // Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes)
601
+ std::vector<Conv2dProblemVector> problem_vectors = {
602
+ conv_test_sizes, // run user specified sizes
603
+ conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
604
+ //conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
605
+ #if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
606
+ conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
607
+ #endif
608
+ };
609
+
610
+ // Flatten 2D problem_vectors into a 1D problem_sizes
611
+ std::vector<cutlass::conv::Conv2dProblemSize> problem_sizes;
612
+ for (auto problem_vector : problem_vectors) {
613
+ for(auto conv_problem : problem_vector) {
614
+ problem_sizes.push_back(conv_problem);
615
+ }
616
+ }
617
+
618
+ // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient)
619
+ // run the most rigorous problem size first
620
+ if (CutlassUnitTestProblemCount()) {
621
+ std::reverse(problem_sizes.begin(), problem_sizes.end());
622
+ }
623
+
624
+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
625
+ for(auto conv_problem : problem_sizes) {
626
+
627
+ // Skip blacklist and avoid duplicate problem sizes
628
+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
629
+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
630
+ continue;
631
+ }
632
+
633
+ //
634
+ // Procedurally disable certain cases
635
+ //
636
+
637
+ // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
638
+ if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
639
+ ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
640
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
641
+ cutlass::conv::StrideSupport::kUnity)) {
642
+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
643
+ continue;
644
+ }
645
+ }
646
+
647
+ // Fixed channels algorithm requires channel count to match access size
648
+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
649
+ cutlass::conv::IteratorAlgorithm::kFixedChannels) {
650
+ if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) {
651
+ continue;
652
+ }
653
+ }
654
+
655
+ // Few channels algorithm requires channel count to match access size
656
+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
657
+ cutlass::conv::IteratorAlgorithm::kFewChannels) {
658
+ if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) {
659
+ continue;
660
+ }
661
+ }
662
+
663
+ // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w}
664
+ // Although strided dgrad works for all stride combinations, we are only going
665
+ // to run strided dgrad for non-unity strides
666
+ if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
667
+ ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
668
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
669
+ cutlass::conv::StrideSupport::kStrided)) {
670
+ if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
671
+ continue;
672
+ }
673
+ }
674
+
675
+ //
676
+ // Test
677
+ //
678
+ // push back tested problem size to avoid re-running duplicates
679
+ conv_tested_sizes.push_back(conv_problem);
680
+
681
+ // test mode = xcross
682
+ passed = testbed.run(
683
+ conv_problem,
684
+ cutlass::conv::SplitKMode::kSerial);
685
+
686
+ if (!passed) {
687
+ return false;
688
+ }
689
+
690
+ // test mode = convolution
691
+ passed = testbed.run(
692
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
693
+ cutlass::conv::SplitKMode::kSerial);
694
+
695
+ if (!passed) {
696
+ return false;
697
+ }
698
+
699
+ // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts
700
+ if (CutlassUnitTestProblemCount() &&
701
+ testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
702
+ return true;
703
+ }
704
+ }
705
+
706
+ // Small-channels convolution can't run here.
707
+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
708
+ cutlass::conv::IteratorAlgorithm::kFixedChannels) {
709
+
710
+ return true;
711
+ }
712
+
713
+ // Small-channels convolution can't run here.
714
+ if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm ==
715
+ cutlass::conv::IteratorAlgorithm::kFewChannels) {
716
+
717
+ return true;
718
+ }
719
+
720
+ // CUTLASS DGRAD's *strided* specialization does not support split-k mode
721
+ if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
722
+ ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
723
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
724
+ cutlass::conv::StrideSupport::kStrided)) {
725
+
726
+ passed = testbed.run(
727
+ cutlass::conv::Conv2dProblemSize(
728
+ {1, 56, 56, 8}, // input size (NHWC)
729
+ {8, 1, 1, 8}, // filter size (KRSC)
730
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
731
+ {2, 2}, // stride (stride_h, stride_w)
732
+ {1, 1}), // dilation (dilation_h, dilation_w)
733
+ cutlass::conv::SplitKMode::kSerial,
734
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0),
735
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0));
736
+
737
+ passed = testbed.run(
738
+ cutlass::conv::Conv2dProblemSize(
739
+ {1, 56, 56, 8}, // input size (NHWC)
740
+ {8, 1, 1, 8}, // filter size (KRSC)
741
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
742
+ {1, 1}, // stride (stride_h, stride_w)
743
+ {1, 1}) // dilation (dilation_h, dilation_w)
744
+ .reset_split_k_slices(2),
745
+ cutlass::conv::SplitKMode::kSerial,
746
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0),
747
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0));
748
+
749
+ if (!passed) {
750
+ return false;
751
+ }
752
+
753
+ return passed;
754
+ }
755
+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
756
+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
757
+ // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
758
+ // alpha and beta for local testing, but only runs one value for alpha and beta.
759
+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
760
+ {1, 17, 11, 288}, // input size (NHWC)
761
+ {160, 3, 3, 288}, // filter size (KRSC)
762
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
763
+ {1, 1}, // stride (stride_h, stride_w)
764
+ {1, 1} // dilation (dilation_h, dilation_w)
765
+ );
766
+
767
+ cutlass::conv::SplitKMode split_k_modes [] = {
768
+ cutlass::conv::SplitKMode::kSerial,
769
+ cutlass::conv::SplitKMode::kParallel,
770
+ };
771
+
772
+ int split_k_slices[] = {
773
+ 1, 2, 3, 4, 201
774
+ };
775
+
776
+ double problem_alpha[] = {
777
+ 2.0
778
+ };
779
+
780
+ double problem_beta[] = {
781
+ 2.0
782
+ };
783
+
784
+ for (auto split_k_mode : split_k_modes) {
785
+ for (auto split_k_slice : split_k_slices) {
786
+ for (auto alpha : problem_alpha) {
787
+ for (auto beta : problem_beta) {
788
+
789
+ passed = testbed.run(
790
+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
791
+ split_k_mode,
792
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
793
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
794
+
795
+ if (!passed) {
796
+ return false;
797
+ }
798
+
799
+ // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts
800
+ if (CutlassUnitTestProblemCount() &&
801
+ testbed.tested_problem_count > CutlassUnitTestProblemCount()) {
802
+ return true;
803
+ }
804
+ }
805
+ }
806
+ }
807
+ }
808
+
809
+ return passed;
810
+ }
811
+
812
+ /////////////////////////////////////////////////////////////////////////////////////////////////
813
+
814
+ } // namespace device
815
+ } // namespace conv
816
+ } // namespace test
817
+
818
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implicit GEMM testbed
33
+ */
34
+ #pragma once
35
+
36
+ #include <fstream>
37
+
38
+ #include "../../common/cutlass_unit_test.h"
39
+ #include "cutlass/cutlass.h"
40
+
41
+ #include "cutlass/conv/device/implicit_gemm_convolution.h"
42
+ #include "cutlass/reduction/device/reduce_split_k.h"
43
+ #include "cutlass/reduction/thread/reduction_operators.h"
44
+
45
+ #include "conv2d_problems.h"
46
+
47
+ #include "cutlass/util/host_tensor.h"
48
+ #include "cutlass/util/reference/host/tensor_fill.h"
49
+ #include "cutlass/util/reference/device/tensor_compare.h"
50
+ #include "cutlass/util/reference/host/tensor_compare.h"
51
+ #include "cutlass/util/host_reorder.h"
52
+
53
+ #include "cutlass/util/reference/host/convolution.h"
54
+ #include "cutlass/util/reference/device/convolution.h"
55
+
56
+ #include "cutlass/core_io.h"
57
+ #include "cutlass/util/tensor_view_io.h"
58
+
59
+ #include "../cache_testbed_output.h"
60
+
61
+ namespace test {
62
+ namespace conv {
63
+ namespace device {
64
+
65
+ template <typename Conv2d, int InterleavedK>
66
+ class InterleavedTestbedConv2d {
67
+ public:
68
+
69
+ using ElementA = typename Conv2d::ElementA;
70
+ using LayoutA = typename Conv2d::LayoutA;
71
+ using ElementB = typename Conv2d::ElementB;
72
+ using LayoutB = typename Conv2d::LayoutB;
73
+ using ElementC = typename Conv2d::ElementC;
74
+ using LayoutC = typename Conv2d::LayoutC;
75
+ using ElementAccumulator = typename Conv2d::ElementAccumulator;
76
+ using ElementCompute = typename Conv2d::ElementCompute;
77
+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
78
+
79
+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
80
+
81
+ /// Reduction kernel
82
+ using ReductionOp = cutlass::reduction::thread::ReduceAdd<
83
+ ElementAccumulator,
84
+ typename EpilogueOutputOp::ElementAccumulator,
85
+ EpilogueOutputOp::kCount
86
+ >;
87
+
88
+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
89
+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
90
+ EpilogueOutputOp,
91
+ ReductionOp
92
+ >;
93
+
94
+ using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
95
+ using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
96
+
97
+ public:
98
+
99
+ /// Initialization
100
+ cutlass::Distribution::Kind init_A;
101
+ cutlass::Distribution::Kind init_B;
102
+ cutlass::Distribution::Kind init_C;
103
+ uint64_t seed;
104
+
105
+ cutlass::HostTensor<ElementA, LayoutA> tensor_A;
106
+ cutlass::HostTensor<ElementB, LayoutB> tensor_B;
107
+ cutlass::HostTensor<ElementB, LayoutB> tensor_B_reordered;
108
+ cutlass::HostTensor<ElementC, LayoutC> tensor_C;
109
+ cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
110
+ cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
111
+
112
+ public:
113
+
114
+ InterleavedTestbedConv2d(
115
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
116
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
117
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
118
+ uint64_t seed_ = 2080
119
+ ):
120
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
121
+
122
+ }
123
+
124
+ /// Helper to initialize a tensor view
125
+ template <typename Element, typename Layout>
126
+ void initialize_tensor(
127
+ cutlass::TensorView<Element, Layout> view,
128
+ cutlass::Distribution::Kind dist_kind,
129
+ uint64_t seed) {
130
+
131
+ if (dist_kind == cutlass::Distribution::Uniform) {
132
+
133
+ int scope;
134
+ int bits = cutlass::sizeof_bits<Element>::value;
135
+
136
+ if (bits <= 8) {
137
+ scope = 2;
138
+ }
139
+ else if (bits == 16) {
140
+ scope = 3;
141
+ }
142
+ else {
143
+ scope = 8;
144
+ }
145
+ cutlass::reference::host::TensorFillRandomUniform(
146
+ view, seed, scope, -scope, 0);
147
+ }
148
+ else if (dist_kind == cutlass::Distribution::Identity) {
149
+
150
+ cutlass::reference::host::TensorFillIdentity(view);
151
+ }
152
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
153
+
154
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
155
+ }
156
+ else if (dist_kind == cutlass::Distribution::Sequential) {
157
+
158
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
159
+ }
160
+ else {
161
+ }
162
+ }
163
+
164
+ void initialize(
165
+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
166
+
167
+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
168
+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
169
+ tensor_B_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
170
+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
171
+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
172
+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
173
+
174
+ initialize_tensor(tensor_A.host_view(), init_A, seed);
175
+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
176
+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
177
+
178
+ cutlass::reorder_convK<InterleavedK>(
179
+ tensor_B_reordered.host_ref(), tensor_B.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size));
180
+
181
+ tensor_A.sync_device();
182
+ tensor_B.sync_device();
183
+ tensor_B_reordered.sync_device();
184
+ tensor_C.sync_device();
185
+ tensor_D_computed.sync_device();
186
+ tensor_D_reference.sync_device();
187
+ }
188
+
189
+ bool sufficient() const {
190
+ //
191
+ // Determine SMEM requirements and waive if not satisfied
192
+ //
193
+
194
+ size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage);
195
+
196
+ cudaDeviceProp properties;
197
+ int device_idx;
198
+ cudaError_t result = cudaGetDevice(&device_idx);
199
+
200
+ if (result != cudaSuccess) {
201
+ throw std::runtime_error("cudaGetDevice() API call failed.");
202
+ }
203
+
204
+ result = cudaGetDeviceProperties(&properties, device_idx);
205
+
206
+ if (result != cudaSuccess) {
207
+ throw std::runtime_error("cudaGetDeviceProperties() failed");
208
+ }
209
+
210
+ if (properties.sharedMemPerMultiprocessor < smem_size) {
211
+ return false;
212
+ }
213
+
214
+ return true;
215
+ }
216
+
217
+ /// Executes one test
218
+ bool run(
219
+ cutlass::conv::Conv2dProblemSize const &problem_size,
220
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
221
+ ElementCompute alpha = ElementCompute(1),
222
+ ElementCompute beta = ElementCompute(0)) {
223
+
224
+ // Waive test if insufficient CUDA device
225
+ if (!sufficient()) {
226
+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
227
+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
228
+ }
229
+ return true;
230
+ }
231
+
232
+ #if 0 //display conv2d problem size for debugging
233
+ std::cout << problem_size << std::endl
234
+ << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl
235
+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
236
+ << std::endl;
237
+ #endif
238
+
239
+ initialize(problem_size);
240
+
241
+ // configure the operator
242
+ Conv2d conv2d_op;
243
+
244
+ typename Conv2d::Arguments conv2d_args(
245
+ problem_size,
246
+ tensor_A.device_ref(),
247
+ tensor_B_reordered.device_ref(),
248
+ tensor_C.device_ref(),
249
+ tensor_D_computed.device_ref(),
250
+ {alpha, beta},
251
+ split_k_mode
252
+ );
253
+
254
+ // find workspace requirement for parallel split-k reduction
255
+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
256
+
257
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
258
+
259
+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get());
260
+
261
+ // conv2d operation with parallel split-k-mode
262
+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
263
+
264
+ // conv2d output is written to workspace in global memory
265
+ conv2d_args.ref_D.reset(reinterpret_cast<ElementC*>(workspace.get()));
266
+ // accumulate mma for each cta in k-dimension (1.0 * A * B)
267
+ conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)};
268
+ // update conv2d operator arguments
269
+ status = conv2d_op.update(conv2d_args, workspace.get());
270
+ }
271
+
272
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
273
+ if (status != cutlass::Status::kSuccess) {
274
+ return false;
275
+ }
276
+
277
+ // run conv2d operator
278
+ status = conv2d_op();
279
+
280
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
281
+ if (status != cutlass::Status::kSuccess) {
282
+ return false;
283
+ }
284
+
285
+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
286
+
287
+ // configure parallel reduction operator
288
+ ReductionDevice reduction_op;
289
+
290
+ typename ReductionDevice::Arguments reduction_args(
291
+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
292
+ problem_size.split_k_slices,
293
+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
294
+ {
295
+ reinterpret_cast<ElementAccumulator*> (workspace.get()),
296
+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
297
+ },
298
+ {
299
+ tensor_D_computed.device_data(),
300
+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
301
+ },
302
+ {
303
+ tensor_C.device_data(),
304
+ ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
305
+ },
306
+ // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C
307
+ {alpha, beta}
308
+ );
309
+
310
+ status = reduction_op.initialize(reduction_args, nullptr);
311
+
312
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
313
+ if (status != cutlass::Status::kSuccess) {
314
+ return false;
315
+ }
316
+
317
+ // run prallel reduction kernel
318
+ status = reduction_op();
319
+
320
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
321
+ if (status != cutlass::Status::kSuccess) {
322
+ return false;
323
+ }
324
+ }
325
+ bool passed = false;
326
+
327
+ tensor_D_computed.sync_host();
328
+
329
+ //
330
+ // Reference check - support caching results
331
+ //
332
+
333
+ CachedTestKey cached_test_key = CreateCachedConv2dTestKey<
334
+ ElementA, LayoutA,
335
+ ElementB, LayoutB,
336
+ ElementC, LayoutC,
337
+ ElementAccumulator,
338
+ ElementCompute
339
+ >(
340
+ kConvolutionalOperator,
341
+ problem_size,
342
+ alpha,
343
+ beta,
344
+ tensor_A.host_view(),
345
+ tensor_B.host_view(),
346
+ tensor_C.host_view()
347
+ );
348
+
349
+ //
350
+ // Look for the cached key
351
+ //
352
+
353
+ bool cached_result_loaded = false;
354
+ CachedTestResult cached_test_result;
355
+
356
+ std::string conv2d_result_cache_name =
357
+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
358
+
359
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
360
+
361
+ CachedTestResultListing cached_results(conv2d_result_cache_name);
362
+
363
+ auto cached = cached_results.find(cached_test_key);
364
+
365
+ cached_result_loaded = cached.first;
366
+ if (cached_result_loaded) {
367
+ cached_test_result = cached.second;
368
+ }
369
+ }
370
+
371
+ if (!cached_result_loaded) {
372
+
373
+ #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
374
+
375
+ cutlass::reference::device::Conv2d<
376
+ ElementA,
377
+ LayoutA,
378
+ ElementB,
379
+ LayoutB,
380
+ ElementC,
381
+ LayoutC,
382
+ ElementCompute,
383
+ ElementAccumulator,
384
+ cutlass::NumericConverterClamp<ElementC, ElementCompute>
385
+ >(
386
+ kConvolutionalOperator,
387
+ problem_size,
388
+ tensor_A.device_ref(),
389
+ tensor_B.device_ref(),
390
+ tensor_C.device_ref(),
391
+ tensor_D_reference.device_ref(),
392
+ alpha,
393
+ beta);
394
+
395
+ cudaError_t result = cudaDeviceSynchronize();
396
+ EXPECT_EQ(result, cudaSuccess) << " device reference error: "
397
+ << cudaGetErrorString(result);
398
+
399
+ // sync host (copy device data to host) for dumping error output in case of mismatches
400
+ tensor_D_reference.sync_host();
401
+
402
+ #else
403
+
404
+ cutlass::reference::host::Conv2d<
405
+ ElementA,
406
+ LayoutA,
407
+ ElementB,
408
+ LayoutB,
409
+ ElementC,
410
+ LayoutC,
411
+ ElementCompute,
412
+ ElementAccumulator,
413
+ ElementC,
414
+ cutlass::NumericConverterClamp<ElementC, ElementCompute>
415
+ >(
416
+ kConvolutionalOperator,
417
+ problem_size,
418
+ tensor_A.host_ref(),
419
+ tensor_B.host_ref(),
420
+ tensor_C.host_ref(),
421
+ tensor_D_reference.host_ref(),
422
+ alpha,
423
+ beta);
424
+
425
+ #endif
426
+
427
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
428
+
429
+ cached_test_result.D = TensorHash(tensor_D_reference.host_view());
430
+
431
+ CachedTestResultListing cached_results(conv2d_result_cache_name);
432
+
433
+ cached_results.append(cached_test_key, cached_test_result);
434
+ cached_results.write(conv2d_result_cache_name);
435
+ }
436
+ } // if (!cached_result_loaded)
437
+
438
+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
439
+
440
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
441
+ passed = (tensor_D_hash == cached_test_result.D);
442
+
443
+ EXPECT_EQ(tensor_D_hash, cached_test_result.D)
444
+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
445
+ }
446
+ else {
447
+
448
+ passed = cutlass::reference::host::TensorEquals(
449
+ tensor_D_computed.host_view(),
450
+ tensor_D_reference.host_view());
451
+ }
452
+
453
+ EXPECT_TRUE(passed);
454
+
455
+ if (!passed) {
456
+ std::stringstream fname;
457
+
458
+ fname << "error_Conv2d_ImplicitGemm_device_"
459
+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
460
+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
461
+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_"))
462
+ << "ncxhwx_"
463
+ << problem_size.N << "x"
464
+ << problem_size.H << "x"
465
+ << problem_size.W << "x"
466
+ << problem_size.C
467
+ << "_cxrskx_"
468
+ << problem_size.K << "x"
469
+ << problem_size.R << "x"
470
+ << problem_size.S << "x"
471
+ << problem_size.C
472
+ << "_padding_"
473
+ << problem_size.pad_h << "x"
474
+ << problem_size.pad_w
475
+ << "_stride_"
476
+ << problem_size.stride_h << "x"
477
+ << problem_size.stride_w
478
+ << "_dilation_"
479
+ << problem_size.dilation_h << "x"
480
+ << problem_size.dilation_w << "_"
481
+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_")
482
+ << Conv2d::ThreadblockShape::kM << "x"
483
+ << Conv2d::ThreadblockShape::kN << "x"
484
+ << Conv2d::ThreadblockShape::kK << "_"
485
+ << Conv2d::WarpShape::kM << "x"
486
+ << Conv2d::WarpShape::kN << "x"
487
+ << Conv2d::WarpShape::kK << ".txt";
488
+
489
+ std::cout << fname.str() << std::endl;
490
+
491
+ std::ofstream results(fname.str());
492
+
493
+ results << problem_size << std::endl;
494
+
495
+ results
496
+ << "\nA:\n" << tensor_A.host_view() << "\n"
497
+ << "\nB:\n" << tensor_B.host_view() << "\n"
498
+ << "\nC:\n" << tensor_C.host_view() << "\n";
499
+
500
+ results << "\nD reference (hash: " << cached_test_result.D << ")\n";
501
+
502
+ if (!cached_result_loaded) {
503
+ results
504
+ << tensor_D_reference.host_view() << "\n";
505
+ }
506
+
507
+ results
508
+ << "\nD computed (hash: " << tensor_D_hash << ")\n"
509
+ << tensor_D_computed.host_view() << "\n";
510
+
511
+ }
512
+
513
+ return passed;
514
+ }
515
+
516
+ };
517
+
518
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////
519
+ // TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
520
+ // TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
521
+ // Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
522
+ // (conv_blacklist_sizes)
523
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
524
+ template <typename ImplicitGemm, int InterleavedK>
525
+ bool TestAllInterleavedConv2d(
526
+ const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(),
527
+ const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) {
528
+
529
+ bool passed = true;
530
+
531
+ //
532
+ // Testbed object
533
+ //
534
+
535
+ InterleavedTestbedConv2d<ImplicitGemm, InterleavedK> testbed;
536
+
537
+ //
538
+ // Get conv problem sizes to run conv operator
539
+ //
540
+ TestbedConv2dProblemSizes conv_problems(InterleavedK); // minimum channel size must be multiple of InterleavedK for interleaved layout
541
+
542
+ // Vector of conv2d problem sizes to avoid duplicate runs
543
+ Conv2dProblemVector conv_tested_sizes;
544
+
545
+ Conv2dProblemVector const *problem_vectors[] = {
546
+ &conv_test_sizes, // run user specified sizes
547
+ &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
548
+ &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
549
+ #if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
550
+ &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
551
+ #endif
552
+ };
553
+
554
+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
555
+ for (Conv2dProblemVector const * problem_vector : problem_vectors) {
556
+
557
+ ChannelDivisibilitySpecification channel_spec(InterleavedK); //input and output channels must be multiple of InterleavedK
558
+ auto pruned_problem_vector = prune(*problem_vector, channel_spec);
559
+
560
+ // Run conv testbed on default convolution sizes
561
+ for(auto conv_problem : pruned_problem_vector) {
562
+
563
+ // Skip blacklist and avoid duplicate problem sizes
564
+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
565
+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
566
+ continue;
567
+ }
568
+
569
+ //
570
+ // Procedurally disable certain cases
571
+ //
572
+
573
+ // CUTLASS DGRAD's unity stride specialization only support stride {1, 1}
574
+ if ((ImplicitGemm::kConvolutionalOperator ==
575
+ cutlass::conv::Operator::kDgrad) &&
576
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
577
+ cutlass::conv::StrideSupport::kUnity)) {
578
+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
579
+ continue;
580
+ }
581
+ }
582
+
583
+ //
584
+ // Test
585
+ //
586
+ // push back tested problem size to avoid re-running duplicates
587
+ conv_tested_sizes.push_back(conv_problem);
588
+
589
+ // test mode = xcross
590
+ passed = testbed.run(
591
+ conv_problem,
592
+ cutlass::conv::SplitKMode::kSerial);
593
+
594
+ if (!passed) {
595
+ return false;
596
+ }
597
+
598
+ // test mode = convolution
599
+ passed = testbed.run(
600
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
601
+ cutlass::conv::SplitKMode::kSerial);
602
+
603
+ if (!passed) {
604
+ return false;
605
+ }
606
+ }
607
+ }
608
+
609
+ #if 0
610
+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
611
+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
612
+ // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
613
+ // alpha and beta for local testing, but only runs one value for alpha and beta.
614
+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
615
+ {1, 17, 11, 288}, // input size (NHWC)
616
+ {160, 3, 3, 288}, // filter size (KRSC)
617
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
618
+ {1, 1}, // stride (stride_h, stride_w)
619
+ {1, 1} // dilation (dilation_h, dilation_w)
620
+ );
621
+
622
+ cutlass::conv::SplitKMode split_k_modes [] = {
623
+ cutlass::conv::SplitKMode::kSerial,
624
+ cutlass::conv::SplitKMode::kParallel,
625
+ };
626
+
627
+ int split_k_slices[] = {
628
+ 1, 2, 3, 4, 201
629
+ };
630
+
631
+ double problem_alpha[] = {
632
+ 2.0
633
+ };
634
+
635
+ double problem_beta[] = {
636
+ 2.0
637
+ };
638
+
639
+ for (auto split_k_mode : split_k_modes) {
640
+ for (auto split_k_slice : split_k_slices) {
641
+ for (auto alpha : problem_alpha) {
642
+ for (auto beta : problem_beta) {
643
+
644
+ passed = testbed.run(
645
+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
646
+ split_k_mode,
647
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
648
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
649
+
650
+ if (!passed) {
651
+ return false;
652
+ }
653
+ }
654
+ }
655
+ }
656
+ }
657
+ #endif
658
+
659
+ return passed;
660
+ }
661
+
662
+ /////////////////////////////////////////////////////////////////////////////////////////////////
663
+
664
+ } // namespace device
665
+ } // namespace conv
666
+ } // namespace test
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ /*! \file
33
+ \brief Testbed for running device-level Conv2Ds with absolute maximum calculation and scaling
34
+ */
35
+
36
+ #pragma once
37
+
38
+ #include <iostream>
39
+ #include <fstream>
40
+ #include <sstream>
41
+
42
+ #include "conv2d_problems.h"
43
+ #include "../../common/cutlass_unit_test.h"
44
+ #include "../../gemm/device/testbed_utils.h"
45
+
46
+ #include "cutlass/matrix_coord.h"
47
+ #include "cutlass/conv/convolution.h"
48
+ #include "cutlass/layout/matrix.h"
49
+
50
+ #include "cutlass/util/host_tensor.h"
51
+ #include "cutlass/util/tensor_view_io.h"
52
+ #include "cutlass/util/distribution.h"
53
+ #include "cutlass/util/reference/host/convolution.h"
54
+ #include "cutlass/util/reference/host/tensor_copy.h"
55
+ #include "cutlass/util/reference/host/tensor_compare.h"
56
+ #include "cutlass/util/reference/host/tensor_fill.h"
57
+ #include "cutlass/util/reference/host/tensor_reduce.h"
58
+
59
+ namespace test {
60
+ namespace conv {
61
+ namespace device {
62
+
63
+ /////////////////////////////////////////////////////////////////////////////////////////////////
64
+
65
+ template <
66
+ typename Conv,
67
+ template<typename T> class ActivationFunctor
68
+ >
69
+ struct TestbedConv2dWithAbsMax {
70
+
71
+ using ElementAccumulator = typename Conv::ElementAccumulator;
72
+ using ElementCompute = typename Conv::UnderlyingKernel::Epilogue::OutputOp::ElementCompute;
73
+ using ElementScalingFactor = typename Conv::EpilogueOutputOp::ElementScalingFactor;
74
+ using ElementAbsmax = typename Conv::EpilogueOutputOp::ElementAbsmax;
75
+ static cutlass::conv::Operator const kConvolutionalOperator = Conv::kConvolutionalOperator;
76
+
77
+ static bool const kScaleAux = Conv::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded;
78
+ static bool const kScaleOutput = Conv::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded;
79
+ bool doScaleA;
80
+ bool doScaleB;
81
+ bool doScaleC;
82
+
83
+ /// Initialization
84
+ cutlass::Distribution::Kind init_A;
85
+ cutlass::Distribution::Kind init_B;
86
+ cutlass::Distribution::Kind init_C;
87
+ uint64_t seed;
88
+
89
+ cutlass::HostTensor<typename Conv::ElementA, typename Conv::LayoutA> tensor_A;
90
+ cutlass::HostTensor<typename Conv::ElementB, typename Conv::LayoutB> tensor_B;
91
+ cutlass::HostTensor<typename Conv::ElementC, typename Conv::LayoutC> tensor_C;
92
+ cutlass::HostTensor<typename Conv::EpilogueOutputOp::ElementAuxOutput, typename Conv::LayoutC> tensor_Aux;
93
+ cutlass::HostTensor<typename Conv::EpilogueOutputOp::ElementOutput, typename Conv::LayoutC> tensor_D;
94
+ cutlass::HostTensor<typename Conv::ElementC, typename Conv::LayoutC> tensor_Vector;
95
+ cutlass::HostTensor<ElementAccumulator, typename Conv::LayoutC> tmp_D;
96
+ cutlass::HostTensor<typename Conv::EpilogueOutputOp::ElementOutput, typename Conv::LayoutC> reference_D;
97
+ cutlass::HostTensor<typename Conv::EpilogueOutputOp::ElementAuxOutput, typename Conv::LayoutC> reference_Aux;
98
+ cutlass::HostTensor<ElementScalingFactor, typename Conv::LayoutC> scale_A;
99
+ cutlass::HostTensor<ElementScalingFactor, typename Conv::LayoutC> scale_B;
100
+ cutlass::HostTensor<ElementScalingFactor, typename Conv::LayoutC> scale_C;
101
+ cutlass::HostTensor<ElementScalingFactor, typename Conv::LayoutC> scale_D;
102
+ cutlass::HostTensor<ElementScalingFactor, typename Conv::LayoutC> scale_Aux;
103
+ cutlass::HostTensor<ElementAbsmax, typename Conv::LayoutC> abs_max_Aux;
104
+ cutlass::HostTensor<ElementAbsmax, typename Conv::LayoutC> abs_max_D;
105
+ cutlass::HostTensor<ElementAbsmax, typename Conv::LayoutC> reference_abs_max_Aux;
106
+ cutlass::HostTensor<ElementAbsmax, typename Conv::LayoutC> reference_abs_max_D;
107
+
108
+ //
109
+ // Methods
110
+ //
111
+
112
+ TestbedConv2dWithAbsMax(
113
+ bool scaleA = true,
114
+ bool scaleB = true,
115
+ bool scaleC = true,
116
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
117
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
118
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
119
+ uint64_t seed_ = 2080
120
+ ):
121
+ doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC),
122
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { }
123
+
124
+ /// Helper to initialize scaling factors
125
+ template <typename Element, typename Layout>
126
+ bool initialize_scale_factor(cutlass::TensorView<Element, Layout> view, uint64_t seed, int bits=0) {
127
+ cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits);
128
+ return true;
129
+ }
130
+
131
+ /// Helper to initialize a tensor view
132
+ template <typename Element, typename Layout>
133
+ bool initialize_tensor(
134
+ cutlass::TensorView<Element, Layout> view,
135
+ cutlass::Distribution::Kind dist_kind,
136
+ uint64_t seed) {
137
+
138
+ if (dist_kind == cutlass::Distribution::Uniform) {
139
+
140
+ double scope_max, scope_min;
141
+ int bits_input = cutlass::sizeof_bits<Element>::value;
142
+ int bits_output = cutlass::sizeof_bits<typename Conv::ElementC>::value;
143
+
144
+ if (bits_input == 1) {
145
+ scope_max = 2;
146
+ scope_min = 0;
147
+ } else if (bits_input <= 8) {
148
+ scope_max = 2;
149
+ scope_min = -2;
150
+ } else if (bits_output == 16) {
151
+ scope_max = 5;
152
+ scope_min = -5;
153
+ } else {
154
+ scope_max = 8;
155
+ scope_min = -8;
156
+ }
157
+
158
+ cutlass::reference::host::TensorFillRandomUniform(
159
+ view, seed, scope_max, scope_min, 0);
160
+ }
161
+ else if (dist_kind == cutlass::Distribution::Identity) {
162
+
163
+ cutlass::reference::host::TensorFillIdentity(view);
164
+ }
165
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
166
+
167
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
168
+ }
169
+ else if (dist_kind == cutlass::Distribution::Sequential) {
170
+
171
+ cutlass::reference::host::BlockFillSequential(
172
+ view.data(), view.capacity());
173
+ }
174
+ else {
175
+ EXPECT_TRUE(false) << "Not implemented";
176
+ return false;
177
+ }
178
+
179
+ return true;
180
+ }
181
+
182
+ /// Initializes data structures
183
+ void initialize(cutlass::conv::Conv2dProblemSize const &problem_size) {
184
+ //
185
+ // Allocate the GEMM workspace
186
+ //
187
+
188
+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
189
+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
190
+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
191
+ tensor_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
192
+ tensor_Vector.resize({1, 1, 1, implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()});
193
+ reference_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false);
194
+ tmp_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false);
195
+
196
+ EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019));
197
+ EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018));
198
+ EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017));
199
+ EXPECT_TRUE(initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020));
200
+
201
+ // It is possible to randomly initialize to all zeros, so override this with non-zeros
202
+ // in the upper left corner of each operand.
203
+ cutlass::Coord<4> origin(0);
204
+ tensor_A.host_view().at(origin) = typename Conv::ElementA(1);
205
+ tensor_B.host_view().at(origin) = typename Conv::ElementB(1);
206
+ tensor_C.host_view().at(origin) = typename Conv::ElementC(1);
207
+ tensor_Vector.host_view().at(origin) = typename Conv::ElementC(1);
208
+
209
+ cutlass::reference::host::TensorFill(tensor_D.host_view());
210
+ cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view());
211
+
212
+ tensor_A.sync_device();
213
+ tensor_B.sync_device();
214
+ tensor_C.sync_device();
215
+ tensor_D.sync_device();
216
+ tensor_Vector.sync_device();
217
+
218
+ int scale_bits = 2;
219
+ if (doScaleA) {
220
+ scale_A.resize({1, 1, 1, 1});
221
+ EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits));
222
+ scale_A.sync_device();
223
+ }
224
+
225
+ if (doScaleB) {
226
+ scale_B.resize({1, 1, 1, 1});
227
+ EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits));
228
+ scale_B.sync_device();
229
+ }
230
+
231
+ if (doScaleC) {
232
+ scale_C.resize({1, 1, 1, 1});
233
+ EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits));
234
+ scale_C.sync_device();
235
+ }
236
+
237
+ if (kScaleOutput) {
238
+ scale_D.resize({1, 1, 1, 1});
239
+ EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits));
240
+ scale_D.sync_device();
241
+
242
+ abs_max_D.resize({1, 1, 1, 1});
243
+ cutlass::reference::host::TensorFill(abs_max_D.host_view());
244
+ abs_max_D.sync_device();
245
+
246
+ reference_abs_max_D.resize({1, 1, 1, 1});
247
+ }
248
+
249
+ if (kScaleAux) {
250
+ tensor_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
251
+ cutlass::reference::host::TensorFill(tensor_Aux.host_view());
252
+ tensor_Aux.sync_device();
253
+
254
+ scale_Aux.resize({1, 1, 1, 1});
255
+ EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits));
256
+ scale_Aux.sync_device();
257
+
258
+ abs_max_Aux.resize({1, 1, 1, 1});
259
+ cutlass::reference::host::TensorFill(abs_max_Aux.host_view());
260
+ abs_max_Aux.sync_device();
261
+
262
+ reference_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false);
263
+ reference_abs_max_Aux.resize({1, 1, 1, 1});
264
+ }
265
+ }
266
+
267
+ /// Compares computed reference with device reference and outputs to a file if incorrect
268
+ bool compare_reference(
269
+ cutlass::conv::Conv2dProblemSize const &problem_size,
270
+ ElementCompute alpha,
271
+ ElementCompute beta) {
272
+
273
+ tensor_D.sync_host();
274
+
275
+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0);
276
+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0);
277
+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0);
278
+
279
+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0);
280
+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0);
281
+ bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view());
282
+
283
+ if (kScaleAux) {
284
+ tensor_Aux.sync_host();
285
+ abs_max_Aux.sync_host();
286
+ EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0);
287
+ EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0);
288
+ EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0);
289
+ passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view());
290
+ passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view());
291
+ }
292
+
293
+ if (kScaleOutput) {
294
+ abs_max_D.sync_host();
295
+ EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0);
296
+ passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view());
297
+ }
298
+
299
+ EXPECT_TRUE(passed) << " mismatched reference";
300
+
301
+ if (!passed) {
302
+
303
+ std::ofstream file0("conv_testbed_with_amax_errors_reference.txt");
304
+ std::ofstream file1("conv_testbed_with_amax_errors_computed.txt");
305
+
306
+ std::ofstream file("conv_testbed_with_amax_errors.txt");
307
+
308
+ file
309
+ << "problem: " << problem_size
310
+ << ", alpha: " << alpha << ", beta: " << beta << "\n\n";
311
+
312
+ file
313
+ << "A =\n" << tensor_A.host_view()
314
+ << "\nB =\n" << tensor_B.host_view()
315
+ << "\nC =\n" << tensor_C.host_view()
316
+ << "\nVector =\n" << tensor_Vector.host_view()
317
+ << "\nScaleA = " << scale_A.host_view()
318
+ << "\nScaleB = " << scale_B.host_view()
319
+ << "\nScaleC = " << scale_C.host_view()
320
+ << "\nScaleD = " << scale_D.host_view()
321
+ << "\nScaleAux = " << scale_Aux.host_view()
322
+ << std::endl;
323
+
324
+ file0 << "\n\nReference D =\n" << reference_D.host_view() << std::endl;
325
+ file1 << "\n\nComputed D =\n" << tensor_D.host_view() << std::endl;
326
+ if (kScaleAux) {
327
+ file0 << "\n\nReference Aux =\n" << reference_Aux.host_view() << std::endl;
328
+ file1 << "\n\nComputed Aux =\n" << tensor_Aux.host_view() << std::endl;
329
+ file0 << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() << std::endl;
330
+ file1 << "\n\nComputed Absmax Aux = " << abs_max_Aux.host_view() << std::endl;
331
+ }
332
+ if (kScaleOutput) {
333
+ file0 << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() << std::endl;
334
+ file1 << "\n\nComputed Absmax D = " << abs_max_D.host_view() << std::endl;
335
+ }
336
+ }
337
+
338
+ return passed;
339
+ }
340
+
341
+ /// Verifies the result is a GEMM
342
+ bool verify(
343
+ cutlass::conv::Conv2dProblemSize const &problem_size,
344
+ ElementCompute alpha,
345
+ ElementCompute beta) {
346
+
347
+ cutlass::Coord<4> origin(0);
348
+ ElementCompute scaled_alpha = alpha;
349
+ if (doScaleA) {
350
+ scaled_alpha *= scale_A.host_view().at(origin);
351
+ }
352
+ if (doScaleB) {
353
+ scaled_alpha *= scale_B.host_view().at(origin);
354
+ }
355
+
356
+ ElementCompute scaled_beta = beta;
357
+ if (doScaleC) {
358
+ scaled_beta *= scale_C.host_view().at(origin);
359
+ }
360
+
361
+ //
362
+ // Verify
363
+ //
364
+
365
+ cutlass::reference::host::Conv2d<
366
+ typename Conv::ElementA, typename Conv::LayoutA,
367
+ typename Conv::ElementB, typename Conv::LayoutB,
368
+ typename Conv::ElementC, typename Conv::LayoutC,
369
+ ElementCompute, ElementAccumulator, ElementAccumulator
370
+ >(
371
+ kConvolutionalOperator,
372
+ problem_size,
373
+ tensor_A.host_ref(),
374
+ tensor_B.host_ref(),
375
+ tensor_C.host_ref(),
376
+ tmp_D.host_ref(),
377
+ scaled_alpha,
378
+ scaled_beta
379
+ );
380
+
381
+ ElementCompute tmp_abs_max_Aux(0.);
382
+ ElementCompute tmp_abs_max_D(0.);
383
+
384
+ cutlass::NumericConverter<ElementCompute, typename Conv::ElementC> cvt_c_to_compute;
385
+ cutlass::NumericConverter<ElementCompute, ElementAccumulator> cvt_accum_to_compute;
386
+ cutlass::NumericConverter<ElementAbsmax, ElementCompute> cvt_compute_to_absmax;
387
+ cutlass::NumericConverter<typename Conv::EpilogueOutputOp::ElementOutput, ElementCompute> cvt_compute_to_d;
388
+ cutlass::NumericConverter<typename Conv::EpilogueOutputOp::ElementAuxOutput, ElementCompute> cvt_compute_to_aux;
389
+
390
+ cutlass::absolute_value_op<ElementCompute> abs;
391
+ cutlass::maximum_with_nan_propogation<ElementCompute> max;
392
+ ActivationFunctor<ElementCompute> act;
393
+
394
+ ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.);
395
+
396
+ for (int n = 0; n < problem_size.N; ++n) {
397
+ for (int p = 0; p < problem_size.P; ++p) {
398
+ for (int q = 0; q < problem_size.Q; ++q) {
399
+ for (int k = 0; k < problem_size.K; ++k) {
400
+ ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({n, p, q, k}));
401
+ ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, 0, 0, k}));
402
+ ElementCompute aux = intermediate + bias;
403
+ ElementCompute d = act(aux);
404
+ tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux);
405
+ tmp_abs_max_D = max(abs(d), tmp_abs_max_D);
406
+ reference_D.host_view().at({n, p, q, k}) = cvt_compute_to_d(d * d_scale);
407
+
408
+ if (kScaleAux) {
409
+ reference_Aux.host_view().at({n, p, q, k}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin));
410
+ }
411
+ }
412
+ }
413
+ }
414
+ }
415
+ if (kScaleAux) {
416
+ reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux);
417
+ }
418
+
419
+ if (kScaleOutput) {
420
+ reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D);
421
+ }
422
+
423
+ return compare_reference(problem_size, alpha, beta);
424
+ }
425
+
426
+ /// Returns true if the CUDA device is sufficient to execute the kernel.
427
+ bool sufficient() const {
428
+ //
429
+ // Determine SMEM requirements and waive if not satisfied
430
+ //
431
+
432
+ size_t smem_size = sizeof(typename Conv::UnderlyingKernel::SharedStorage);
433
+
434
+ cudaDeviceProp properties;
435
+ int device_idx;
436
+ cudaError_t result = cudaGetDevice(&device_idx);
437
+
438
+ if (result != cudaSuccess) {
439
+ throw std::runtime_error("cudaGetDevice() API call failed.");
440
+ }
441
+
442
+ result = cudaGetDeviceProperties(&properties, device_idx);
443
+
444
+ if (result != cudaSuccess) {
445
+ throw std::runtime_error("cudaGetDeviceProperties() failed");
446
+ }
447
+
448
+ if (properties.sharedMemPerBlockOptin < smem_size) {
449
+ return false;
450
+ }
451
+
452
+ return true;
453
+ }
454
+
455
+ /// Executes one test
456
+ bool run(
457
+ cutlass::conv::Conv2dProblemSize const &problem_size,
458
+ ElementCompute alpha = ElementCompute(1),
459
+ ElementCompute beta = ElementCompute(0))
460
+ {
461
+
462
+ // Waive test if insufficient CUDA device
463
+ if (!sufficient()) {
464
+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
465
+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
466
+ }
467
+ return true;
468
+ }
469
+
470
+ this->initialize(problem_size);
471
+
472
+ //
473
+ // Initialize the GEMM operator
474
+ //
475
+
476
+ typename Conv::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta};
477
+ typename Conv::EpilogueOutputOp::Params epilogue_params{
478
+ activation_params,
479
+ scale_A.device_data(),
480
+ scale_B.device_data(),
481
+ scale_C.device_data(),
482
+ scale_D.device_data(),
483
+ scale_Aux.device_data(),
484
+ abs_max_Aux.device_data(),
485
+ abs_max_D.device_data()
486
+ };
487
+
488
+ typename Conv::Arguments arguments{
489
+ problem_size,
490
+ tensor_A.device_ref(),
491
+ tensor_B.device_ref(),
492
+ tensor_C.device_ref(),
493
+ tensor_D.device_ref(),
494
+ tensor_Aux.device_ref(),
495
+ epilogue_params,
496
+ cutlass::conv::SplitKMode::kSerial,
497
+ tensor_Vector.device_data(),
498
+ 0
499
+ };
500
+
501
+ Conv conv2d_op;
502
+
503
+ cutlass::Status status = conv2d_op.can_implement(arguments);
504
+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
505
+
506
+ size_t workspace_size = Conv::get_workspace_size(arguments);
507
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
508
+
509
+ status = conv2d_op.initialize(arguments, workspace.get());
510
+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
511
+
512
+ //
513
+ // Run the GEMM
514
+ //
515
+
516
+ status = conv2d_op();
517
+
518
+ EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status);
519
+
520
+ cudaError_t cuda_error = cudaDeviceSynchronize();
521
+ EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error);
522
+
523
+ //
524
+ // Verify
525
+ //
526
+
527
+ bool passed = this->verify(problem_size, alpha, beta);
528
+
529
+ if (!passed) {
530
+ std::cout << "Failed" << std::endl;
531
+ }
532
+
533
+ return passed;
534
+ }
535
+
536
+ };
537
+
538
+ /////////////////////////////////////////////////////////////////////////////////////////////////
539
+
540
+ template <
541
+ typename ImplicitGemm,
542
+ template<typename T> class ActivationFunctor = cutlass::epilogue::thread::Identity
543
+ >
544
+ bool TestAllConv2dWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) {
545
+ const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector();
546
+ const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector();
547
+
548
+ //
549
+ // Testbed object
550
+ //
551
+
552
+ TestbedConv2dWithAbsMax<ImplicitGemm, ActivationFunctor> testbed(scaleA, scaleB, scaleC);
553
+
554
+ //
555
+ // Get conv problem sizes to run conv operator
556
+ //
557
+ TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
558
+
559
+ // Vector of conv2d problem sizes to avoid duplicate runs
560
+ Conv2dProblemVector conv_tested_sizes;
561
+
562
+ Conv2dProblemVector const *problem_vectors[] = {
563
+ &conv_test_sizes, // run user specified sizes
564
+ &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
565
+ &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
566
+ #if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
567
+ &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
568
+ #endif
569
+ };
570
+
571
+ bool passed = true;
572
+
573
+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
574
+ for (Conv2dProblemVector const * problem_vector : problem_vectors) {
575
+
576
+ // Prune all problems with channels that aren't divisible by the number of elements accessed per
577
+ // load for operands A and B. This is meant to align with the requirements of iterators used for
578
+ // fprop kernels.
579
+ ChannelDivisibilitySpecification channel_spec(128 / cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
580
+ auto pruned_problem_vector = prune(*problem_vector, channel_spec);
581
+
582
+ // Run conv testbed on default convolution sizes
583
+ for(auto conv_problem : pruned_problem_vector) {
584
+
585
+ // Skip blacklist and avoid duplicate problem sizes
586
+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
587
+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
588
+ continue;
589
+ }
590
+
591
+ //
592
+ // Test
593
+ //
594
+ // push back tested problem size to avoid re-running duplicates
595
+ conv_tested_sizes.push_back(conv_problem);
596
+
597
+ // test mode = xcross
598
+ passed &= testbed.run(conv_problem);
599
+
600
+ if (!passed) {
601
+ return false;
602
+ }
603
+
604
+ // test mode = convolution
605
+ passed &= testbed.run(conv_problem.reset_mode(cutlass::conv::Mode::kConvolution));
606
+
607
+ if (!passed) {
608
+ return false;
609
+ }
610
+ }
611
+ }
612
+
613
+ return passed;
614
+ }
615
+
616
+ /////////////////////////////////////////////////////////////////////////////////////////////////
617
+
618
+ } // namespace device
619
+ } // namespace conv
620
+ } // namespace test
621
+
622
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implicit GEMM for fused epilogue broadcast testbed
33
+
34
+ Parallel split-k is not tested because we can just use regular conv kernel
35
+ when we need to use parallel-splitk. Broadcast can happen in the reduction
36
+ kernel.
37
+ */
38
+ #pragma once
39
+
40
+ #include <fstream>
41
+
42
+ #include "../../common/cutlass_unit_test.h"
43
+ #include "cutlass/cutlass.h"
44
+
45
+ #include "cutlass/conv/device/implicit_gemm_convolution.h"
46
+ #include "cutlass/reduction/device/reduce_split_k.h"
47
+ #include "cutlass/reduction/thread/reduction_operators.h"
48
+
49
+ #include "conv2d_problems.h"
50
+
51
+ #include "cutlass/util/host_tensor.h"
52
+ #include "cutlass/util/reference/host/tensor_fill.h"
53
+ #include "cutlass/util/reference/device/tensor_compare.h"
54
+ #include "cutlass/util/reference/host/tensor_compare.h"
55
+
56
+ #include "cutlass/util/reference/host/convolution.h"
57
+ #include "cutlass/util/reference/device/convolution.h"
58
+
59
+ #include "cutlass/core_io.h"
60
+ #include "cutlass/util/tensor_view_io.h"
61
+
62
+ #include "../cache_testbed_output.h"
63
+
64
+ namespace test {
65
+ namespace conv {
66
+ namespace device {
67
+
68
+ /////////////////////////////////////////////////////////////////////////////////////////////////
69
+
70
+ template <typename Conv2d>
71
+ struct Conv2dWithBroadcastReferenceOp {
72
+
73
+ using OutputOp = typename Conv2d::EpilogueOutputOp;
74
+
75
+ using ElementCompute = typename OutputOp::ElementCompute;
76
+ using ElementZ = typename OutputOp::ElementZ;
77
+ using ElementT = typename OutputOp::ElementT;
78
+
79
+ typename OutputOp::BinaryOp binary_op;
80
+ typename OutputOp::ElementwiseOp elementwise_op;
81
+
82
+ Conv2dWithBroadcastReferenceOp() { }
83
+
84
+ void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) {
85
+ ElementCompute t_full = binary_op(conv2d, bias);
86
+ T = ElementT(t_full);
87
+
88
+ ElementCompute z_full = elementwise_op(t_full);
89
+ Z = ElementZ(z_full);
90
+ }
91
+ };
92
+
93
+ /////////////////////////////////////////////////////////////////////////////////////////////////
94
+
95
+ // Fused testbed
96
+ //
97
+ // Y = CONV(AB, C)
98
+ //
99
+ // T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k])
100
+ //
101
+ // Z[n, p, q, k] = Elementwise(T[n, p, q, k])
102
+ //
103
+
104
+ template <
105
+ typename Conv2d,
106
+ typename ReferenceOp,
107
+ bool AddBroadcastFirst = false
108
+ >
109
+ class TestbedConv2dWithBroadcast {
110
+ public:
111
+
112
+ using ElementA = typename Conv2d::ElementA;
113
+ using LayoutA = typename Conv2d::LayoutA;
114
+ using ElementB = typename Conv2d::ElementB;
115
+ using LayoutB = typename Conv2d::LayoutB;
116
+ using ElementC = typename Conv2d::ElementC;
117
+ using LayoutC = typename Conv2d::LayoutC;
118
+ using ElementAccumulator = typename Conv2d::ElementAccumulator;
119
+ using ElementCompute = typename Conv2d::ElementCompute;
120
+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
121
+ using ElementZ = typename EpilogueOutputOp::ElementZ;
122
+ using ElementT = typename EpilogueOutputOp::ElementT;
123
+ using ElementVector = typename EpilogueOutputOp::ElementVector;
124
+
125
+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
126
+ static const bool kAddBroadcastFirst = AddBroadcastFirst;
127
+ static const bool kStoreT = EpilogueOutputOp::kStoreT;
128
+
129
+ public:
130
+
131
+ /// Initialization
132
+ cutlass::Distribution::Kind init_A;
133
+ cutlass::Distribution::Kind init_B;
134
+ cutlass::Distribution::Kind init_C;
135
+ uint64_t seed;
136
+
137
+ cutlass::HostTensor<ElementA, LayoutA> tensor_A;
138
+ cutlass::HostTensor<ElementB, LayoutB> tensor_B;
139
+ cutlass::HostTensor<ElementC, LayoutC> tensor_C;
140
+ cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_C_reference;
141
+ cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_computed;
142
+ cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_reference;
143
+ cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
144
+ cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
145
+ cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
146
+ cutlass::HostTensor<ElementVector, LayoutC> tensor_Broadcast; // Input Broadcast
147
+
148
+ public:
149
+
150
+ TestbedConv2dWithBroadcast(
151
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
152
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
153
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
154
+ uint64_t seed_ = 2080
155
+ ):
156
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
157
+
158
+ }
159
+
160
+ /// Helper to initialize a tensor view
161
+ template <typename Element, typename Layout>
162
+ void initialize_tensor(
163
+ cutlass::TensorView<Element, Layout> view,
164
+ cutlass::Distribution::Kind dist_kind,
165
+ uint64_t seed) {
166
+
167
+ if (dist_kind == cutlass::Distribution::Uniform) {
168
+
169
+ int scope;
170
+ int bits = cutlass::sizeof_bits<Element>::value;
171
+
172
+ if (bits <= 8) {
173
+ scope = 2;
174
+ }
175
+ else if (bits == 16) {
176
+ if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
177
+ scope = 3;
178
+ }
179
+ else {
180
+ scope = 5;
181
+ }
182
+ }
183
+ else {
184
+ scope = 8;
185
+ }
186
+
187
+ cutlass::reference::host::TensorFillRandomUniform(
188
+ view, seed, scope, -scope, 0);
189
+ }
190
+ else if (dist_kind == cutlass::Distribution::Identity) {
191
+
192
+ cutlass::reference::host::TensorFillIdentity(view);
193
+ }
194
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
195
+
196
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
197
+ }
198
+ else if (dist_kind == cutlass::Distribution::Sequential) {
199
+
200
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
201
+ }
202
+ else {
203
+ }
204
+ }
205
+
206
+ void initialize(
207
+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
208
+
209
+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
210
+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
211
+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
212
+ tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
213
+ tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
214
+ tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
215
+ tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
216
+ tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
217
+ tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
218
+ tensor_Broadcast.resize({
219
+ 1,
220
+ 1,
221
+ 1,
222
+ implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(),
223
+ });
224
+
225
+ initialize_tensor(tensor_A.host_view(), init_A, seed);
226
+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
227
+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
228
+ initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39);
229
+
230
+ for (int n = 0; n < tensor_C_reference.extent().n(); ++n) {
231
+ for (int p = 0; p < tensor_C_reference.extent().h(); ++p) {
232
+ for (int q = 0; q < tensor_C_reference.extent().w(); ++q) {
233
+ for (int k = 0; k < tensor_C_reference.extent().c(); ++k) {
234
+ tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k}));
235
+ }
236
+ }
237
+ }
238
+ }
239
+
240
+ tensor_A.sync_device();
241
+ tensor_B.sync_device();
242
+ tensor_C.sync_device();
243
+ tensor_Broadcast.sync_device();
244
+ tensor_C_reference.sync_device();
245
+ tensor_Z_computed.sync_device();
246
+ tensor_Z_reference.sync_device();
247
+ tensor_T_computed.sync_device();
248
+ tensor_T_reference.sync_device();
249
+ tensor_Y_reference.sync_device();
250
+ }
251
+
252
+ bool sufficient() const {
253
+ //
254
+ // Determine SMEM requirements and waive if not satisfied
255
+ //
256
+
257
+ size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage);
258
+
259
+ cudaDeviceProp properties;
260
+ int device_idx;
261
+ cudaError_t result = cudaGetDevice(&device_idx);
262
+
263
+ if (result != cudaSuccess) {
264
+ throw std::runtime_error("cudaGetDevice() API call failed.");
265
+ }
266
+
267
+ result = cudaGetDeviceProperties(&properties, device_idx);
268
+
269
+ if (result != cudaSuccess) {
270
+ throw std::runtime_error("cudaGetDeviceProperties() failed");
271
+ }
272
+
273
+ if (properties.sharedMemPerBlockOptin < smem_size) {
274
+ return false;
275
+ }
276
+
277
+ return true;
278
+ }
279
+
280
+ /// Executes one test
281
+ bool run(
282
+ cutlass::conv::Conv2dProblemSize const &problem_size,
283
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
284
+ ElementCompute alpha = ElementCompute(1),
285
+ ElementCompute beta = ElementCompute(1)) {
286
+
287
+ // Waive test if insufficient CUDA device
288
+ if (!sufficient()) {
289
+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
290
+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
291
+ }
292
+ return true;
293
+ }
294
+
295
+ #if 0 //display conv2d problem size for debugging
296
+ std::cout << problem_size << std::endl
297
+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
298
+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
299
+ << std::endl;
300
+ #endif
301
+
302
+ initialize(problem_size);
303
+
304
+ // configure the operator
305
+ Conv2d conv2d_op;
306
+ typename Conv2d::Arguments conv2d_args(
307
+ problem_size,
308
+ tensor_A.device_ref(),
309
+ tensor_B.device_ref(),
310
+ tensor_C.device_ref(),
311
+ tensor_Z_computed.device_ref(),
312
+ {alpha, beta},
313
+ split_k_mode,
314
+ tensor_Broadcast.device_data(),
315
+ kStoreT ? tensor_T_computed.device_data() : nullptr,
316
+ 0, // This must be zero
317
+ implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()
318
+ );
319
+
320
+ // initialize the kernel
321
+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
322
+
323
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
324
+
325
+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get());
326
+
327
+ if (status != cutlass::Status::kSuccess) {
328
+ cudaError_t error = cudaGetLastError();
329
+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
330
+ return true;
331
+ }
332
+
333
+ // run conv2d operator
334
+ status = conv2d_op();
335
+
336
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
337
+ if (status != cutlass::Status::kSuccess) {
338
+ return false;
339
+ }
340
+
341
+ bool passed = false;
342
+
343
+ cudaError_t result = cudaDeviceSynchronize();
344
+ EXPECT_EQ(result, cudaSuccess) << " device reference error: "
345
+ << cudaGetErrorString(result);
346
+
347
+ tensor_T_computed.sync_host();
348
+ tensor_Z_computed.sync_host();
349
+
350
+ //
351
+ // Reference check
352
+ //
353
+
354
+ // When kAddBroadcastFirst is true, add bias on the host
355
+ ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta;
356
+
357
+ #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
358
+
359
+ cutlass::reference::device::Conv2d<
360
+ ElementA,
361
+ LayoutA,
362
+ ElementB,
363
+ LayoutB,
364
+ ElementAccumulator,
365
+ LayoutC,
366
+ ElementAccumulator,
367
+ ElementAccumulator
368
+ >(
369
+ kConvolutionalOperator,
370
+ problem_size,
371
+ tensor_A.device_ref(),
372
+ tensor_B.device_ref(),
373
+ tensor_C_reference.device_ref(),
374
+ tensor_Y_reference.device_ref(),
375
+ alpha,
376
+ beta_ref);
377
+
378
+ // sync host (copy device data to host) for dumping error output in case of mismatches
379
+ tensor_Y_reference.sync_host();
380
+
381
+ #else
382
+
383
+ cutlass::reference::host::Conv2d<
384
+ ElementA,
385
+ LayoutA,
386
+ ElementB,
387
+ LayoutB,
388
+ ElementAccumulator,
389
+ LayoutC,
390
+ ElementAccumulator,
391
+ ElementAccumulator
392
+ >(
393
+ kConvolutionalOperator,
394
+ problem_size,
395
+ tensor_A.host_ref(),
396
+ tensor_B.host_ref(),
397
+ tensor_C_reference.host_ref(),
398
+ tensor_Y_reference.host_ref(),
399
+ alpha,
400
+ beta_ref);
401
+
402
+ #endif
403
+ ReferenceOp reference_op;
404
+
405
+ // compute tensor Z and tensor T
406
+ for (int n = 0; n < problem_size.N; ++n) {
407
+ for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) {
408
+ for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) {
409
+ for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) {
410
+
411
+ ElementZ z{};
412
+ ElementT t{};
413
+
414
+ ElementCompute accum = tensor_Y_reference.at({n, p, q, k});
415
+ ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k}));
416
+
417
+
418
+ if (kAddBroadcastFirst) {
419
+ reference_op(z, t, accum + bias,
420
+ beta * ElementCompute(tensor_C_reference.at({n, p, q, k})));
421
+ } else {
422
+ reference_op(z, t, accum, bias);
423
+ }
424
+
425
+ tensor_Z_reference.at({n, p, q, k}) = z;
426
+ tensor_T_reference.at({n, p, q, k}) = t;
427
+ }
428
+ }
429
+ }
430
+ }
431
+
432
+ if (kStoreT) {
433
+ passed = cutlass::reference::host::TensorEquals(
434
+ tensor_T_computed.host_view(),
435
+ tensor_T_reference.host_view());
436
+
437
+ EXPECT_TRUE(passed);
438
+ }
439
+
440
+ passed = cutlass::reference::host::TensorEquals(
441
+ tensor_Z_computed.host_view(),
442
+ tensor_Z_reference.host_view());
443
+
444
+ EXPECT_TRUE(passed);
445
+
446
+ if (!passed) {
447
+ std::stringstream fname;
448
+
449
+ fname << "error_Conv2d_ImplicitGemm_device_"
450
+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
451
+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
452
+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" :
453
+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_")))
454
+ << "nhwc_"
455
+ << problem_size.N << "x"
456
+ << problem_size.H << "x"
457
+ << problem_size.W << "x"
458
+ << problem_size.C
459
+ << "_krsc_"
460
+ << problem_size.K << "x"
461
+ << problem_size.R << "x"
462
+ << problem_size.S << "x"
463
+ << problem_size.C
464
+ << "_padding_"
465
+ << problem_size.pad_h << "x"
466
+ << problem_size.pad_w
467
+ << "_stride_"
468
+ << problem_size.stride_h << "x"
469
+ << problem_size.stride_w
470
+ << "_dilation_"
471
+ << problem_size.dilation_h << "x"
472
+ << problem_size.dilation_w << "_"
473
+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_")
474
+ << Conv2d::ThreadblockShape::kM << "x"
475
+ << Conv2d::ThreadblockShape::kN << "x"
476
+ << Conv2d::ThreadblockShape::kK << "_"
477
+ << Conv2d::WarpShape::kM << "x"
478
+ << Conv2d::WarpShape::kN << "x"
479
+ << Conv2d::WarpShape::kK << ".txt";
480
+
481
+ std::cout << fname.str() << std::endl;
482
+
483
+ std::ofstream results(fname.str());
484
+
485
+ results << problem_size << std::endl;
486
+
487
+ results
488
+ << "\nA:\n" << tensor_A.host_view() << "\n"
489
+ << "\nB:\n" << tensor_B.host_view() << "\n"
490
+ << "\nC:\n" << tensor_C.host_view() << "\n"
491
+ << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n"
492
+ << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n"
493
+ << "\nT reference:\n" << tensor_T_reference.host_view() << "\n"
494
+ << "\nT computed:\n" << tensor_T_computed.host_view() << "\n"
495
+ << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n"
496
+ << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n";
497
+ }
498
+
499
+ return passed;
500
+ }
501
+ };
502
+
503
+
504
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
505
+
506
+ template <typename ImplicitGemm,
507
+ typename ReferenceOp = Conv2dWithBroadcastReferenceOp<ImplicitGemm>,
508
+ bool AddBroadcastFirst = false>
509
+ bool TestSpecificConv2dWithBroadcast(
510
+ const Conv2dProblemVector & problem_sizes) {
511
+
512
+ bool passed = true;
513
+
514
+ //
515
+ // Testbed object
516
+ //
517
+
518
+ TestbedConv2dWithBroadcast<ImplicitGemm, ReferenceOp, AddBroadcastFirst> testbed;
519
+
520
+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
521
+ for(auto conv_problem : problem_sizes) {
522
+
523
+ //
524
+ // Test
525
+ //
526
+
527
+ // test mode = xcross
528
+ passed = testbed.run(
529
+ conv_problem,
530
+ cutlass::conv::SplitKMode::kSerial);
531
+
532
+ if (!passed) {
533
+ return false;
534
+ }
535
+
536
+ // test mode = convolution
537
+ passed = testbed.run(
538
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
539
+ cutlass::conv::SplitKMode::kSerial);
540
+
541
+ if (!passed) {
542
+ return false;
543
+ }
544
+ }
545
+
546
+ return true;
547
+ }
548
+
549
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////
550
+ // TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
551
+ // TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
552
+ // Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
553
+ // (conv_blacklist_sizes)
554
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
555
+ template <typename ImplicitGemm,
556
+ typename ReferenceOp = Conv2dWithBroadcastReferenceOp<ImplicitGemm>,
557
+ bool AddBroadcastFirst = false,
558
+ bool TestSplitK = true
559
+ >
560
+ bool TestAllConv2dWithBroadcast(
561
+ const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(),
562
+ const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) {
563
+
564
+ bool passed = true;
565
+
566
+ //
567
+ // Testbed object
568
+ //
569
+
570
+ TestbedConv2dWithBroadcast<ImplicitGemm, ReferenceOp, AddBroadcastFirst> testbed;
571
+
572
+ //
573
+ // Get conv problem sizes to run conv operator
574
+ //
575
+ TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
576
+
577
+ // Vector of conv2d problem sizes to avoid duplicate runs
578
+ Conv2dProblemVector conv_tested_sizes;
579
+
580
+ Conv2dProblemVector const *problem_vectors[] = {
581
+ &conv_test_sizes, // run user specified sizes
582
+ &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
583
+ &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
584
+ #if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
585
+ &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
586
+ #endif
587
+ };
588
+
589
+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
590
+ for (Conv2dProblemVector const * problem_vector : problem_vectors) {
591
+
592
+ // Run conv testbed on default convolution sizes
593
+ for(auto conv_problem : *problem_vector) {
594
+
595
+ // Skip blacklist and avoid duplicate problem sizes
596
+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
597
+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
598
+ continue;
599
+ }
600
+
601
+ //
602
+ // Procedurally disable certain cases
603
+ //
604
+
605
+ // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
606
+ if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
607
+ ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
608
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
609
+ cutlass::conv::StrideSupport::kUnity)) {
610
+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
611
+ continue;
612
+ }
613
+ }
614
+
615
+ #if 0 // relax restrictions on analytic strided dgrad
616
+ // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2}
617
+ if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
618
+ ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
619
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
620
+ cutlass::conv::StrideSupport::kStrided)) {
621
+ if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
622
+ continue;
623
+ }
624
+ }
625
+ #endif
626
+
627
+ //
628
+ // Test
629
+ //
630
+ // push back tested problem size to avoid re-running duplicates
631
+ conv_tested_sizes.push_back(conv_problem);
632
+
633
+ // test mode = xcross
634
+ passed = testbed.run(
635
+ conv_problem,
636
+ cutlass::conv::SplitKMode::kSerial);
637
+
638
+ if (!passed) {
639
+ return false;
640
+ }
641
+
642
+ // test mode = convolution
643
+ passed = testbed.run(
644
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
645
+ cutlass::conv::SplitKMode::kSerial);
646
+
647
+ if (!passed) {
648
+ return false;
649
+ }
650
+ }
651
+ }
652
+
653
+ // CUTLASS DGRAD's *strided* specialization does not support split-k mode
654
+ if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
655
+ ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
656
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
657
+ cutlass::conv::StrideSupport::kStrided)) {
658
+
659
+ passed = testbed.run(
660
+ cutlass::conv::Conv2dProblemSize(
661
+ {1, 56, 56, 8}, // input size (NHWC)
662
+ {8, 1, 1, 8}, // filter size (KRSC)
663
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
664
+ {2, 2}, // stride (stride_h, stride_w)
665
+ {1, 1}), // dilation (dilation_h, dilation_w)
666
+ cutlass::conv::SplitKMode::kSerial,
667
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0),
668
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0));
669
+
670
+ if (!passed) {
671
+ return false;
672
+ }
673
+
674
+ return passed;
675
+ }
676
+
677
+ if (!TestSplitK)
678
+ return passed;
679
+
680
+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
681
+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
682
+ // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
683
+ // alpha and beta for local testing, but only runs one value for alpha and beta.
684
+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
685
+ {1, 17, 11, 288}, // input size (NHWC)
686
+ {160, 3, 3, 288}, // filter size (KRSC)
687
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
688
+ {1, 1}, // stride (stride_h, stride_w)
689
+ {1, 1} // dilation (dilation_h, dilation_w)
690
+ );
691
+
692
+ cutlass::conv::SplitKMode split_k_modes [] = {
693
+ cutlass::conv::SplitKMode::kSerial
694
+ };
695
+
696
+ int split_k_slices[] = {
697
+ 1, 2, 3, 4, 201
698
+ };
699
+
700
+ double problem_alpha[] = {
701
+ 2.0
702
+ };
703
+
704
+ double problem_beta[] = {
705
+ 2.0
706
+ };
707
+
708
+ for (auto split_k_mode : split_k_modes) {
709
+ for (auto split_k_slice : split_k_slices) {
710
+ for (auto alpha : problem_alpha) {
711
+ for (auto beta : problem_beta) {
712
+
713
+ passed = testbed.run(
714
+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
715
+ split_k_mode,
716
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
717
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
718
+
719
+ if (!passed) {
720
+ return false;
721
+ }
722
+ }
723
+ }
724
+ }
725
+ }
726
+
727
+ return passed;
728
+ }
729
+
730
+ /////////////////////////////////////////////////////////////////////////////////////////////////
731
+
732
+ } // namespace device
733
+ } // namespace conv
734
+ } // namespace test
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implicit GEMM testbed
33
+ */
34
+ #pragma once
35
+
36
+ #include <fstream>
37
+
38
+ #include "../../common/cutlass_unit_test.h"
39
+ #include "cutlass/cutlass.h"
40
+
41
+ #include "cutlass/conv/device/implicit_gemm_convolution.h"
42
+ #include "cutlass/reduction/device/tensor_reduce.h"
43
+ #include "cutlass/reduction/device/reduce_split_k.h"
44
+ #include "cutlass/reduction/thread/reduction_operators.h"
45
+
46
+ #include "conv2d_problems.h"
47
+
48
+ #include "cutlass/util/host_tensor.h"
49
+ #include "cutlass/util/reference/host/tensor_fill.h"
50
+ #include "cutlass/util/reference/device/tensor_compare.h"
51
+ #include "cutlass/util/reference/host/tensor_compare.h"
52
+
53
+ #include "cutlass/util/reference/host/convolution.h"
54
+ #include "cutlass/util/reference/device/convolution.h"
55
+
56
+ #include "cutlass/core_io.h"
57
+ #include "cutlass/util/tensor_view_io.h"
58
+
59
+ #include "../cache_testbed_output.h"
60
+
61
+ namespace test {
62
+ namespace conv {
63
+ namespace device {
64
+
65
+ template <typename Conv2d>
66
+ class TestbedConv2dWithReduction {
67
+ public:
68
+
69
+ using ElementA = typename Conv2d::ElementA;
70
+ using LayoutA = typename Conv2d::LayoutA;
71
+ using ElementB = typename Conv2d::ElementB;
72
+ using LayoutB = typename Conv2d::LayoutB;
73
+ using ElementC = typename Conv2d::ElementC;
74
+ using LayoutC = typename Conv2d::LayoutC;
75
+ using ElementAccumulator = typename Conv2d::ElementAccumulator;
76
+ using ElementCompute = typename Conv2d::ElementCompute;
77
+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
78
+ using ElementT = typename EpilogueOutputOp::ElementTensor;
79
+
80
+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
81
+
82
+ public:
83
+
84
+ /// Initialization
85
+ cutlass::Distribution::Kind init_A;
86
+ cutlass::Distribution::Kind init_B;
87
+ cutlass::Distribution::Kind init_C;
88
+ uint64_t seed;
89
+
90
+ cutlass::HostTensor<ElementA, LayoutA> tensor_A;
91
+ cutlass::HostTensor<ElementB, LayoutB> tensor_B;
92
+ cutlass::HostTensor<ElementC, LayoutC> tensor_C;
93
+
94
+ cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Reduction;
95
+ cutlass::HostTensor<ElementT, cutlass::layout::RowMajor> tensor_Tensor;
96
+ cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Final_Reduction;
97
+
98
+ cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
99
+ cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
100
+
101
+ public:
102
+
103
+ TestbedConv2dWithReduction(
104
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
105
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
106
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
107
+ uint64_t seed_ = 2080
108
+ ):
109
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
110
+
111
+ }
112
+
113
+ /// Helper to initialize a tensor view
114
+ template <typename Element, typename Layout>
115
+ void initialize_tensor(
116
+ cutlass::TensorView<Element, Layout> view,
117
+ cutlass::Distribution::Kind dist_kind,
118
+ uint64_t seed) {
119
+
120
+ if (dist_kind == cutlass::Distribution::Uniform) {
121
+
122
+ int scope = 2;
123
+
124
+ cutlass::reference::host::TensorFillRandomUniform(
125
+ view, seed, scope, -scope, 0);
126
+ }
127
+ else if (dist_kind == cutlass::Distribution::Identity) {
128
+
129
+ cutlass::reference::host::TensorFillIdentity(view);
130
+ }
131
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
132
+
133
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
134
+ }
135
+ else if (dist_kind == cutlass::Distribution::Sequential) {
136
+
137
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
138
+ }
139
+ else {
140
+ }
141
+ }
142
+
143
+ void initialize(
144
+ cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
145
+
146
+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
147
+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
148
+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
149
+
150
+ tensor_Reduction.resize({
151
+ 1,
152
+ 1,
153
+ (problem_size.N * problem_size.P * problem_size.Q - 1 + Conv2d::ThreadblockShape::kM) / Conv2d::ThreadblockShape::kM,
154
+ (problem_size.K)
155
+ });
156
+
157
+ tensor_Final_Reduction.resize({
158
+ 1,
159
+ 1,
160
+ 1,
161
+ (problem_size.K)
162
+ });
163
+
164
+ tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K});
165
+
166
+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
167
+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
168
+
169
+ initialize_tensor(tensor_A.host_view(), init_A, seed);
170
+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
171
+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
172
+
173
+ tensor_A.sync_device();
174
+ tensor_B.sync_device();
175
+ tensor_C.sync_device();
176
+ tensor_D_computed.sync_device();
177
+ tensor_D_reference.sync_device();
178
+ }
179
+
180
+ bool sufficient() const {
181
+ //
182
+ // Determine SMEM requirements and waive if not satisfied
183
+ //
184
+
185
+ size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage);
186
+
187
+ cudaDeviceProp properties;
188
+ int device_idx;
189
+ cudaError_t result = cudaGetDevice(&device_idx);
190
+
191
+ if (result != cudaSuccess) {
192
+ throw std::runtime_error("cudaGetDevice() API call failed.");
193
+ }
194
+
195
+ result = cudaGetDeviceProperties(&properties, device_idx);
196
+
197
+ if (result != cudaSuccess) {
198
+ throw std::runtime_error("cudaGetDeviceProperties() failed");
199
+ }
200
+
201
+ if (properties.sharedMemPerBlockOptin < smem_size) {
202
+ return false;
203
+ }
204
+
205
+ return true;
206
+ }
207
+
208
+ /// Executes one test
209
+ bool run(
210
+ cutlass::conv::Conv2dProblemSize const &problem_size,
211
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
212
+ ElementCompute alpha = ElementCompute(1),
213
+ ElementCompute beta = ElementCompute(0)) {
214
+
215
+ // Waive test if insufficient CUDA device
216
+ if (!sufficient()) {
217
+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
218
+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
219
+ }
220
+ return true;
221
+ }
222
+
223
+ #if 0 //display conv2d problem size for debugging
224
+ std::cout << problem_size << std::endl
225
+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
226
+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
227
+ << std::endl;
228
+ #endif
229
+
230
+ initialize(problem_size);
231
+
232
+ // configure the operator
233
+ Conv2d conv2d_op;
234
+
235
+ typename Conv2d::Arguments conv2d_args(
236
+ problem_size,
237
+ tensor_A.device_ref(),
238
+ tensor_B.device_ref(),
239
+ tensor_C.device_ref(),
240
+ tensor_D_computed.device_ref(),
241
+ {alpha, beta},
242
+ split_k_mode,
243
+ tensor_Reduction.device_data(),
244
+ tensor_Tensor.device_data(),
245
+ static_cast<int>(tensor_Reduction.stride()[0]),
246
+ static_cast<int>(tensor_Tensor.stride()[0])
247
+ );
248
+
249
+ // find workspace requirement for parallel split-k reduction
250
+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
251
+
252
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
253
+
254
+ cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get());
255
+
256
+ if (status != cutlass::Status::kSuccess) {
257
+ cudaError_t error = cudaGetLastError();
258
+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
259
+ return true;
260
+ }
261
+
262
+ // conv2d operation with parallel split-k-mode
263
+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
264
+
265
+ // conv2d output is written to workspace in global memory
266
+ conv2d_args.ref_D.reset(reinterpret_cast<ElementC*>(workspace.get()));
267
+ // accumulate mma for each cta in k-dimension (1.0 * A * B)
268
+ conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)};
269
+ // update conv2d operator arguments
270
+ status = conv2d_op.update(conv2d_args, workspace.get());
271
+ }
272
+
273
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
274
+ if (status != cutlass::Status::kSuccess) {
275
+ return false;
276
+ }
277
+
278
+ // run conv2d operator
279
+ status = conv2d_op();
280
+
281
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
282
+ if (status != cutlass::Status::kSuccess) {
283
+ return false;
284
+ }
285
+
286
+ bool passed = false;
287
+
288
+ cudaError_t result = cudaDeviceSynchronize();
289
+ EXPECT_EQ(result, cudaSuccess) << " device reference error: "
290
+ << cudaGetErrorString(result);
291
+
292
+ // Final reduction over the partial reduction tensor
293
+ using Functor = cutlass::plus<ElementAccumulator>;
294
+ using TensorReduction = cutlass::reduction::device::TensorReduction<
295
+ ElementAccumulator,
296
+ ElementAccumulator,
297
+ LayoutC,
298
+ Functor,
299
+ 8,
300
+ ElementAccumulator
301
+ >;
302
+
303
+ TensorReduction reduction(tensor_Reduction.extent(), 2);
304
+
305
+ cutlass::DeviceAllocation<uint8_t> reduction_device_workspace(reduction.workspace_size());
306
+
307
+ status = reduction.reduce(
308
+ tensor_Final_Reduction.device_ref(),
309
+ tensor_Reduction.device_ref(),
310
+ reduction_device_workspace.get(),
311
+ ElementAccumulator());
312
+
313
+ EXPECT_EQ(status, cutlass::Status::kSuccess);
314
+ EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess);
315
+
316
+ //
317
+ // Reference check
318
+ //
319
+
320
+ tensor_D_computed.sync_host();
321
+
322
+ #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
323
+
324
+ cutlass::reference::device::Conv2d<
325
+ ElementA,
326
+ LayoutA,
327
+ ElementB,
328
+ LayoutB,
329
+ ElementC,
330
+ LayoutC,
331
+ ElementCompute,
332
+ ElementAccumulator
333
+ >(
334
+ kConvolutionalOperator,
335
+ problem_size,
336
+ tensor_A.device_ref(),
337
+ tensor_B.device_ref(),
338
+ tensor_C.device_ref(),
339
+ tensor_D_reference.device_ref(),
340
+ alpha,
341
+ beta);
342
+
343
+ // sync host (copy device data to host) for dumping error output in case of mismatches
344
+ tensor_D_reference.sync_host();
345
+
346
+ #else
347
+
348
+ cutlass::reference::host::Conv2d<
349
+ ElementA,
350
+ LayoutA,
351
+ ElementB,
352
+ LayoutB,
353
+ ElementC,
354
+ LayoutC,
355
+ ElementCompute,
356
+ ElementAccumulator
357
+ >(
358
+ kConvolutionalOperator,
359
+ problem_size,
360
+ tensor_A.host_ref(),
361
+ tensor_B.host_ref(),
362
+ tensor_C.host_ref(),
363
+ tensor_D_reference.host_ref(),
364
+ alpha,
365
+ beta);
366
+
367
+ #endif
368
+
369
+ passed = cutlass::reference::host::TensorEquals(
370
+ tensor_D_computed.host_view(),
371
+ tensor_D_reference.host_view());
372
+
373
+ EXPECT_TRUE(passed);
374
+
375
+ //
376
+ // Reference check on reduction results
377
+ //
378
+
379
+ tensor_Reduction.sync_host();
380
+ tensor_Final_Reduction.sync_host();
381
+
382
+ // compute backwards for reduction results
383
+ cutlass::HostTensor<ElementAccumulator, LayoutC> reference_Reduction;
384
+ reference_Reduction.resize({
385
+ 1,
386
+ 1,
387
+ 1,
388
+ (problem_size.K)
389
+ });
390
+
391
+ for (int k = 0; k < problem_size.K; ++k) {
392
+ ElementAccumulator reduced_value = ElementAccumulator();
393
+ for (int n = 0; n < problem_size.N; ++n) {
394
+ for (int p = 0; p < problem_size.P; ++p) {
395
+ for (int q = 0; q < problem_size.Q; ++q) {
396
+ reduced_value += tensor_D_reference.at({n, p, q, k});
397
+ }
398
+ }
399
+ }
400
+ reference_Reduction.at({0, 0, 0, k}) = reduced_value;
401
+ }
402
+
403
+ passed = cutlass::reference::host::TensorEquals(
404
+ tensor_Final_Reduction.host_view(),
405
+ reference_Reduction.host_view()
406
+ );
407
+
408
+ EXPECT_TRUE(passed);
409
+
410
+ if (!passed) {
411
+ std::stringstream fname;
412
+
413
+ fname << "error_Conv2d_ImplicitGemm_device_"
414
+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
415
+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
416
+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_"))
417
+ << "nhwc_"
418
+ << problem_size.N << "x"
419
+ << problem_size.H << "x"
420
+ << problem_size.W << "x"
421
+ << problem_size.C
422
+ << "_krsc_"
423
+ << problem_size.K << "x"
424
+ << problem_size.R << "x"
425
+ << problem_size.S << "x"
426
+ << problem_size.C
427
+ << "_padding_"
428
+ << problem_size.pad_h << "x"
429
+ << problem_size.pad_w
430
+ << "_stride_"
431
+ << problem_size.stride_h << "x"
432
+ << problem_size.stride_w
433
+ << "_dilation_"
434
+ << problem_size.dilation_h << "x"
435
+ << problem_size.dilation_w << "_"
436
+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_")
437
+ << Conv2d::ThreadblockShape::kM << "x"
438
+ << Conv2d::ThreadblockShape::kN << "x"
439
+ << Conv2d::ThreadblockShape::kK << "_"
440
+ << Conv2d::WarpShape::kM << "x"
441
+ << Conv2d::WarpShape::kN << "x"
442
+ << Conv2d::WarpShape::kK << ".txt";
443
+
444
+ std::cout << fname.str() << std::endl;
445
+
446
+ std::ofstream results(fname.str());
447
+
448
+ results << problem_size << std::endl;
449
+
450
+ results
451
+ << "\nA:\n" << tensor_A.host_view() << "\n"
452
+ << "\nB:\n" << tensor_B.host_view() << "\n"
453
+ << "\nC:\n" << tensor_C.host_view() << "\n"
454
+ << "\nD reference:\n" << tensor_D_reference.host_view() << "\n"
455
+ << "\nD computed:\n" << tensor_D_computed.host_view() << "\n"
456
+ << "\nreduction reference:\n" << reference_Reduction.host_view() << "\n"
457
+ << "\nreduction computed:\n" << tensor_Reduction.host_view() << "\n";
458
+ }
459
+
460
+ return passed;
461
+ }
462
+ };
463
+
464
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////
465
+ // TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
466
+ // TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
467
+ // Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
468
+ // (conv_blacklist_sizes)
469
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
470
+ template <typename ImplicitGemm>
471
+ bool TestAllConv2dWithReduction(
472
+ const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(),
473
+ const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) {
474
+
475
+ bool passed = true;
476
+
477
+ //
478
+ // Testbed object
479
+ //
480
+
481
+ TestbedConv2dWithReduction<ImplicitGemm> testbed;
482
+
483
+ //
484
+ // Get conv problem sizes to run conv operator
485
+ //
486
+ TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
487
+
488
+ // Vector of conv2d problem sizes to avoid duplicate runs
489
+ Conv2dProblemVector conv_tested_sizes;
490
+
491
+ Conv2dProblemVector const *problem_vectors[] = {
492
+ &conv_test_sizes, // run user specified sizes
493
+ &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes
494
+ &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes
495
+ #if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
496
+ &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled
497
+ #endif
498
+ };
499
+
500
+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
501
+ for (Conv2dProblemVector const * problem_vector : problem_vectors) {
502
+
503
+ // Run conv testbed on default convolution sizes
504
+ for(auto conv_problem : *problem_vector) {
505
+
506
+ // Skip blacklist and avoid duplicate problem sizes
507
+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
508
+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
509
+ continue;
510
+ }
511
+
512
+ //
513
+ // Procedurally disable certain cases
514
+ //
515
+
516
+ // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
517
+ if ((ImplicitGemm::kConvolutionalOperator ==
518
+ cutlass::conv::Operator::kDgrad) &&
519
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
520
+ cutlass::conv::StrideSupport::kUnity)) {
521
+ if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
522
+ continue;
523
+ }
524
+ }
525
+
526
+ #if 0 // relax restrictions on analytic strided dgrad
527
+ // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2}
528
+ if ((ImplicitGemm::kConvolutionalOperator ==
529
+ cutlass::conv::Operator::kDgrad) &&
530
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
531
+ cutlass::conv::StrideSupport::kStrided)) {
532
+ if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
533
+ continue;
534
+ }
535
+ }
536
+ #endif
537
+
538
+ //
539
+ // Test
540
+ //
541
+ // push back tested problem size to avoid re-running duplicates
542
+ conv_tested_sizes.push_back(conv_problem);
543
+
544
+ // test mode = xcross
545
+ passed = testbed.run(
546
+ conv_problem,
547
+ cutlass::conv::SplitKMode::kSerial);
548
+
549
+ if (!passed) {
550
+ return false;
551
+ }
552
+
553
+ // test mode = convolution
554
+ passed = testbed.run(
555
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
556
+ cutlass::conv::SplitKMode::kSerial);
557
+
558
+ if (!passed) {
559
+ return false;
560
+ }
561
+ }
562
+ }
563
+
564
+ // CUTLASS DGRAD's *strided* specialization does not support split-k mode
565
+ if ((ImplicitGemm::kConvolutionalOperator ==
566
+ cutlass::conv::Operator::kDgrad) &&
567
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
568
+ cutlass::conv::StrideSupport::kStrided)) {
569
+
570
+ passed = testbed.run(
571
+ cutlass::conv::Conv2dProblemSize(
572
+ {1, 56, 56, 8}, // input size (NHWC)
573
+ {8, 1, 1, 8}, // filter size (KRSC)
574
+ {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _)
575
+ {2, 2}, // stride (stride_h, stride_w)
576
+ {1, 1}), // dilation (dilation_h, dilation_w)
577
+ cutlass::conv::SplitKMode::kSerial,
578
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0),
579
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(2.0));
580
+
581
+ if (!passed) {
582
+ return false;
583
+ }
584
+
585
+ return passed;
586
+ }
587
+
588
+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
589
+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
590
+ // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
591
+ // alpha and beta for local testing, but only runs one value for alpha and beta.
592
+ cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size (
593
+ {1, 17, 11, 288}, // input size (NHWC)
594
+ {160, 3, 3, 288}, // filter size (KRSC)
595
+ {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _)
596
+ {1, 1}, // stride (stride_h, stride_w)
597
+ {1, 1} // dilation (dilation_h, dilation_w)
598
+ );
599
+
600
+ // Parallel SplitK is not tested.
601
+ cutlass::conv::SplitKMode split_k_modes [] = {
602
+ cutlass::conv::SplitKMode::kSerial,
603
+ };
604
+
605
+ int split_k_slices[] = {
606
+ 1, 2, 3, 4, 201
607
+ };
608
+
609
+ double problem_alpha[] = {
610
+ 2.0
611
+ };
612
+
613
+ double problem_beta[] = {
614
+ 2.0
615
+ };
616
+
617
+ for (auto split_k_mode : split_k_modes) {
618
+ for (auto split_k_slice : split_k_slices) {
619
+ for (auto alpha : problem_alpha) {
620
+ for (auto beta : problem_beta) {
621
+
622
+ passed = testbed.run(
623
+ conv2d_split_k_test_size.reset_split_k_slices(split_k_slice),
624
+ split_k_mode,
625
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
626
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
627
+
628
+ if (!passed) {
629
+ return false;
630
+ }
631
+ }
632
+ }
633
+ }
634
+ }
635
+
636
+ return passed;
637
+ }
638
+
639
+ /////////////////////////////////////////////////////////////////////////////////////////////////
640
+
641
+ } // namespace device
642
+ } // namespace conv
643
+ } // namespace test
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implicit GEMM testbed sizes for Conv2d problem
33
+ */
34
+ #pragma once
35
+
36
+ #include "../../common/cutlass_unit_test.h"
37
+
38
+ #include "cutlass/cutlass.h"
39
+
40
+ #include "cutlass/aligned_buffer.h"
41
+ #include "cutlass/numeric_types.h"
42
+ #include "cutlass/layout/matrix.h"
43
+ #include "cutlass/layout/tensor.h"
44
+ #include "cutlass/layout/pitch_linear.h"
45
+ #include "cutlass/core_io.h"
46
+ #include "cutlass/util/host_tensor.h"
47
+ #include "cutlass/util/tensor_view_io.h"
48
+ #include "cutlass/conv/convolution.h"
49
+ #include "cutlass/conv/conv2d_problem_size.h"
50
+ #include "cutlass/conv/conv3d_problem_size.h"
51
+
52
+ namespace test {
53
+ namespace conv {
54
+ namespace device {
55
+
56
+ using Conv3dProblemVector = std::vector<cutlass::conv::Conv3dProblemSize>;
57
+
58
+ ////////////////////////////////////////////////////////////////////////////
59
+ /// Structure TestbedConv3dProblemSizes initializes and holds conv default and
60
+ /// important network sizes
61
+ ////////////////////////////////////////////////////////////////////////////
62
+ struct TestbedConv3dProblemSizes {
63
+
64
+ //
65
+ // Data members
66
+ //
67
+ int minimum_channel_size;
68
+ Conv3dProblemVector conv3d_default_sizes;
69
+ Conv3dProblemVector conv3d_vnet_medical_sizes;
70
+
71
+ //
72
+ // Methods
73
+ //
74
+ /// Default ctor
75
+ TestbedConv3dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) {
76
+
77
+ initialize_conv3d_default_sizes();
78
+ initialize_conv3d_vnet_medical_sizes(conv3d_vnet_medical_sizes, 1 /*batch-size*/);
79
+
80
+ filter_all();
81
+ }
82
+
83
+ /// Eliminates some illegal cases
84
+ void filter_all() {
85
+
86
+ Conv3dProblemVector *problems_vectors[] = {
87
+ &conv3d_default_sizes,
88
+ &conv3d_vnet_medical_sizes
89
+ };
90
+
91
+ for (Conv3dProblemVector *problems : problems_vectors) {
92
+ Conv3dProblemVector filtered;
93
+
94
+ for (cutlass::conv::Conv3dProblemSize const & problem : *problems) {
95
+ if (!(problem.C % minimum_channel_size)) {
96
+ filtered.push_back(problem);
97
+ }
98
+ }
99
+
100
+ *problems = filtered;
101
+ }
102
+ }
103
+
104
+ // Add a few standard convolution problem sizes
105
+ void initialize_conv3d_default_sizes() {
106
+
107
+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
108
+ {1, 1, 3, 3, minimum_channel_size}, // input size (NDHWC)
109
+ {8, 1, 1, 1, minimum_channel_size}, // filter size (KTRSC)
110
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
111
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
112
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
113
+ ));
114
+
115
+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
116
+ {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC)
117
+ {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC)
118
+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
119
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
120
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
121
+ ));
122
+
123
+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
124
+ {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC)
125
+ {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC)
126
+ CUTLASS_STL_NAMESPACE::make_tuple(
127
+ cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w)
128
+ cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w)
129
+ ),
130
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
131
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
132
+ ));
133
+
134
+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
135
+ {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC)
136
+ {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC)
137
+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
138
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
139
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
140
+ ));
141
+
142
+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
143
+ {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC)
144
+ {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC)
145
+ CUTLASS_STL_NAMESPACE::make_tuple(
146
+ cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w)
147
+ cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w)
148
+ ),
149
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
150
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
151
+ ));
152
+
153
+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
154
+ {1, 16, 16, 16, minimum_channel_size}, // input size (NDHWC)
155
+ {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC)
156
+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
157
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
158
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
159
+ ));
160
+
161
+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
162
+ {1, 1, 15, 19, 160}, // input size (NDHWC)
163
+ {224, 1, 3, 6, 160}, // filter size (KTRSC)
164
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
165
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
166
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
167
+ ));
168
+
169
+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
170
+ {1, 2, 1, 1, minimum_channel_size}, // input size (NDHWC)
171
+ {8, 2, 1, 1, minimum_channel_size}, // filter size (KTRSC)
172
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
173
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
174
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
175
+ ));
176
+
177
+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
178
+ {1, 1, 7, 7, minimum_channel_size}, // input size (NDHWC)
179
+ {16, 1, 3, 3, minimum_channel_size}, // filter size (KTRSC)
180
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
181
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
182
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
183
+ ));
184
+
185
+
186
+ conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize(
187
+ {1, 11, 15, 19, 64}, // input size (NDHWC)
188
+ {32, 4, 3, 6, 64}, // filter size (KTRSC)
189
+ cutlass::Coord<3>({2, 1, 3}), // padding (pad_d, pad_h, pad_w)
190
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
191
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
192
+ ));
193
+ }
194
+
195
+ // Add vnet layers to unit testing sizes
196
+ void initialize_conv3d_vnet_medical_sizes(Conv3dProblemVector &conv3d_problem_vector, int batch_size = 1) {
197
+
198
+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
199
+ {batch_size, 32, 32, 32, 16}, // input size (NDHWC)
200
+ {32, 2, 2, 2, 16}, // filter size (KTRSC)
201
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
202
+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w)
203
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
204
+ ));
205
+
206
+
207
+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
208
+ {batch_size, 16, 16, 16, 32}, // input size (NDHWC)
209
+ {32, 3, 3, 3, 32}, // filter size (KTRSC)
210
+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
211
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
212
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
213
+ ));
214
+
215
+
216
+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
217
+ {batch_size, 16, 16, 16, 32}, // input size (NDHWC)
218
+ {64, 2, 2, 2, 32}, // filter size (KTRSC)
219
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
220
+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w)
221
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
222
+ ));
223
+
224
+
225
+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
226
+ {batch_size, 8, 8, 8, 64}, // input size (NDHWC)
227
+ {64, 3, 3, 3, 64}, // filter size (KTRSC)
228
+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
229
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
230
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
231
+ ));
232
+
233
+
234
+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
235
+ {batch_size, 8, 8, 8, 64}, // input size (NDHWC)
236
+ {128, 2, 2, 2, 64}, // filter size (KTRSC)
237
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
238
+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w)
239
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
240
+ ));
241
+
242
+
243
+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
244
+ {batch_size, 4, 4, 4, 128}, // input size (NDHWC)
245
+ {128, 3, 3, 3, 128}, // filter size (KTRSC)
246
+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
247
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
248
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
249
+ ));
250
+
251
+
252
+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
253
+ {batch_size, 8, 8, 8, 128}, // input size (NDHWC)
254
+ {128, 3, 3, 3, 128}, // filter size (KTRSC)
255
+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
256
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
257
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
258
+ ));
259
+
260
+
261
+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
262
+ {batch_size, 16, 16, 16, 64}, // input size (NDHWC)
263
+ {64, 3, 3, 3, 64}, // filter size (KTRSC)
264
+ cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w)
265
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
266
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
267
+ ));
268
+
269
+
270
+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
271
+ {batch_size, 32, 32, 32, 16}, // input size (NDHWC)
272
+ {64, 2, 2, 2, 16}, // filter size (KTRSC)
273
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
274
+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w)
275
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
276
+ ));
277
+
278
+
279
+ conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize(
280
+ {batch_size, 16, 16, 16, 32}, // input size (NDHWC)
281
+ {128, 2, 2, 2, 32}, // filter size (KTRSC)
282
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
283
+ cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w)
284
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
285
+ ));
286
+
287
+ }
288
+
289
+ };
290
+
291
+ } // namespace device
292
+ } // namespace conv
293
+ } // namespace test
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implicit GEMM testbed
33
+ */
34
+ #pragma once
35
+
36
+ #include <fstream>
37
+
38
+ #include "../../common/cutlass_unit_test.h"
39
+ #include "cutlass/cutlass.h"
40
+
41
+
42
+ #include "cutlass/conv/device/implicit_gemm_convolution.h"
43
+ #include "cutlass/reduction/device/reduce_split_k.h"
44
+ #include "cutlass/reduction/thread/reduction_operators.h"
45
+
46
+ #include "cutlass/util/reference/host/tensor_fill.h"
47
+
48
+ #include "cutlass/util/reference/host/convolution.h"
49
+
50
+ #include "cutlass/util/reference/host/tensor_compare.h"
51
+
52
+ #include "cutlass/util/reference/device/convolution.h"
53
+ #include "cutlass/util/reference/device/tensor_compare.h"
54
+
55
+ #include "conv3d_problems.h"
56
+ #include "cutlass/core_io.h"
57
+
58
+ #include "../cache_testbed_output.h"
59
+
60
+ namespace test {
61
+ namespace conv {
62
+ namespace device {
63
+
64
+ template <typename Conv3d>
65
+ class TestbedConv3d {
66
+ public:
67
+
68
+ using ElementA = typename Conv3d::ElementA;
69
+ using LayoutA = typename Conv3d::LayoutA;
70
+ using ElementB = typename Conv3d::ElementB;
71
+ using LayoutB = typename Conv3d::LayoutB;
72
+ using ElementC = typename Conv3d::ElementC;
73
+ using LayoutC = typename Conv3d::LayoutC;
74
+ using ElementAccumulator = typename Conv3d::ElementAccumulator;
75
+ using ElementCompute = typename Conv3d::ElementCompute;
76
+ using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp;
77
+
78
+ static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator;
79
+
80
+ /// Reduction kernel
81
+ using ReductionOp = cutlass::reduction::thread::ReduceAdd<
82
+ ElementAccumulator,
83
+ typename EpilogueOutputOp::ElementAccumulator,
84
+ EpilogueOutputOp::kCount
85
+ >;
86
+
87
+ using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
88
+ cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
89
+ EpilogueOutputOp,
90
+ ReductionOp
91
+ >;
92
+
93
+ using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
94
+ using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
95
+
96
+ public:
97
+
98
+ /// Initialization
99
+ cutlass::Distribution::Kind init_A;
100
+ cutlass::Distribution::Kind init_B;
101
+ cutlass::Distribution::Kind init_C;
102
+ uint64_t seed;
103
+
104
+ cutlass::HostTensor<ElementA, LayoutA> tensor_A;
105
+ cutlass::HostTensor<ElementB, LayoutB> tensor_B;
106
+ cutlass::HostTensor<ElementC, LayoutC> tensor_C;
107
+ cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
108
+ cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
109
+
110
+ public:
111
+
112
+ TestbedConv3d(
113
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
114
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
115
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
116
+ uint64_t seed_ = 2080
117
+ ):
118
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
119
+
120
+ }
121
+
122
+ /// Helper to initialize a tensor view
123
+ template <typename Element, typename Layout>
124
+ void initialize_tensor(
125
+ cutlass::TensorView<Element, Layout> view,
126
+ cutlass::Distribution::Kind dist_kind,
127
+ uint64_t seed) {
128
+
129
+ if (dist_kind == cutlass::Distribution::Uniform) {
130
+
131
+ int scope;
132
+ int bits = cutlass::sizeof_bits<Element>::value;
133
+
134
+ if (bits <= 8) {
135
+ scope = 2;
136
+ }
137
+ else if (bits == 16) {
138
+ scope = 4;
139
+ }
140
+ else {
141
+ scope = 8;
142
+ }
143
+ cutlass::reference::host::TensorFillRandomUniform(
144
+ view, seed, scope, -scope, 0);
145
+ }
146
+ else if (dist_kind == cutlass::Distribution::Identity) {
147
+
148
+ cutlass::reference::host::TensorFillIdentity(view);
149
+ }
150
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
151
+
152
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
153
+ }
154
+ else if (dist_kind == cutlass::Distribution::Sequential) {
155
+
156
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
157
+ }
158
+ else {
159
+ }
160
+ }
161
+
162
+ void initialize(
163
+ cutlass::conv::Conv3dProblemSize const &problem_size, uint64_t seed = 2019) {
164
+
165
+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
166
+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
167
+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
168
+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
169
+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
170
+
171
+ initialize_tensor(tensor_A.host_view(), init_A, seed);
172
+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
173
+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
174
+
175
+ tensor_A.sync_device();
176
+ tensor_B.sync_device();
177
+ tensor_C.sync_device();
178
+ tensor_D_computed.sync_device();
179
+ tensor_D_reference.sync_device();
180
+ }
181
+
182
+ bool sufficient() const {
183
+ //
184
+ // Determine SMEM requirements and waive if not satisfied
185
+ //
186
+
187
+ size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage);
188
+
189
+ cudaDeviceProp properties;
190
+ int device_idx;
191
+ cudaError_t result = cudaGetDevice(&device_idx);
192
+
193
+ if (result != cudaSuccess) {
194
+ throw std::runtime_error("cudaGetDevice() API call failed.");
195
+ }
196
+
197
+ result = cudaGetDeviceProperties(&properties, device_idx);
198
+
199
+ if (result != cudaSuccess) {
200
+ throw std::runtime_error("cudaGetDeviceProperties() failed");
201
+ }
202
+
203
+ if (properties.sharedMemPerBlockOptin < smem_size) {
204
+ return false;
205
+ }
206
+
207
+ return true;
208
+ }
209
+
210
+
211
+ /// Executes one test
212
+ bool run(
213
+ cutlass::conv::Conv3dProblemSize const &problem_size,
214
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
215
+ ElementCompute alpha = ElementCompute(1),
216
+ ElementCompute beta = ElementCompute()) {
217
+
218
+
219
+ // Waive test if insufficient CUDA device
220
+ if (!sufficient()) {
221
+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
222
+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
223
+ }
224
+ return true;
225
+ }
226
+
227
+ #if 0 //display conv2d problem size for debugging
228
+ std::cout << problem_size << std::endl
229
+ << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl
230
+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
231
+ << std::endl;
232
+ #endif
233
+
234
+ initialize(problem_size);
235
+
236
+ // configure the operator
237
+ Conv3d conv3d_op;
238
+
239
+ typename Conv3d::Arguments conv3d_args(
240
+ problem_size,
241
+ tensor_A.device_ref(),
242
+ tensor_B.device_ref(),
243
+ tensor_C.device_ref(),
244
+ tensor_D_computed.device_ref(),
245
+ {alpha, beta},
246
+ split_k_mode
247
+ );
248
+
249
+ cutlass::Status status = conv3d_op.can_implement(conv3d_args);
250
+ if (status != cutlass::Status::kSuccess) {
251
+ std::cerr << "can_implement failed for the given problem_size: \n";
252
+ return false;
253
+ }
254
+
255
+ // find workspace requirement for parallel split-k reduction
256
+ size_t workspace_size = Conv3d::get_workspace_size(conv3d_args);
257
+
258
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
259
+
260
+ status = conv3d_op.initialize(conv3d_args, workspace.get());
261
+
262
+ if (status != cutlass::Status::kSuccess) {
263
+ cudaError_t error = cudaGetLastError();
264
+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
265
+ return true;
266
+ }
267
+
268
+ // conv3d operation with parallel split-k-mode
269
+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
270
+
271
+ // conv3d output is written to workspace in global memory
272
+ conv3d_args.ref_D.reset(reinterpret_cast<ElementAccumulator*>(workspace.get()));
273
+ // accumulate mma for each cta in k-dimension (1.0 * A * B)
274
+ conv3d_args.output_op = {1.0, 0.0};
275
+ // update conv3d operator arguments
276
+ status = conv3d_op.update(conv3d_args, workspace.get());
277
+ }
278
+
279
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
280
+ if (status != cutlass::Status::kSuccess) {
281
+ return false;
282
+ }
283
+
284
+ // run conv3d operator
285
+ status = conv3d_op();
286
+
287
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
288
+ if (status != cutlass::Status::kSuccess) {
289
+ return false;
290
+ }
291
+
292
+ if (split_k_mode == cutlass::conv::SplitKMode::kParallel) {
293
+
294
+ // configure parallel reduction operator
295
+ ReductionDevice reduction_op;
296
+
297
+ typename ReductionDevice::Arguments reduction_args(
298
+ cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
299
+ problem_size.split_k_slices,
300
+ cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
301
+ {
302
+ reinterpret_cast<ElementAccumulator*> (workspace.get()),
303
+ ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx])
304
+ },
305
+ {
306
+ tensor_D_computed.device_data(),
307
+ ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx])
308
+ },
309
+ {
310
+ tensor_C.device_data(),
311
+ ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx])
312
+ },
313
+ // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C
314
+ {alpha, beta}
315
+ );
316
+
317
+ status = reduction_op.initialize(reduction_args, nullptr);
318
+
319
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
320
+ if (status != cutlass::Status::kSuccess) {
321
+ return false;
322
+ }
323
+
324
+ // run prallel reduction kernel
325
+ status = reduction_op();
326
+
327
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
328
+ if (status != cutlass::Status::kSuccess) {
329
+ return false;
330
+ }
331
+ }
332
+ bool passed = false;
333
+
334
+ cudaError_t result = cudaDeviceSynchronize();
335
+ EXPECT_EQ(result, cudaSuccess) << " device reference error: "
336
+ << cudaGetErrorString(result);
337
+
338
+ tensor_D_computed.sync_host();
339
+
340
+ //
341
+ // Reference check - support caching results
342
+ //
343
+
344
+ CachedTestKey cached_test_key = CreateCachedConv3dTestKey<
345
+ ElementA, LayoutA,
346
+ ElementB, LayoutB,
347
+ ElementC, LayoutC,
348
+ ElementAccumulator,
349
+ ElementCompute
350
+ >(
351
+ kConvolutionalOperator,
352
+ problem_size,
353
+ alpha,
354
+ beta,
355
+ tensor_A.host_view(),
356
+ tensor_B.host_view(),
357
+ tensor_C.host_view()
358
+ );
359
+
360
+ //
361
+ // Look for the cached key
362
+ //
363
+
364
+ bool cached_result_loaded = false;
365
+ CachedTestResult cached_test_result;
366
+
367
+ std::string conv3d_result_cache_name =
368
+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
369
+
370
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
371
+
372
+ CachedTestResultListing cached_results(conv3d_result_cache_name);
373
+
374
+ auto cached = cached_results.find(cached_test_key);
375
+
376
+ cached_result_loaded = cached.first;
377
+ if (cached_result_loaded) {
378
+ cached_test_result = cached.second;
379
+ }
380
+ }
381
+
382
+ if (!cached_result_loaded) {
383
+
384
+ #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
385
+
386
+ cutlass::reference::device::Conv3d<
387
+ ElementA,
388
+ LayoutA,
389
+ ElementB,
390
+ LayoutB,
391
+ ElementC,
392
+ LayoutC,
393
+ ElementAccumulator,
394
+ ElementCompute
395
+ >(
396
+ kConvolutionalOperator,
397
+ problem_size,
398
+ tensor_A.device_ref(),
399
+ tensor_B.device_ref(),
400
+ tensor_C.device_ref(),
401
+ tensor_D_reference.device_ref(),
402
+ alpha,
403
+ beta
404
+ );
405
+
406
+ // sync host (copy device data to host) for dumping error output in case of mismatches
407
+ tensor_D_reference.sync_host();
408
+
409
+ #else
410
+ cutlass::reference::host::Conv3d<
411
+ ElementA,
412
+ LayoutA,
413
+ ElementB,
414
+ LayoutB,
415
+ ElementC,
416
+ LayoutC,
417
+ ElementAccumulator,
418
+ ElementCompute
419
+ >(
420
+ kConvolutionalOperator,
421
+ problem_size,
422
+ tensor_A.host_ref(),
423
+ tensor_B.host_ref(),
424
+ tensor_C.host_ref(),
425
+ tensor_D_reference.host_ref(),
426
+ alpha,
427
+ beta
428
+ );
429
+ #endif
430
+
431
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
432
+
433
+ cached_test_result.D = TensorHash(tensor_D_reference.host_view());
434
+
435
+ CachedTestResultListing cached_results(conv3d_result_cache_name);
436
+
437
+ cached_results.append(cached_test_key, cached_test_result);
438
+ cached_results.write(conv3d_result_cache_name);
439
+ }
440
+ } // if (!cached_result_loaded)
441
+
442
+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
443
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
444
+ passed = (tensor_D_hash == cached_test_result.D);
445
+
446
+ EXPECT_EQ(tensor_D_hash, cached_test_result.D)
447
+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
448
+ }
449
+ else {
450
+
451
+ passed = cutlass::reference::host::TensorEquals(
452
+ tensor_D_computed.host_view(),
453
+ tensor_D_reference.host_view());
454
+ }
455
+
456
+ EXPECT_TRUE(passed);
457
+
458
+ if (!passed) {
459
+ std::stringstream fname;
460
+
461
+ fname << "error_Conv3d_ImplicitGemm_device_"
462
+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
463
+ << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
464
+ (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" :
465
+ (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_")))
466
+ << "ndhwc_"
467
+ << problem_size.N << "x"
468
+ << problem_size.D << "x"
469
+ << problem_size.H << "x"
470
+ << problem_size.W << "x"
471
+ << problem_size.C
472
+ << "_ktrsc_"
473
+ << problem_size.K << "x"
474
+ << problem_size.T << "x"
475
+ << problem_size.R << "x"
476
+ << problem_size.S << "x"
477
+ << problem_size.C
478
+ << "_padding_"
479
+ << problem_size.pad_d << "x"
480
+ << problem_size.pad_h << "x"
481
+ << problem_size.pad_w
482
+ << "_stride_"
483
+ << problem_size.stride_d << "x"
484
+ << problem_size.stride_h << "x"
485
+ << problem_size.stride_w
486
+ << "_dilation_"
487
+ << problem_size.dilation_d << "x"
488
+ << problem_size.dilation_h << "x"
489
+ << problem_size.dilation_w << "_"
490
+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_")
491
+ << Conv3d::ThreadblockShape::kM << "x"
492
+ << Conv3d::ThreadblockShape::kN << "x"
493
+ << Conv3d::ThreadblockShape::kK << "_"
494
+ << Conv3d::WarpShape::kM << "x"
495
+ << Conv3d::WarpShape::kN << "x"
496
+ << Conv3d::WarpShape::kK << ".txt";
497
+
498
+ std::cout << fname.str() << std::endl;
499
+
500
+ std::ofstream results(fname.str());
501
+
502
+ results << problem_size << std::endl;
503
+
504
+ results
505
+ << "\nA:\n" << tensor_A.host_view() << "\n"
506
+ << "\nB:\n" << tensor_B.host_view() << "\n"
507
+ << "\nC:\n" << tensor_C.host_view() << "\n";
508
+
509
+
510
+ results << "\nD reference (hash: " << cached_test_result.D << ")\n";
511
+
512
+ if (!cached_result_loaded) {
513
+ results
514
+ << tensor_D_reference.host_view() << "\n";
515
+ }
516
+
517
+ results
518
+ << "\nD computed (hash: " << tensor_D_hash << ")\n"
519
+ << tensor_D_computed.host_view() << "\n";
520
+
521
+ }
522
+
523
+ return passed;
524
+ }
525
+
526
+ };
527
+
528
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////
529
+ // TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
530
+ // TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes
531
+ // Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
532
+ // (conv_blacklist_sizes)
533
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
534
+
535
+ template <typename ImplicitGemm>
536
+ bool TestAllConv3d(
537
+ const Conv3dProblemVector & conv_test_sizes = Conv3dProblemVector(),
538
+ const Conv3dProblemVector & conv_blacklist_sizes = Conv3dProblemVector()) {
539
+
540
+ bool passed = true;
541
+
542
+ //
543
+ // Testbed object
544
+ //
545
+
546
+ //TestbedConv3d<ImplicitGemm> testbed(cutlass::Distribution::Sequential, cutlass::Distribution::Sequential, cutlass::Distribution::Sequential);
547
+ TestbedConv3d<ImplicitGemm> testbed;
548
+
549
+ //
550
+ // Get conv problem sizes to run conv operator
551
+ //
552
+ TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
553
+
554
+ // Vector of conv3d problem sizes to avoid duplicate runs
555
+ Conv3dProblemVector conv_tested_sizes;
556
+
557
+ Conv3dProblemVector const *problem_vectors[] = {
558
+ &conv3d_problems.conv3d_default_sizes,
559
+ &conv3d_problems.conv3d_vnet_medical_sizes,
560
+ &conv_test_sizes
561
+ };
562
+
563
+ // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
564
+ for (Conv3dProblemVector const * problem_vector : problem_vectors) {
565
+
566
+ // Run conv testbed on default convolution sizes
567
+ for(auto conv_problem : *problem_vector) {
568
+
569
+ // Skip blacklist and avoid duplicate problem sizes
570
+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
571
+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
572
+ continue;
573
+ }
574
+
575
+ //
576
+ // Procedurally disable certain cases
577
+ //
578
+
579
+ // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1}
580
+ if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
581
+ ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
582
+ ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
583
+ cutlass::conv::StrideSupport::kUnity) ||
584
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport ==
585
+ cutlass::conv::StrideSupport::kUnity))) {
586
+ if (!((conv_problem.stride_d == 1) &&
587
+ (conv_problem.stride_h == 1) &&
588
+ (conv_problem.stride_w == 1))
589
+ ) {
590
+ continue;
591
+ }
592
+ }
593
+
594
+ //
595
+ // Test
596
+ //
597
+ // push back tested problem size to avoid re-running duplicates
598
+ conv_tested_sizes.push_back(conv_problem);
599
+
600
+ // test mode = xcross
601
+ passed = testbed.run(
602
+ conv_problem,
603
+ cutlass::conv::SplitKMode::kSerial);
604
+
605
+ if (!passed) {
606
+ return false;
607
+ }
608
+
609
+ // test mode = convolution
610
+ passed = testbed.run(
611
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
612
+ cutlass::conv::SplitKMode::kSerial);
613
+
614
+ if (!passed) {
615
+ return false;
616
+ }
617
+ }
618
+ }
619
+
620
+ // Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for
621
+ // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters
622
+ // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
623
+ // alpha and beta for local testing, but only runs one value for alpha and beta.
624
+ cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size (
625
+ {1, 8, 8, 8, 32}, // input size (NDHWC)
626
+ {32, 3, 3, 3, 32}, // filter size (KTRSC)
627
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
628
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
629
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
630
+ );
631
+
632
+ cutlass::conv::SplitKMode split_k_modes [] = {
633
+ cutlass::conv::SplitKMode::kSerial,
634
+ cutlass::conv::SplitKMode::kParallel
635
+ };
636
+
637
+ int split_k_slices[] = {
638
+ 1, 2, 3, 4, 201
639
+ };
640
+
641
+ double problem_alpha[] = {
642
+ 2.0
643
+ };
644
+
645
+ double problem_beta[] = {
646
+ 2.0
647
+ };
648
+
649
+ for (auto split_k_mode : split_k_modes) {
650
+ for (auto split_k_slice : split_k_slices) {
651
+ for (auto alpha : problem_alpha) {
652
+ for (auto beta : problem_beta) {
653
+
654
+ passed = testbed.run(
655
+ conv3d_split_k_test_size.reset_split_k_slices(split_k_slice),
656
+ split_k_mode,
657
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
658
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
659
+
660
+ if (!passed) {
661
+ return false;
662
+ }
663
+ }
664
+ }
665
+ }
666
+ }
667
+
668
+ return passed;
669
+ }
670
+
671
+ template <typename ImplicitGemm>
672
+ bool TestSpecificConv3d(
673
+ const Conv3dProblemVector & problem_sizes) {
674
+
675
+ bool passed = true;
676
+
677
+ //
678
+ // Testbed object
679
+ //
680
+
681
+ TestbedConv3d<ImplicitGemm> testbed;
682
+
683
+ // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
684
+ for(auto conv_problem : problem_sizes) {
685
+
686
+ //
687
+ // Test
688
+ //
689
+
690
+ // test mode = xcross
691
+ passed = testbed.run(
692
+ conv_problem,
693
+ cutlass::conv::SplitKMode::kSerial);
694
+
695
+ if (!passed) {
696
+ return false;
697
+ }
698
+
699
+ // test mode = convolution
700
+ passed = testbed.run(
701
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
702
+ cutlass::conv::SplitKMode::kSerial);
703
+
704
+ if (!passed) {
705
+ return false;
706
+ }
707
+ }
708
+
709
+ return true;
710
+ }
711
+
712
+ /////////////////////////////////////////////////////////////////////////////////////////////////
713
+
714
+ } // namespace device
715
+ } // namespace conv
716
+ } // namespace test
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h ADDED
@@ -0,0 +1,732 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implicit GEMM for fused epilogue broadcast testbed
33
+
34
+ Parallel split-k is not tested because we can just use regular conv kernel
35
+ when we need to use parallel-splitk. Broadcast can happen in the reduction
36
+ kernel.
37
+ */
38
+ #pragma once
39
+
40
+ #include <fstream>
41
+
42
+ #include "../../common/cutlass_unit_test.h"
43
+ #include "cutlass/cutlass.h"
44
+
45
+ #include "cutlass/conv/device/implicit_gemm_convolution.h"
46
+ #include "cutlass/reduction/device/reduce_split_k.h"
47
+ #include "cutlass/reduction/thread/reduction_operators.h"
48
+
49
+ #include "conv3d_problems.h"
50
+
51
+ #include "cutlass/util/host_tensor.h"
52
+ #include "cutlass/util/reference/host/tensor_fill.h"
53
+ #include "cutlass/util/reference/device/tensor_compare.h"
54
+ #include "cutlass/util/reference/host/tensor_compare.h"
55
+
56
+ #include "cutlass/util/reference/host/convolution.h"
57
+ #include "cutlass/util/reference/device/convolution.h"
58
+
59
+ #include "cutlass/core_io.h"
60
+ #include "cutlass/util/tensor_view_io.h"
61
+
62
+ #include "../cache_testbed_output.h"
63
+
64
+ namespace test {
65
+ namespace conv {
66
+ namespace device {
67
+
68
+ /////////////////////////////////////////////////////////////////////////////////////////////////
69
+
70
+ template <typename Conv3d>
71
+ struct Conv3dWithBroadcastReferenceOp {
72
+
73
+ using OutputOp = typename Conv3d::EpilogueOutputOp;
74
+
75
+ using ElementCompute = typename OutputOp::ElementCompute;
76
+ using ElementZ = typename OutputOp::ElementZ;
77
+ using ElementT = typename OutputOp::ElementT;
78
+
79
+ typename OutputOp::BinaryOp binary_op;
80
+ typename OutputOp::ElementwiseOp elementwise_op;
81
+
82
+ Conv3dWithBroadcastReferenceOp() { }
83
+
84
+ void operator()(ElementZ &Z, ElementT &T, ElementCompute conv3d, ElementCompute bias) {
85
+ ElementCompute t_full = binary_op(conv3d, bias);
86
+ T = ElementT(t_full);
87
+
88
+ ElementCompute z_full = elementwise_op(t_full);
89
+ Z = ElementZ(z_full);
90
+ }
91
+ };
92
+
93
+ /////////////////////////////////////////////////////////////////////////////////////////////////
94
+
95
+ // Fused testbed
96
+ //
97
+ // Y = CONV(AB, C)
98
+ //
99
+ // T[n, o, p, q, k] = ReductionOp(Y[n, o, p, q, k], Broadcast[k])
100
+ //
101
+ // Z[n, o, p, q, k] = Elementwise(T[n, o, p, q, k])
102
+ //
103
+
104
+ template <
105
+ typename Conv3d,
106
+ typename ReferenceOp,
107
+ bool AddBroadcastFirst = false
108
+ >
109
+ class TestbedConv3dWithBroadcast {
110
+ public:
111
+
112
+ using ElementA = typename Conv3d::ElementA;
113
+ using LayoutA = typename Conv3d::LayoutA;
114
+ using ElementB = typename Conv3d::ElementB;
115
+ using LayoutB = typename Conv3d::LayoutB;
116
+ using ElementC = typename Conv3d::ElementC;
117
+ using LayoutC = typename Conv3d::LayoutC;
118
+ using ElementAccumulator = typename Conv3d::ElementAccumulator;
119
+ using ElementCompute = typename Conv3d::ElementCompute;
120
+ using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp;
121
+ using ElementZ = typename EpilogueOutputOp::ElementZ;
122
+ using ElementT = typename EpilogueOutputOp::ElementT;
123
+ using ElementVector = typename EpilogueOutputOp::ElementVector;
124
+
125
+ static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator;
126
+ static const bool kAddBroadcastFirst = AddBroadcastFirst;
127
+ static const bool kStoreT = EpilogueOutputOp::kStoreT;
128
+
129
+ public:
130
+
131
+ /// Initialization
132
+ cutlass::Distribution::Kind init_A;
133
+ cutlass::Distribution::Kind init_B;
134
+ cutlass::Distribution::Kind init_C;
135
+ uint64_t seed;
136
+
137
+ cutlass::HostTensor<ElementA, LayoutA> tensor_A;
138
+ cutlass::HostTensor<ElementB, LayoutB> tensor_B;
139
+ cutlass::HostTensor<ElementC, LayoutC> tensor_C;
140
+ cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_C_reference;
141
+ cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_computed;
142
+ cutlass::HostTensor<ElementZ, LayoutC> tensor_Z_reference;
143
+ cutlass::HostTensor<ElementT, LayoutC> tensor_T_computed;
144
+ cutlass::HostTensor<ElementT, LayoutC> tensor_T_reference;
145
+ cutlass::HostTensor<ElementAccumulator, LayoutC> tensor_Y_reference;
146
+ cutlass::HostTensor<ElementVector, LayoutC> tensor_Broadcast; // Input Broadcast
147
+
148
+ public:
149
+
150
+ TestbedConv3dWithBroadcast(
151
+ cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
152
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
153
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
154
+ uint64_t seed_ = 2080
155
+ ):
156
+ init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {
157
+
158
+ }
159
+
160
+ /// Helper to initialize a tensor view
161
+ template <typename Element, typename Layout>
162
+ void initialize_tensor(
163
+ cutlass::TensorView<Element, Layout> view,
164
+ cutlass::Distribution::Kind dist_kind,
165
+ uint64_t seed) {
166
+
167
+ if (dist_kind == cutlass::Distribution::Uniform) {
168
+
169
+ int scope;
170
+ int bits = cutlass::sizeof_bits<Element>::value;
171
+
172
+ if (bits <= 8) {
173
+ scope = 2;
174
+ }
175
+ else if (bits == 16) {
176
+ if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
177
+ scope = 3;
178
+ }
179
+ else {
180
+ scope = 5;
181
+ }
182
+ }
183
+ else {
184
+ scope = 8;
185
+ }
186
+
187
+ cutlass::reference::host::TensorFillRandomUniform(
188
+ view, seed, scope, -scope, 0);
189
+ }
190
+ else if (dist_kind == cutlass::Distribution::Identity) {
191
+
192
+ cutlass::reference::host::TensorFillIdentity(view);
193
+ }
194
+ else if (dist_kind == cutlass::Distribution::Gaussian) {
195
+
196
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
197
+ }
198
+ else if (dist_kind == cutlass::Distribution::Sequential) {
199
+
200
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
201
+ }
202
+ else {
203
+ }
204
+ }
205
+
206
+ void initialize(
207
+ cutlass::conv::Conv3dProblemSize const &problem_size, bool non_packed_test = false, uint64_t seed = 2019) {
208
+
209
+ // to make the layout of tensors a little bit bigger than the problem size
210
+ cutlass::Tensor5DCoord stride_increment = cutlass::Tensor5DCoord(8, 16, 32, 32, 64);
211
+
212
+ cutlass::Tensor5DCoord tensor_A_extent = implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size);
213
+ cutlass::Tensor5DCoord tensor_B_extent = implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size);
214
+ cutlass::Tensor5DCoord tensor_C_extent = implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size);
215
+
216
+ if (non_packed_test) {
217
+ tensor_A_extent += stride_increment;
218
+ tensor_C_extent += stride_increment;
219
+ }
220
+
221
+ tensor_A.resize(tensor_A_extent);
222
+ tensor_B.resize(tensor_B_extent);
223
+ tensor_C.resize(tensor_C_extent);
224
+ tensor_C_reference.resize(tensor_C_extent);
225
+ tensor_Z_computed.resize(tensor_C_extent);
226
+ tensor_Z_reference.resize(tensor_C_extent);
227
+ tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
228
+ tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
229
+ tensor_Y_reference.resize(tensor_C_extent);
230
+ tensor_Broadcast.resize({
231
+ 1,
232
+ 1,
233
+ 1,
234
+ 1,
235
+ implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(),
236
+ });
237
+
238
+ initialize_tensor(tensor_A.host_view(), init_A, seed);
239
+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
240
+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
241
+ initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39);
242
+ for (int n = 0; n < tensor_C_reference.extent().n(); ++n) {
243
+ for (int o = 0; o < tensor_C_reference.extent().d(); ++o) {
244
+ for (int p = 0; p < tensor_C_reference.extent().h(); ++p) {
245
+ for (int q = 0; q < tensor_C_reference.extent().w(); ++q) {
246
+ for (int k = 0; k < tensor_C_reference.extent().c(); ++k) {
247
+ tensor_C_reference.at({n, o, p, q, k}) = ElementAccumulator(tensor_C.at({n, o, p, q, k}));
248
+ }
249
+ }
250
+ }
251
+ }
252
+ }
253
+ tensor_A.sync_device();
254
+ tensor_B.sync_device();
255
+ tensor_C.sync_device();
256
+ tensor_Broadcast.sync_device();
257
+ tensor_C_reference.sync_device();
258
+ tensor_Z_computed.sync_device();
259
+ tensor_Z_reference.sync_device();
260
+ tensor_T_computed.sync_device();
261
+ tensor_T_reference.sync_device();
262
+ tensor_Y_reference.sync_device();
263
+ }
264
+
265
+ bool sufficient() const {
266
+ //
267
+ // Determine SMEM requirements and waive if not satisfied
268
+ //
269
+
270
+ size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage);
271
+
272
+ cudaDeviceProp properties;
273
+ int device_idx;
274
+ cudaError_t result = cudaGetDevice(&device_idx);
275
+
276
+ if (result != cudaSuccess) {
277
+ throw std::runtime_error("cudaGetDevice() API call failed.");
278
+ }
279
+
280
+ result = cudaGetDeviceProperties(&properties, device_idx);
281
+
282
+ if (result != cudaSuccess) {
283
+ throw std::runtime_error("cudaGetDeviceProperties() failed");
284
+ }
285
+
286
+ if (properties.sharedMemPerBlockOptin < smem_size) {
287
+ return false;
288
+ }
289
+
290
+ return true;
291
+ }
292
+
293
+ /// Executes one test
294
+ bool run(
295
+ cutlass::conv::Conv3dProblemSize const &problem_size,
296
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
297
+ bool non_packed_test = false,
298
+ ElementCompute alpha = ElementCompute(1),
299
+ ElementCompute beta = ElementCompute(1)) {
300
+
301
+ // Waive test if insufficient CUDA device
302
+ if (!sufficient()) {
303
+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
304
+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
305
+ }
306
+ return true;
307
+ }
308
+
309
+ #if 0 //display conv3d problem size for debugging
310
+ std::cout << problem_size << std::endl
311
+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
312
+ << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl
313
+ << std::endl;
314
+ #endif
315
+
316
+ initialize(problem_size, non_packed_test);
317
+
318
+ // configure the operator
319
+ Conv3d conv3d_op;
320
+ typename Conv3d::Arguments conv3d_args(
321
+ problem_size,
322
+ tensor_A.device_ref(),
323
+ tensor_B.device_ref(),
324
+ tensor_C.device_ref(),
325
+ tensor_Z_computed.device_ref(),
326
+ {alpha, beta},
327
+ split_k_mode,
328
+ tensor_Broadcast.device_data(),
329
+ kStoreT ? tensor_T_computed.device_data() : nullptr,
330
+ 0, // This must be zero
331
+ implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()
332
+ );
333
+
334
+ // initialize the kernel
335
+ size_t workspace_size = Conv3d::get_workspace_size(conv3d_args);
336
+
337
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
338
+
339
+ cutlass::Status status = conv3d_op.initialize(conv3d_args, workspace.get());
340
+
341
+ if (status != cutlass::Status::kSuccess) {
342
+ cudaError_t error = cudaGetLastError();
343
+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
344
+ return true;
345
+ }
346
+
347
+ // run conv3d operator
348
+ status = conv3d_op();
349
+
350
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
351
+ if (status != cutlass::Status::kSuccess) {
352
+ return false;
353
+ }
354
+
355
+ bool passed = false;
356
+
357
+ cudaError_t result = cudaDeviceSynchronize();
358
+ EXPECT_EQ(result, cudaSuccess) << " device reference error: "
359
+ << cudaGetErrorString(result);
360
+
361
+ tensor_T_computed.sync_host();
362
+ tensor_Z_computed.sync_host();
363
+
364
+ //
365
+ // Reference check
366
+ //
367
+
368
+ // When kAddBroadcastFirst is true, add bias on the host
369
+ ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta;
370
+
371
+ #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
372
+
373
+ cutlass::reference::device::Conv3d<
374
+ ElementA,
375
+ LayoutA,
376
+ ElementB,
377
+ LayoutB,
378
+ ElementAccumulator,
379
+ LayoutC,
380
+ ElementAccumulator,
381
+ ElementAccumulator
382
+ >(
383
+ kConvolutionalOperator,
384
+ problem_size,
385
+ tensor_A.device_ref(),
386
+ tensor_B.device_ref(),
387
+ tensor_C_reference.device_ref(),
388
+ tensor_Y_reference.device_ref(),
389
+ alpha,
390
+ beta_ref);
391
+
392
+ // sync host (copy device data to host) for dumping error output in case of mismatches
393
+ tensor_Y_reference.sync_host();
394
+
395
+ #else
396
+
397
+ cutlass::reference::host::Conv3d<
398
+ ElementA,
399
+ LayoutA,
400
+ ElementB,
401
+ LayoutB,
402
+ ElementAccumulator,
403
+ LayoutC,
404
+ ElementAccumulator,
405
+ ElementAccumulator
406
+ >(
407
+ kConvolutionalOperator,
408
+ problem_size,
409
+ tensor_A.host_ref(),
410
+ tensor_B.host_ref(),
411
+ tensor_C_reference.host_ref(),
412
+ tensor_Y_reference.host_ref(),
413
+ alpha,
414
+ beta_ref);
415
+
416
+ #endif
417
+ ReferenceOp reference_op;
418
+
419
+ // compute tensor Z and tensor T
420
+ for (int n = 0; n < problem_size.N; ++n) {
421
+ for (int o = 0; o < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Z : problem_size.D); ++o) {
422
+ for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) {
423
+ for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) {
424
+ for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) {
425
+
426
+ ElementZ z{};
427
+ ElementT t{};
428
+
429
+ ElementCompute accum = tensor_Y_reference.at({n, o, p, q, k});
430
+ ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, 0, k}));
431
+
432
+
433
+ if (kAddBroadcastFirst) {
434
+ reference_op(z, t, accum + bias,
435
+ beta * ElementCompute(tensor_C_reference.at({n, o, p, q, k})));
436
+ } else {
437
+ reference_op(z, t, accum, bias);
438
+ }
439
+
440
+ tensor_Z_reference.at({n, o, p, q, k}) = z;
441
+ tensor_T_reference.at({n, o, p, q, k}) = t;
442
+ }
443
+ }
444
+ }
445
+ }
446
+ }
447
+
448
+ if (kStoreT) {
449
+ passed = cutlass::reference::host::TensorEquals(
450
+ tensor_T_computed.host_view(),
451
+ tensor_T_reference.host_view());
452
+
453
+ EXPECT_TRUE(passed);
454
+ }
455
+
456
+ passed = cutlass::reference::host::TensorEquals(
457
+ tensor_Z_computed.host_view(),
458
+ tensor_Z_reference.host_view());
459
+
460
+ EXPECT_TRUE(passed);
461
+
462
+ if (!passed) {
463
+ std::stringstream fname;
464
+
465
+ fname << "error_Conv3d_ImplicitGemm_device_"
466
+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
467
+ << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
468
+ (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" :
469
+ (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_")))
470
+ << "nnhwc_"
471
+ << problem_size.N << "x"
472
+ << problem_size.D << "x"
473
+ << problem_size.H << "x"
474
+ << problem_size.W << "x"
475
+ << problem_size.C
476
+ << "_krsc_"
477
+ << problem_size.K << "x"
478
+ << problem_size.T << "x"
479
+ << problem_size.R << "x"
480
+ << problem_size.S << "x"
481
+ << problem_size.C
482
+ << "_padding_"
483
+ << problem_size.pad_d << "x"
484
+ << problem_size.pad_h << "x"
485
+ << problem_size.pad_w
486
+ << "_stride_"
487
+ << problem_size.stride_d << "x"
488
+ << problem_size.stride_h << "x"
489
+ << problem_size.stride_w
490
+ << "_dilation_"
491
+ << problem_size.dilation_d << "x"
492
+ << problem_size.dilation_h << "x"
493
+ << problem_size.dilation_w << "_"
494
+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_")
495
+ << (non_packed_test ? "non_packed_tensor_test_" : "packed_tensor_test_")
496
+ << Conv3d::ThreadblockShape::kM << "x"
497
+ << Conv3d::ThreadblockShape::kN << "x"
498
+ << Conv3d::ThreadblockShape::kK << "_"
499
+ << Conv3d::WarpShape::kM << "x"
500
+ << Conv3d::WarpShape::kN << "x"
501
+ << Conv3d::WarpShape::kK << ".txt";
502
+
503
+ std::cout << fname.str() << std::endl;
504
+
505
+ std::ofstream results(fname.str());
506
+
507
+ results << problem_size << std::endl;
508
+
509
+ results
510
+ << "\nA:\n" << tensor_A.host_view() << "\n"
511
+ << "\nB:\n" << tensor_B.host_view() << "\n"
512
+ << "\nC:\n" << tensor_C.host_view() << "\n"
513
+ << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n"
514
+ << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n"
515
+ << "\nT reference:\n" << tensor_T_reference.host_view() << "\n"
516
+ << "\nT computed:\n" << tensor_T_computed.host_view() << "\n"
517
+ << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n"
518
+ << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n";
519
+ }
520
+
521
+ return passed;
522
+ }
523
+ };
524
+
525
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////
526
+ // TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference
527
+ // TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv3dProblemSizes
528
+ // Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes
529
+ // (conv_blacklist_sizes)
530
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
531
+ template <typename ImplicitGemm,
532
+ typename ReferenceOp = Conv3dWithBroadcastReferenceOp<ImplicitGemm>,
533
+ bool AddBroadcastFirst = false,
534
+ bool TestSplitK = true
535
+ >
536
+ bool TestAllConv3dWithBroadcast(
537
+ const Conv3dProblemVector &conv_test_sizes = Conv3dProblemVector(),
538
+ const Conv3dProblemVector &conv_blacklist_sizes = Conv3dProblemVector(),
539
+ bool non_packed_test = false) {
540
+
541
+ bool passed = true;
542
+
543
+ //
544
+ // Testbed object
545
+ //
546
+
547
+ TestbedConv3dWithBroadcast<ImplicitGemm, ReferenceOp, AddBroadcastFirst> testbed;
548
+
549
+ //
550
+ // Get conv problem sizes to run conv operator
551
+ //
552
+ TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits<typename ImplicitGemm::ElementA>::value);
553
+
554
+ // Vector of conv3d problem sizes to avoid duplicate runs
555
+ Conv3dProblemVector conv_tested_sizes;
556
+
557
+ Conv3dProblemVector const *problem_vectors[] = {
558
+ &conv3d_problems.conv3d_default_sizes,
559
+ &conv3d_problems.conv3d_vnet_medical_sizes,
560
+ &conv_test_sizes
561
+ };
562
+
563
+ // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
564
+ for (Conv3dProblemVector const * problem_vector : problem_vectors) {
565
+
566
+ // Run conv testbed on default convolution sizes
567
+ for(auto conv_problem : *problem_vector) {
568
+
569
+ // Skip blacklist and avoid duplicate problem sizes
570
+ if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() ||
571
+ std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) {
572
+ continue;
573
+ }
574
+
575
+ //
576
+ // Procedurally disable certain cases
577
+ //
578
+
579
+ // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1}
580
+ if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
581
+ ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
582
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
583
+ cutlass::conv::StrideSupport::kUnity)) {
584
+ if (!((conv_problem.stride_d == 1) &&
585
+ (conv_problem.stride_h == 1) &&
586
+ (conv_problem.stride_w == 1))
587
+ ) {
588
+ continue;
589
+ }
590
+ }
591
+
592
+ #if 0 // relax restrictions on analytic strided dgrad
593
+ // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2}
594
+ if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ||
595
+ ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) &&
596
+ (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport ==
597
+ cutlass::conv::StrideSupport::kStrided)) {
598
+ if (((conv_problem.stride_d == 1) && (conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) {
599
+ continue;
600
+ }
601
+ }
602
+ #endif
603
+
604
+ //
605
+ // Test
606
+ //
607
+ // push back tested problem size to avoid re-running duplicates
608
+ conv_tested_sizes.push_back(conv_problem);
609
+
610
+ // test mode = xcross
611
+ passed = testbed.run(
612
+ conv_problem,
613
+ cutlass::conv::SplitKMode::kSerial, non_packed_test);
614
+
615
+ if (!passed) {
616
+ return false;
617
+ }
618
+
619
+ // test mode = convolution
620
+ passed = testbed.run(
621
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
622
+ cutlass::conv::SplitKMode::kSerial, non_packed_test);
623
+
624
+ if (!passed) {
625
+ return false;
626
+ }
627
+ }
628
+ }
629
+
630
+ if (!TestSplitK)
631
+ return passed;
632
+
633
+ // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for
634
+ // a single conv3d problem size. Convolution unit tests take a long time to run so only sweep parameters
635
+ // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep
636
+ // alpha and beta for local testing, but only runs one value for alpha and beta.
637
+ cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size (
638
+ {1, 8, 8, 8, 32}, // input size (NDHWC)
639
+ {32, 3, 3, 3, 32}, // filter size (KTRSC)
640
+ cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w)
641
+ cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w)
642
+ cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w)
643
+ );
644
+
645
+ cutlass::conv::SplitKMode split_k_modes [] = {
646
+ cutlass::conv::SplitKMode::kSerial
647
+ };
648
+
649
+ int split_k_slices[] = {
650
+ 1, 2, 3, 4, 201
651
+ };
652
+
653
+ double problem_alpha[] = {
654
+ 2.0
655
+ };
656
+
657
+ double problem_beta[] = {
658
+ 2.0
659
+ };
660
+
661
+ for (auto split_k_mode : split_k_modes) {
662
+ for (auto split_k_slice : split_k_slices) {
663
+ for (auto alpha : problem_alpha) {
664
+ for (auto beta : problem_beta) {
665
+
666
+ passed = testbed.run(
667
+ conv3d_split_k_test_size.reset_split_k_slices(split_k_slice),
668
+ split_k_mode,
669
+ false,/*non_packed_test*/
670
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(alpha),
671
+ cutlass::from_real<typename ImplicitGemm::ElementCompute>(beta));
672
+
673
+ if (!passed) {
674
+ return false;
675
+ }
676
+ }
677
+ }
678
+ }
679
+ }
680
+
681
+ return passed;
682
+ }
683
+
684
+ template <typename ImplicitGemm,
685
+ typename ReferenceOp = Conv3dWithBroadcastReferenceOp<ImplicitGemm>,
686
+ bool AddBroadcastFirst = false>
687
+ bool TestSpecificConv3dWithBroadcast(
688
+ const Conv3dProblemVector & problem_sizes,
689
+ bool non_packed_test = false) {
690
+
691
+ bool passed = true;
692
+
693
+ //
694
+ // Testbed object
695
+ //
696
+
697
+ TestbedConv3dWithBroadcast<ImplicitGemm, ReferenceOp, AddBroadcastFirst> testbed;
698
+
699
+ // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
700
+ for(auto conv_problem : problem_sizes) {
701
+
702
+ //
703
+ // Test
704
+ //
705
+
706
+ // test mode = xcross, non_packed_test = false
707
+ passed = testbed.run(
708
+ conv_problem,
709
+ cutlass::conv::SplitKMode::kSerial, non_packed_test);
710
+
711
+ if (!passed) {
712
+ return false;
713
+ }
714
+
715
+ // test mode = convolution, non_packed_test = false
716
+ passed = testbed.run(
717
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
718
+ cutlass::conv::SplitKMode::kSerial, non_packed_test);
719
+
720
+ if (!passed) {
721
+ return false;
722
+ }
723
+ }
724
+
725
+ return true;
726
+ }
727
+
728
+ /////////////////////////////////////////////////////////////////////////////////////////////////
729
+
730
+ } // namespace device
731
+ } // namespace conv
732
+ } // namespace test
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Depthwise Direct Conv testbed
33
+ */
34
+ #pragma once
35
+
36
+ #include <fstream>
37
+
38
+ #include "../../common/cutlass_unit_test.h"
39
+ #include "../cache_testbed_output.h"
40
+ #include "conv2d_problems.h"
41
+ #include "cutlass/conv/device/direct_convolution.h"
42
+
43
+ #include "cutlass/core_io.h"
44
+ #include "cutlass/cutlass.h"
45
+ #include "cutlass/util/host_tensor.h"
46
+ #include "cutlass/util/reference/device/convolution.h"
47
+ #include "cutlass/util/reference/device/tensor_compare.h"
48
+ #include "cutlass/util/reference/host/convolution.h"
49
+ #include "cutlass/util/reference/host/tensor_compare.h"
50
+ #include "cutlass/util/reference/host/tensor_fill.h"
51
+ #include "cutlass/util/tensor_view_io.h"
52
+
53
+ namespace test {
54
+ namespace conv {
55
+ namespace device {
56
+
57
+ template <typename Conv2d>
58
+ class TestbedDepthwiseDirectConv2d {
59
+ public:
60
+
61
+ using ElementA = typename Conv2d::ElementA;
62
+ using LayoutA = typename Conv2d::LayoutA;
63
+ using ElementB = typename Conv2d::ElementB;
64
+ using LayoutB = typename Conv2d::LayoutB;
65
+ using ElementC = typename Conv2d::ElementC;
66
+ using LayoutC = typename Conv2d::LayoutC;
67
+ using ElementAccumulator = typename Conv2d::ElementAccumulator;
68
+ using ElementCompute = typename Conv2d::ElementCompute;
69
+ using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp;
70
+
71
+ static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator;
72
+
73
+ public:
74
+ /// Initialization
75
+ cutlass::Distribution::Kind init_A;
76
+ cutlass::Distribution::Kind init_B;
77
+ cutlass::Distribution::Kind init_C;
78
+ uint64_t seed;
79
+
80
+ cutlass::HostTensor<ElementA, LayoutA> tensor_A;
81
+ cutlass::HostTensor<ElementB, LayoutB> tensor_B;
82
+ cutlass::HostTensor<ElementB, LayoutB> tensor_reordered_B;
83
+ cutlass::HostTensor<ElementC, LayoutC> tensor_C;
84
+ cutlass::HostTensor<ElementC, LayoutC> tensor_D_computed;
85
+ cutlass::HostTensor<ElementC, LayoutC> tensor_D_reference;
86
+
87
+ int tested_problem_count;
88
+
89
+ public:
90
+ TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform,
91
+ cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform,
92
+ cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform,
93
+ uint64_t seed_ = 2080)
94
+ : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {}
95
+
96
+ /// Helper to initialize a tensor view
97
+ template <typename Element, typename Layout>
98
+ void initialize_tensor(cutlass::TensorView<Element, Layout> view,
99
+ cutlass::Distribution::Kind dist_kind,
100
+ uint64_t seed) {
101
+ if (dist_kind == cutlass::Distribution::Uniform) {
102
+ int scope;
103
+ int bits = cutlass::sizeof_bits<Element>::value;
104
+
105
+ if (bits <= 8) {
106
+ scope = 2;
107
+ } else if (bits == 16) {
108
+ if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) {
109
+ scope = 3;
110
+ } else {
111
+ scope = 5;
112
+ }
113
+ } else {
114
+ scope = 8;
115
+ }
116
+ cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0);
117
+ } else if (dist_kind == cutlass::Distribution::Identity) {
118
+ cutlass::reference::host::TensorFillIdentity(view);
119
+
120
+ } else if (dist_kind == cutlass::Distribution::Gaussian) {
121
+ cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5);
122
+ } else if (dist_kind == cutlass::Distribution::Sequential) {
123
+ cutlass::reference::host::BlockFillSequential(view.data(), view.capacity());
124
+ } else {
125
+ }
126
+ }
127
+
128
+ void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) {
129
+ tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size));
130
+ tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
131
+ tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size));
132
+ tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
133
+ tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
134
+ tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size));
135
+
136
+ initialize_tensor(tensor_A.host_view(), init_A, seed);
137
+ initialize_tensor(tensor_B.host_view(), init_B, seed * 17);
138
+ initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17);
139
+ initialize_tensor(tensor_C.host_view(), init_C, seed * 39);
140
+
141
+ tensor_A.sync_device();
142
+ tensor_B.sync_device();
143
+ tensor_reordered_B.sync_device();
144
+ tensor_C.sync_device();
145
+ tensor_D_computed.sync_device();
146
+ tensor_D_reference.sync_device();
147
+ }
148
+
149
+ bool sufficient(int smem_size) const {
150
+ //
151
+ // Determine SMEM requirements and waive if not satisfied
152
+ //
153
+
154
+ cudaDeviceProp properties;
155
+ int device_idx;
156
+ cudaError_t result = cudaGetDevice(&device_idx);
157
+
158
+ if (result != cudaSuccess) {
159
+ throw std::runtime_error("cudaGetDevice() API call failed.");
160
+ }
161
+
162
+ result = cudaGetDeviceProperties(&properties, device_idx);
163
+
164
+ if (result != cudaSuccess) {
165
+ throw std::runtime_error("cudaGetDeviceProperties() failed");
166
+ }
167
+
168
+ if (properties.sharedMemPerBlockOptin < static_cast<size_t>(smem_size)) {
169
+ return false;
170
+ }
171
+
172
+ return true;
173
+ }
174
+
175
+ /// Executes one test
176
+ bool run(cutlass::conv::Conv2dProblemSize const &problem_size,
177
+ cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial,
178
+ ElementCompute alpha = ElementCompute(1.5),
179
+ ElementCompute beta = ElementCompute(1)) {
180
+ // increment tested problem count run by the testbed
181
+ tested_problem_count++;
182
+
183
+ #if 0 // display conv2d problem size for debugging
184
+ std::cout << problem_size << std::endl
185
+ << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl
186
+ << "split_k_mode: "
187
+ << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)")
188
+ << std::endl
189
+ << std::endl;
190
+ #endif
191
+
192
+ initialize(problem_size);
193
+
194
+ // configure the operator
195
+ Conv2d conv2d_op;
196
+
197
+ typename Conv2d::Arguments conv2d_args(problem_size,
198
+ tensor_A.device_ref(),
199
+ tensor_B.device_ref(),
200
+ tensor_C.device_ref(),
201
+ tensor_D_computed.device_ref(),
202
+ {alpha, beta},
203
+ tensor_reordered_B.device_ref(),
204
+ split_k_mode);
205
+
206
+ // find workspace requirement for parallel split-k reduction
207
+ size_t workspace_size = Conv2d::get_workspace_size(conv2d_args);
208
+
209
+ cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
210
+
211
+ cutlass::Status status = conv2d_op.can_implement(problem_size);
212
+
213
+ if (status != cutlass::Status::kSuccess) {
214
+ cudaError_t error = cudaGetLastError();
215
+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
216
+ return true;
217
+ }
218
+
219
+ status = conv2d_op.initialize(conv2d_args, workspace.get());
220
+
221
+ if (status != cutlass::Status::kSuccess) {
222
+ cudaError_t error = cudaGetLastError();
223
+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
224
+ return true;
225
+ }
226
+
227
+ if (!sufficient(conv2d_op.get_smem_size())) {
228
+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
229
+ std::cerr << "Test waived due to insufficient CUDA device." << std::endl;
230
+ }
231
+ return true;
232
+ }
233
+
234
+ // run conv2d operator
235
+ status = conv2d_op();
236
+
237
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
238
+ if (status != cutlass::Status::kSuccess) {
239
+ std::cerr << "Failed to run." << std::endl;
240
+ return false;
241
+ }
242
+
243
+ bool passed = false;
244
+
245
+ cudaError_t result = cudaDeviceSynchronize();
246
+ EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result);
247
+
248
+ tensor_D_computed.sync_host();
249
+
250
+ //
251
+ // Reference check - support caching results
252
+ //
253
+
254
+ CachedTestKey cached_test_key =
255
+ CreateCachedConv2dTestKey<ElementA,
256
+ LayoutA,
257
+ ElementB,
258
+ LayoutB,
259
+ ElementC,
260
+ LayoutC,
261
+ ElementAccumulator,
262
+ ElementCompute>(kConvolutionalOperator,
263
+ problem_size,
264
+ alpha,
265
+ beta,
266
+ tensor_A.host_view(),
267
+ tensor_B.host_view(),
268
+ tensor_C.host_view());
269
+
270
+ //
271
+ // Look for the cached key
272
+ //
273
+
274
+ bool cached_result_loaded = false;
275
+ CachedTestResult cached_test_result;
276
+
277
+ std::string conv2d_result_cache_name =
278
+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
279
+
280
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
281
+
282
+ CachedTestResultListing cached_results(conv2d_result_cache_name);
283
+
284
+ auto cached = cached_results.find(cached_test_key);
285
+
286
+ cached_result_loaded = cached.first;
287
+ if (cached_result_loaded) {
288
+ cached_test_result = cached.second;
289
+ }
290
+ }
291
+
292
+ if (!cached_result_loaded) {
293
+ #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED
294
+
295
+ cutlass::reference::device::Conv2d<ElementA,
296
+ LayoutA,
297
+ ElementB,
298
+ LayoutB,
299
+ ElementC,
300
+ LayoutC,
301
+ ElementCompute,
302
+ ElementAccumulator>(kConvolutionalOperator,
303
+ problem_size,
304
+ tensor_A.device_ref(),
305
+ tensor_B.device_ref(),
306
+ tensor_C.device_ref(),
307
+ tensor_D_reference.device_ref(),
308
+ alpha,
309
+ beta);
310
+
311
+ // sync host (copy device data to host) for dumping error output in case of mismatches
312
+ tensor_D_reference.sync_host();
313
+
314
+ #else
315
+
316
+ cutlass::reference::host::Conv2d<ElementA,
317
+ LayoutA,
318
+ ElementB,
319
+ LayoutB,
320
+ ElementC,
321
+ LayoutC,
322
+ ElementCompute,
323
+ ElementAccumulator>(kConvolutionalOperator,
324
+ problem_size,
325
+ tensor_A.host_ref(),
326
+ tensor_B.host_ref(),
327
+ tensor_C.host_ref(),
328
+ tensor_D_reference.host_ref(),
329
+ alpha,
330
+ beta);
331
+
332
+ #endif
333
+
334
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
335
+
336
+ cached_test_result.D = TensorHash(tensor_D_reference.host_view());
337
+
338
+ CachedTestResultListing cached_results(conv2d_result_cache_name);
339
+
340
+ cached_results.append(cached_test_key, cached_test_result);
341
+ cached_results.write(conv2d_result_cache_name);
342
+ }
343
+ } // if (!cached_result_loaded)
344
+
345
+ uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view());
346
+
347
+ if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) {
348
+ passed = (tensor_D_hash == cached_test_result.D);
349
+
350
+ EXPECT_EQ(tensor_D_hash, cached_test_result.D)
351
+ << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n";
352
+ }
353
+ else {
354
+
355
+ passed = cutlass::reference::host::TensorEquals(
356
+ tensor_D_computed.host_view(),
357
+ tensor_D_reference.host_view());
358
+ }
359
+
360
+ EXPECT_TRUE(passed);
361
+
362
+ std::stringstream ss_problem_size_text;
363
+ ss_problem_size_text << "nhwc_"
364
+ << problem_size.N << "x"
365
+ << problem_size.H << "x"
366
+ << problem_size.W << "x"
367
+ << problem_size.C
368
+ << "_krsc_"
369
+ << problem_size.K << "x"
370
+ << problem_size.R << "x"
371
+ << problem_size.S << "x"
372
+ << problem_size.C
373
+ << "_padding_"
374
+ << problem_size.pad_h << "x"
375
+ << problem_size.pad_w
376
+ << "_stride_"
377
+ << problem_size.stride_h << "x"
378
+ << problem_size.stride_w
379
+ << "_dilation_"
380
+ << problem_size.dilation_h << "x"
381
+ << problem_size.dilation_w << "_"
382
+ << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_");
383
+
384
+ if (!passed) {
385
+ std::stringstream fname;
386
+
387
+ fname << "error_Conv2d_DirectConv_device_"
388
+ << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_")
389
+ << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" :
390
+ (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_"))
391
+ << ss_problem_size_text.str()
392
+ << Conv2d::ThreadblockShape::kM << "x"
393
+ << Conv2d::ThreadblockShape::kN << "x"
394
+ << Conv2d::ThreadblockShape::kK << "_"
395
+ << Conv2d::WarpShape::kM << "x"
396
+ << Conv2d::WarpShape::kN << "x"
397
+ << Conv2d::WarpShape::kK << ".txt";
398
+
399
+ std::cout << fname.str() << std::endl;
400
+
401
+ std::ofstream results(fname.str());
402
+
403
+ results << problem_size << std::endl;
404
+
405
+ results
406
+ << "\nA:\n" << tensor_A.host_view() << "\n"
407
+ << "\nB:\n" << tensor_B.host_view() << "\n"
408
+ << "\nC:\n" << tensor_C.host_view() << "\n";
409
+
410
+ results << "\nD reference (hash: " << cached_test_result.D << ")\n";
411
+
412
+ if (!cached_result_loaded) {
413
+ results
414
+ << tensor_D_reference.host_view() << "\n";
415
+ }
416
+
417
+ results
418
+ << "\nD computed (hash: " << tensor_D_hash << ")\n"
419
+ << tensor_D_computed.host_view() << "\n";
420
+
421
+ }
422
+
423
+ return passed;
424
+ }
425
+
426
+ };
427
+
428
+ /////////////////////////////////////////////////////////////////////////////////////////////////////////////
429
+
430
+ template <typename DirectConv>
431
+ bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) {
432
+ bool passed = true;
433
+
434
+ //
435
+ // Testbed object
436
+ //
437
+ TestbedDepthwiseDirectConv2d<DirectConv> testbed;
438
+
439
+ // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0)
440
+ for (auto conv_problem : problem_sizes) {
441
+ //
442
+ // Test
443
+ //
444
+
445
+ // test mode = xcross
446
+ passed = testbed.run(
447
+ conv_problem,
448
+ cutlass::conv::SplitKMode::kSerial);
449
+
450
+ if (!passed) {
451
+ return false;
452
+ }
453
+
454
+ // test mode = convolution
455
+ passed = testbed.run(
456
+ conv_problem.reset_mode(cutlass::conv::Mode::kConvolution),
457
+ cutlass::conv::SplitKMode::kSerial);
458
+
459
+ if (!passed) {
460
+ return false;
461
+ }
462
+ }
463
+
464
+ return true;
465
+ }
466
+
467
+ /////////////////////////////////////////////////////////////////////////////////////////////////
468
+
469
+ } // namespace device
470
+ } // namespace conv
471
+ } // namespace test
472
+
473
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp ADDED
@@ -0,0 +1,1385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief CUTLASS 3.x Implicit GEMM testbed sizes for ConvNd problem
33
+ */
34
+ #pragma once
35
+
36
+ #include "cutlass/conv/convnd_problem_shape.hpp"
37
+ #include <vector>
38
+
39
+ /////////////////////////////////////////////////////////////////////////////////////////////////
40
+
41
+ namespace test::conv::device {
42
+
43
+ /////////////////////////////////////////////////////////////////////////////////////////////////
44
+
45
+ template<int SpatialDim, cutlass::conv::Operator ConvOp, bool SupportStrides = (ConvOp != cutlass::conv::Operator::kDgrad)>
46
+ std::vector<cutlass::conv::ConvProblemShape<ConvOp, SpatialDim>>
47
+ inline
48
+ get_conv_problem_vector();
49
+
50
+ /////////////////////////////////////////////////////////////////////////////////////////////////
51
+ // Fprop
52
+ /////////////////////////////////////////////////////////////////////////////////////////////////
53
+
54
+ // Specialization for 1D fprop problems
55
+ template<>
56
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 1>> inline
57
+ get_conv_problem_vector<1, cutlass::conv::Operator::kFprop>() {
58
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 1>;
59
+ std::vector<ProblemShape> problem_shapes;
60
+ problem_shapes.push_back({
61
+ cutlass::conv::Mode::kCrossCorrelation,
62
+ {1, 8, 64}, // nwc
63
+ {64, 1, 64}, // ksc
64
+ {0}, // padding lower (pad_w)
65
+ {0}, // padding upper (pad_w)
66
+ {1}, // stride (stride_w)
67
+ {1}, // dilation (dilation_w)
68
+ 1 // group
69
+ });
70
+ // non-packed input strides.
71
+ problem_shapes.push_back({
72
+ cutlass::conv::Mode::kCrossCorrelation,
73
+ {1, 8, 64}, // nwc
74
+ {800, 80, 1}, // stride (nwc)
75
+ {64, 1, 64}, // ksc
76
+ {64, 64, 1}, // stride (ksc)
77
+ {0}, // padding lower (pad_w)
78
+ {0}, // padding upper (pad_w)
79
+ {1}, // stride (stride_w)
80
+ {1}, // dilation (dilation_w)
81
+ 1 // group
82
+ });
83
+ // non-packed output strides.
84
+ problem_shapes.push_back({
85
+ cutlass::conv::Mode::kCrossCorrelation,
86
+ {1, 8, 64}, // nwc
87
+ {512, 64, 1}, // stride (nwc)
88
+ {64, 1, 64}, // ksc
89
+ {64, 64, 1}, // stride (ksc)
90
+ {800, 80, 1}, // stride (nqk)
91
+ {0}, // padding lower (pad_w)
92
+ {0}, // padding upper (pad_w)
93
+ {1}, // stride (stride_w)
94
+ {1}, // dilation (dilation_w)
95
+ 1 // group
96
+ });
97
+ // Filter-K = 16 for predication
98
+ problem_shapes.push_back({
99
+ cutlass::conv::Mode::kCrossCorrelation,
100
+ {1, 8, 64},
101
+ {16,1, 64},
102
+ {0},
103
+ {0},
104
+ {1},
105
+ {1},
106
+ 1
107
+ });
108
+ // N = 2 and K = 128 for a larger grid
109
+ problem_shapes.push_back({
110
+ cutlass::conv::Mode::kCrossCorrelation,
111
+ {2, 8, 64},
112
+ {96, 1, 64},
113
+ {0},
114
+ {0},
115
+ {1},
116
+ {1},
117
+ 1
118
+ });
119
+ // N = 7 and K = 256 for a even larger grid
120
+ problem_shapes.push_back({
121
+ cutlass::conv::Mode::kCrossCorrelation,
122
+ {7, 8, 64},
123
+ {256, 1, 64},
124
+ {0},
125
+ {0},
126
+ {1},
127
+ {1},
128
+ 1
129
+ });
130
+ // 3 filter, no padding
131
+ problem_shapes.push_back({
132
+ cutlass::conv::Mode::kCrossCorrelation,
133
+ {2, 8, 64},
134
+ {256, 3, 64},
135
+ {0},
136
+ {0},
137
+ {1},
138
+ {1},
139
+ 1
140
+ });
141
+ // 3 filter, symmetric padding with c % cta_k !=0
142
+ problem_shapes.push_back({
143
+ cutlass::conv::Mode::kCrossCorrelation,
144
+ {2, 8, 32},
145
+ {256, 3, 32},
146
+ {1},
147
+ {1},
148
+ {1},
149
+ {1},
150
+ 1
151
+ });
152
+ // 4 filter, asymmetric padding
153
+ problem_shapes.push_back({
154
+ cutlass::conv::Mode::kCrossCorrelation,
155
+ {2, 8, 64},
156
+ {256, 4, 64},
157
+ {0},
158
+ {1},
159
+ {1},
160
+ {1},
161
+ 1
162
+ });
163
+ // 3 filter, asymmetric padding and tstride of 2
164
+ problem_shapes.push_back({
165
+ cutlass::conv::Mode::kCrossCorrelation,
166
+ {2, 8, 64},
167
+ {256, 3, 64},
168
+ {0},
169
+ {1},
170
+ {2},
171
+ {1},
172
+ 1
173
+ });
174
+ // 3 filter, asymmetric padding and dilation of 2
175
+ problem_shapes.push_back({
176
+ cutlass::conv::Mode::kCrossCorrelation,
177
+ {2, 8, 64},
178
+ {256, 3, 64},
179
+ {0},
180
+ {1},
181
+ {1},
182
+ {2},
183
+ 1
184
+ });
185
+ return problem_shapes;
186
+ }
187
+
188
+ // Specialization for 2D fprop problems
189
+ template<>
190
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 2>> inline
191
+ get_conv_problem_vector<2, cutlass::conv::Operator::kFprop>() {
192
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 2>;
193
+ std::vector<ProblemShape> problem_shapes;
194
+ problem_shapes.push_back({
195
+ cutlass::conv::Mode::kCrossCorrelation,
196
+ {1, 8, 8, 64}, // nhwc
197
+ {64, 1, 1, 64}, // krsc
198
+ {0, 0}, // padding lower (pad_h, pad_w)
199
+ {0, 0}, // padding upper (pad_h, pad_w)
200
+ {1, 1}, // stride (stride_h, stride_w)
201
+ {1, 1}, // dilation (dilation_h, dilation_w)
202
+ 1 // group
203
+ });
204
+ // non-packed input strides.
205
+ problem_shapes.push_back({
206
+ cutlass::conv::Mode::kCrossCorrelation,
207
+ {1, 8, 8, 64}, // nhwc
208
+ {8000, 800, 80, 1}, // stride (nhwc)
209
+ {64, 1, 1, 64}, // krsc
210
+ {64, 64, 64, 1}, // stride (krsc)
211
+ {0, 0}, // padding lower (pad_h, pad_w)
212
+ {0, 0}, // padding upper (pad_h, pad_w)
213
+ {1, 1}, // stride (stride_h, stride_w)
214
+ {1, 1}, // dilation (dilation_h, dilation_w)
215
+ 1 // group
216
+ });
217
+ // non-packed output strides.
218
+ problem_shapes.push_back({
219
+ cutlass::conv::Mode::kCrossCorrelation,
220
+ {1, 8, 8, 64}, // nhwc
221
+ {4096, 512, 64, 1}, // stride (nhwc)
222
+ {64, 1, 1, 64}, // krsc
223
+ {64, 64, 64, 1}, // stride (krsc)
224
+ {8000, 800, 80, 1}, // stride (npqk)
225
+ {0, 0}, // padding lower (pad_h, pad_w)
226
+ {0, 0}, // padding upper (pad_h, pad_w)
227
+ {1, 1}, // stride (stride_h, stride_w)
228
+ {1, 1}, // dilation (dilation_h, dilation_w)
229
+ 1 // group
230
+ });
231
+ // Filter-K = 16 for predication
232
+ problem_shapes.push_back({
233
+ cutlass::conv::Mode::kCrossCorrelation,
234
+ {1, 8, 8, 64},
235
+ {16, 1, 1, 64},
236
+ {0, 0},
237
+ {0, 0},
238
+ {1, 1},
239
+ {1, 1},
240
+ 1
241
+ });
242
+ // N = 2 and K = 128 for a larger grid
243
+ problem_shapes.push_back({
244
+ cutlass::conv::Mode::kCrossCorrelation,
245
+ {2, 8, 8, 64},
246
+ {96, 1, 1, 64},
247
+ {0, 0},
248
+ {0, 0},
249
+ {1, 1},
250
+ {1, 1},
251
+ 1
252
+ });
253
+ // N = 7 and K = 256 for a even larger grid
254
+ problem_shapes.push_back({
255
+ cutlass::conv::Mode::kCrossCorrelation,
256
+ {7, 8, 8, 64},
257
+ {256, 1, 1, 64},
258
+ {0, 0},
259
+ {0, 0},
260
+ {1, 1},
261
+ {1, 1},
262
+ 1
263
+ });
264
+ // 3x3 filter, no padding
265
+ problem_shapes.push_back({
266
+ cutlass::conv::Mode::kCrossCorrelation,
267
+ {2, 8, 8, 64},
268
+ {256, 3, 3, 64},
269
+ {0, 0},
270
+ {0, 0},
271
+ {1, 1},
272
+ {1, 1},
273
+ 1
274
+ });
275
+ // 3x3 filter, symmetric padding with c % cta_k !=0
276
+ problem_shapes.push_back({
277
+ cutlass::conv::Mode::kCrossCorrelation,
278
+ {2, 8, 8, 32},
279
+ {256, 3, 3, 32},
280
+ {1, 1},
281
+ {1, 1},
282
+ {1, 1},
283
+ {1, 1},
284
+ 1
285
+ });
286
+ // 2x5 filter, asymmetric padding 1,2/1,2
287
+ problem_shapes.push_back({
288
+ cutlass::conv::Mode::kCrossCorrelation,
289
+ {2, 8, 8, 64},
290
+ {256, 2, 5, 64},
291
+ {1, 1},
292
+ {2, 2},
293
+ {1, 1},
294
+ {1, 1},
295
+ 1
296
+ });
297
+ // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride
298
+ problem_shapes.push_back({
299
+ cutlass::conv::Mode::kCrossCorrelation,
300
+ {2, 7, 7, 64},
301
+ {256, 2, 5, 64},
302
+ {1, 1},
303
+ {0, 0},
304
+ {2, 3},
305
+ {1, 1},
306
+ 1
307
+ });
308
+ // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
309
+ problem_shapes.push_back({
310
+ cutlass::conv::Mode::kCrossCorrelation,
311
+ {2, 16, 16, 64},
312
+ {256, 2, 5, 64},
313
+ {1, 1},
314
+ {0, 0},
315
+ {1, 1},
316
+ {2, 3},
317
+ 1
318
+ });
319
+ // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation
320
+ problem_shapes.push_back({
321
+ cutlass::conv::Mode::kCrossCorrelation,
322
+ {2, 16, 15, 64},
323
+ {256, 2, 5, 64},
324
+ {1, 1},
325
+ {0, 0},
326
+ {2, 3},
327
+ {2, 3},
328
+ 1
329
+ });
330
+ return problem_shapes;
331
+ }
332
+
333
+ // Specialization for 3D fprop problems
334
+ template<>
335
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 3>> inline
336
+ get_conv_problem_vector<3, cutlass::conv::Operator::kFprop>() {
337
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kFprop, 3>;
338
+ std::vector<ProblemShape> problem_shapes;
339
+ problem_shapes.push_back({
340
+ cutlass::conv::Mode::kCrossCorrelation,
341
+ {1, 1, 8, 8, 64}, // ndhwc
342
+ {64, 1, 1, 1, 64}, // ktrsc
343
+ {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w)
344
+ {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w)
345
+ {1, 1, 1}, // stride (stride_d, stride_h, stride_w)
346
+ {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
347
+ 1 // group
348
+ });
349
+ // non-packed input output strides.
350
+ problem_shapes.push_back({
351
+ cutlass::conv::Mode::kCrossCorrelation,
352
+ {1, 1, 8, 8, 64}, // ndhwc
353
+ {8000, 8000, 800, 80, 1}, // stride (ndhwc)
354
+ {64, 1, 1, 1, 64}, // ktrsc
355
+ {64, 64, 64, 64, 1}, // stride (ktrsc)
356
+ {8000, 8000, 800, 80, 1}, // stride (nzpqk)
357
+ {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w)
358
+ {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w)
359
+ {1, 1, 1}, // stride (stride_d, stride_h, stride_w)
360
+ {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
361
+ 1 // group
362
+ });
363
+ // Filter-K = 16 for predication
364
+ problem_shapes.push_back({
365
+ cutlass::conv::Mode::kCrossCorrelation,
366
+ {1, 1, 8, 8, 64},
367
+ {16, 1, 1, 1, 64},
368
+ {0, 0, 0},
369
+ {0, 0, 0},
370
+ {1, 1, 1},
371
+ {1, 1, 1},
372
+ 1
373
+ });
374
+ // N = 7 and K = 256 for a larger grid
375
+ problem_shapes.push_back({
376
+ cutlass::conv::Mode::kCrossCorrelation,
377
+ {2, 1, 8, 8, 64},
378
+ {96, 1, 1, 1, 64},
379
+ {0, 0, 0},
380
+ {0, 0, 0},
381
+ {1, 1, 1},
382
+ {1, 1, 1},
383
+ 1
384
+ });
385
+ // Filter 3x3x3 + no padding
386
+ problem_shapes.push_back({
387
+ cutlass::conv::Mode::kCrossCorrelation,
388
+ {2, 3, 5, 8, 64},
389
+ {96, 3, 3, 3, 64},
390
+ {0, 0, 0},
391
+ {0, 0, 0},
392
+ {1, 1, 1},
393
+ {1, 1, 1},
394
+ 1
395
+ });
396
+ // Filter 3x3x3 + symmetric padding with c % cta_k !=0
397
+ problem_shapes.push_back({
398
+ cutlass::conv::Mode::kCrossCorrelation,
399
+ {2, 3, 5, 8, 32},
400
+ {96, 3, 3, 3, 32},
401
+ {1, 1, 1},
402
+ {1, 1, 1},
403
+ {1, 1, 1},
404
+ {1, 1, 1},
405
+ 1
406
+ });
407
+ // Filter 3x4x5 + symmetric padding 111
408
+ problem_shapes.push_back({
409
+ cutlass::conv::Mode::kCrossCorrelation,
410
+ {2, 3, 5, 8, 64},
411
+ {96, 3, 4, 5, 64},
412
+ {1, 1, 1},
413
+ {1, 1, 1},
414
+ {1, 1, 1},
415
+ {1, 1, 1},
416
+ 1
417
+ });
418
+ // Filter 3x4x5 + asymmetric padding 102/010
419
+ problem_shapes.push_back({
420
+ cutlass::conv::Mode::kCrossCorrelation,
421
+ {2, 3, 5, 8, 64},
422
+ {96, 3, 4, 5, 64},
423
+ {1, 0, 1},
424
+ {0, 2, 0},
425
+ {1, 1, 1},
426
+ {1, 1, 1},
427
+ 1
428
+ });
429
+ // Filter 3x4x5 + asymmetric padding 102/010, w/ stride
430
+ problem_shapes.push_back({
431
+ cutlass::conv::Mode::kCrossCorrelation,
432
+ {2, 16, 10, 16, 64},
433
+ {96, 3, 4, 5, 64},
434
+ {1, 0, 1},
435
+ {0, 2, 0},
436
+ {2, 2, 3},
437
+ {1, 1, 1},
438
+ 1
439
+ });
440
+ // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation
441
+ problem_shapes.push_back({
442
+ cutlass::conv::Mode::kCrossCorrelation,
443
+ {2, 16, 10, 16, 64},
444
+ {96, 3, 4, 5, 64},
445
+ {1, 0, 1},
446
+ {0, 2, 0},
447
+ {1, 1, 1},
448
+ {2, 2, 3},
449
+ 1
450
+ });
451
+ // Filter 3x4x5 + asymmetric padding 102/010, w/ stride, w/ dilation
452
+ problem_shapes.push_back({
453
+ cutlass::conv::Mode::kCrossCorrelation,
454
+ {2, 16, 10, 16, 64},
455
+ {96, 3, 4, 5, 64},
456
+ {1, 0, 1},
457
+ {0, 2, 0},
458
+ {2, 2, 3},
459
+ {2, 2, 3},
460
+ 1
461
+ });
462
+ return problem_shapes;
463
+ }
464
+
465
+
466
+ /////////////////////////////////////////////////////////////////////////////////////////////////
467
+ // Wgrad
468
+ /////////////////////////////////////////////////////////////////////////////////////////////////
469
+
470
+ // Specialization for 1D wgrad problems
471
+ template<>
472
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 1>> inline
473
+ get_conv_problem_vector<1, cutlass::conv::Operator::kWgrad>() {
474
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 1>;
475
+ std::vector<ProblemShape> problem_shapes;
476
+ problem_shapes.push_back({
477
+ cutlass::conv::Mode::kCrossCorrelation,
478
+ {1, 8, 64}, // nwc
479
+ {64, 1, 64}, // ksc
480
+ {0}, // padding lower (pad_w)
481
+ {0}, // padding upper (pad_w)
482
+ {1}, // stride (stride_w)
483
+ {1}, // dilation (dilation_w)
484
+ 1 // group
485
+ });
486
+ // Filter-K = 16 for predication
487
+ problem_shapes.push_back({
488
+ cutlass::conv::Mode::kCrossCorrelation,
489
+ {1, 8, 64},
490
+ {16,1, 64},
491
+ {0},
492
+ {0},
493
+ {1},
494
+ {1},
495
+ 1
496
+ });
497
+ // N = 2 and K = 128 for a larger grid
498
+ problem_shapes.push_back({
499
+ cutlass::conv::Mode::kCrossCorrelation,
500
+ {2, 8, 64},
501
+ {96, 1, 64},
502
+ {0},
503
+ {0},
504
+ {1},
505
+ {1},
506
+ 1
507
+ });
508
+ // N = 7 and K = 256 for a even larger grid
509
+ problem_shapes.push_back({
510
+ cutlass::conv::Mode::kCrossCorrelation,
511
+ {7, 8, 64},
512
+ {256, 1, 64},
513
+ {0},
514
+ {0},
515
+ {1},
516
+ {1},
517
+ 1
518
+ });
519
+ // 3 filter, no padding
520
+ problem_shapes.push_back({
521
+ cutlass::conv::Mode::kCrossCorrelation,
522
+ {2, 8, 32},
523
+ {256, 3, 32},
524
+ {0},
525
+ {0},
526
+ {1},
527
+ {1},
528
+ 1
529
+ });
530
+ // 3 filter, symmetric padding
531
+ problem_shapes.push_back({
532
+ cutlass::conv::Mode::kCrossCorrelation,
533
+ {2, 8, 32},
534
+ {256, 3, 32},
535
+ {1},
536
+ {1},
537
+ {1},
538
+ {1},
539
+ 1
540
+ });
541
+ // 4 filter, asymmetric padding
542
+ problem_shapes.push_back({
543
+ cutlass::conv::Mode::kCrossCorrelation,
544
+ {2, 8, 32},
545
+ {256, 4, 32},
546
+ {0},
547
+ {1},
548
+ {1},
549
+ {1},
550
+ 1
551
+ });
552
+ // 3 filter, asymmetric padding and tstride of 2
553
+ problem_shapes.push_back({
554
+ cutlass::conv::Mode::kCrossCorrelation,
555
+ {2, 8, 32},
556
+ {256, 3, 32},
557
+ {0},
558
+ {1},
559
+ {2},
560
+ {1},
561
+ 1
562
+ });
563
+ // 3 filter, asymmetric padding and dilation of 2
564
+ problem_shapes.push_back({
565
+ cutlass::conv::Mode::kCrossCorrelation,
566
+ {2, 8, 32},
567
+ {256, 3, 32},
568
+ {0},
569
+ {1},
570
+ {1},
571
+ {2},
572
+ 1
573
+ });
574
+ // To test streamk, equals to gemm-MxNxK size 128x640x2048
575
+ problem_shapes.push_back({
576
+ cutlass::conv::Mode::kCrossCorrelation,
577
+ {2, 1024, 128},
578
+ {640, 1, 128},
579
+ {0},
580
+ {0},
581
+ {1},
582
+ {1},
583
+ 1
584
+ });
585
+ // To test streamk, equals to gemm-MxNxK size 128x640x2080
586
+ problem_shapes.push_back({
587
+ cutlass::conv::Mode::kCrossCorrelation,
588
+ {2, 1040, 128},
589
+ {640, 1, 128},
590
+ {0},
591
+ {0},
592
+ {1},
593
+ {1},
594
+ 1
595
+ });
596
+ return problem_shapes;
597
+ }
598
+
599
+ // Specialization for 2D wgrad problems
600
+ template<>
601
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 2>> inline
602
+ get_conv_problem_vector<2, cutlass::conv::Operator::kWgrad>() {
603
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 2>;
604
+ std::vector<ProblemShape> problem_shapes;
605
+ problem_shapes.push_back({
606
+ cutlass::conv::Mode::kCrossCorrelation,
607
+ {1, 8, 8, 64}, // nhwc
608
+ {64, 1, 1, 64}, // krsc
609
+ {0, 0}, // padding lower (pad_h, pad_w)
610
+ {0, 0}, // padding upper (pad_h, pad_w)
611
+ {1, 1}, // stride (stride_h, stride_w)
612
+ {1, 1}, // dilation (dilation_h, dilation_w)
613
+ 1 // group
614
+ });
615
+ // Filter-K = 16 for predication
616
+ problem_shapes.push_back({
617
+ cutlass::conv::Mode::kCrossCorrelation,
618
+ {1, 8, 8, 64},
619
+ {16, 1, 1, 64},
620
+ {0, 0},
621
+ {0, 0},
622
+ {1, 1},
623
+ {1, 1},
624
+ 1
625
+ });
626
+ // N = 2 and K = 128 for a larger grid
627
+ problem_shapes.push_back({
628
+ cutlass::conv::Mode::kCrossCorrelation,
629
+ {2, 8, 8, 64},
630
+ {96, 1, 1, 64},
631
+ {0, 0},
632
+ {0, 0},
633
+ {1, 1},
634
+ {1, 1},
635
+ 1
636
+ });
637
+ // N = 7 and K = 256 for a even larger grid
638
+ problem_shapes.push_back({
639
+ cutlass::conv::Mode::kCrossCorrelation,
640
+ {7, 8, 8, 64},
641
+ {256, 1, 1, 64},
642
+ {0, 0},
643
+ {0, 0},
644
+ {1, 1},
645
+ {1, 1},
646
+ 1
647
+ });
648
+ // 3x3 filter, no padding
649
+ problem_shapes.push_back({
650
+ cutlass::conv::Mode::kCrossCorrelation,
651
+ {2, 8, 8, 32},
652
+ {256, 3, 3, 32},
653
+ {0, 0},
654
+ {0, 0},
655
+ {1, 1},
656
+ {1, 1},
657
+ 1
658
+ });
659
+ // 3x3 filter, symmetric padding
660
+ problem_shapes.push_back({
661
+ cutlass::conv::Mode::kCrossCorrelation,
662
+ {2, 8, 8, 32},
663
+ {256, 3, 3, 32},
664
+ {1, 1},
665
+ {1, 1},
666
+ {1, 1},
667
+ {1, 1},
668
+ 1
669
+ });
670
+ // 2x5 filter, asymmetric padding 1,0/1,0
671
+ problem_shapes.push_back({
672
+ cutlass::conv::Mode::kCrossCorrelation,
673
+ {2, 8, 8, 32},
674
+ {256, 2, 5, 32},
675
+ {1, 1},
676
+ {0, 0},
677
+ {1, 1},
678
+ {1, 1},
679
+ 1
680
+ });
681
+ // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride
682
+ problem_shapes.push_back({
683
+ cutlass::conv::Mode::kCrossCorrelation,
684
+ {2, 15, 16, 32},
685
+ {256, 2, 5, 32},
686
+ {1, 1},
687
+ {0, 0},
688
+ {2, 3},
689
+ {1, 1},
690
+ 1
691
+ });
692
+ // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
693
+ problem_shapes.push_back({
694
+ cutlass::conv::Mode::kCrossCorrelation,
695
+ {2, 16, 16, 32},
696
+ {256, 2, 5, 32},
697
+ {1, 1},
698
+ {0, 0},
699
+ {1, 1},
700
+ {2, 3},
701
+ 1
702
+ });
703
+ // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation
704
+ problem_shapes.push_back({
705
+ cutlass::conv::Mode::kCrossCorrelation,
706
+ {2, 16, 15, 32},
707
+ {256, 2, 5, 32},
708
+ {1, 1},
709
+ {0, 0},
710
+ {2, 3},
711
+ {2, 3},
712
+ 1
713
+ });
714
+ // To test streamk, equals to gemm-MxNxK size 128x640x2048
715
+ problem_shapes.push_back({
716
+ cutlass::conv::Mode::kCrossCorrelation,
717
+ {2, 64, 16, 128},
718
+ {640, 1, 1, 128},
719
+ {0, 0},
720
+ {0, 0},
721
+ {1, 1},
722
+ {1, 1},
723
+ 1
724
+ });
725
+ // To test streamk, equals to gemm-MxNxK size 128x640x2080
726
+ problem_shapes.push_back({
727
+ cutlass::conv::Mode::kCrossCorrelation,
728
+ {2, 65, 16, 128},
729
+ {640, 1, 1, 128},
730
+ {0, 0},
731
+ {0, 0},
732
+ {1, 1},
733
+ {1, 1},
734
+ 1
735
+ });
736
+ return problem_shapes;
737
+ }
738
+
739
+ // Specialization for 3D wgrad problems
740
+ template<>
741
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 3>> inline
742
+ get_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>() {
743
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 3>;
744
+ std::vector<ProblemShape> problem_shapes;
745
+ problem_shapes.push_back({
746
+ cutlass::conv::Mode::kCrossCorrelation,
747
+ {2, 1, 8, 8, 64}, // ndhwc
748
+ {64, 1, 1, 1, 64}, // ktrsc
749
+ {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w)
750
+ {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w)
751
+ {1, 1, 1}, // stride (stride_d, stride_h, stride_w)
752
+ {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
753
+ 1 // group
754
+ });
755
+ // Filter 3x3x3 + no padding
756
+ problem_shapes.push_back({
757
+ cutlass::conv::Mode::kCrossCorrelation,
758
+ {2, 3, 5, 8, 32},
759
+ {96, 3, 3, 3, 32},
760
+ {0, 0, 0},
761
+ {0, 0, 0},
762
+ {1, 1, 1},
763
+ {1, 1, 1},
764
+ 1
765
+ });
766
+ // Filter 3x4x5 + asymmetric padding 102/010
767
+ problem_shapes.push_back({
768
+ cutlass::conv::Mode::kCrossCorrelation,
769
+ {2, 3, 5, 8, 32},
770
+ {96, 3, 4, 5, 32},
771
+ {1, 0, 1},
772
+ {0, 2, 0},
773
+ {1, 1, 1},
774
+ {1, 1, 1},
775
+ 1
776
+ });
777
+ // Filter 3x4x5 + asymmetric padding 102/010, w/ stride
778
+ problem_shapes.push_back({
779
+ cutlass::conv::Mode::kCrossCorrelation,
780
+ {2, 16, 10, 16, 32},
781
+ {96, 3, 4, 5, 32},
782
+ {1, 0, 1},
783
+ {0, 2, 0},
784
+ {2, 2, 3},
785
+ {1, 1, 1},
786
+ 1
787
+ });
788
+ // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation
789
+ problem_shapes.push_back({
790
+ cutlass::conv::Mode::kCrossCorrelation,
791
+ {2, 16, 10, 16, 32},
792
+ {96, 3, 4, 5, 32},
793
+ {1, 0, 1},
794
+ {0, 2, 0},
795
+ {1, 1, 1},
796
+ {2, 2, 3},
797
+ 1
798
+ });
799
+ // To test streamk, equals to gemm-MxNxK size 128x640x2048
800
+ problem_shapes.push_back({
801
+ cutlass::conv::Mode::kCrossCorrelation,
802
+ {2, 1, 64, 16, 128},
803
+ {640, 1, 1, 1, 128},
804
+ {0, 0, 0},
805
+ {0, 0, 0},
806
+ {1, 1, 1},
807
+ {1, 1, 1},
808
+ 1
809
+ });
810
+ // To test streamk, equals to gemm-MxNxK size 128x640x2080
811
+ problem_shapes.push_back({
812
+ cutlass::conv::Mode::kCrossCorrelation,
813
+ {2, 1, 65, 16, 128},
814
+ {640, 1, 1, 1, 128},
815
+ {0, 0, 0},
816
+ {0, 0, 0},
817
+ {1, 1, 1},
818
+ {1, 1, 1},
819
+ 1
820
+ });
821
+ return problem_shapes;
822
+ }
823
+
824
+ /////////////////////////////////////////////////////////////////////////////////////////////////
825
+ // Grouped Wgrad
826
+ /////////////////////////////////////////////////////////////////////////////////////////////////
827
+
828
+ // Get problem size vectors for group conv problems
829
+ template<int SpatialDim, cutlass::conv::Operator ConvOp>
830
+ std::vector<cutlass::conv::ConvProblemShape<ConvOp, SpatialDim>>
831
+ inline
832
+ get_grouped_conv_problem_vector(int GroupsPerTile);
833
+
834
+ // Specialization for 3D wgrad problems
835
+ template<>
836
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 3>> inline
837
+ get_grouped_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>(int GroupsPerTile) {
838
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kWgrad, 3>;
839
+ std::vector<ProblemShape> problem_shapes;
840
+
841
+ if (GroupsPerTile == 1) {
842
+ // channel_per_group == 64
843
+ problem_shapes.push_back({
844
+ cutlass::conv::Mode::kCrossCorrelation,
845
+ {1, 1, 16, 16, 2048}, // ndhwc
846
+ {2048, 1, 3, 3, 64}, // ktrsc
847
+ {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w)
848
+ {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w)
849
+ {1, 1, 1}, // stride (stride_d, stride_h, stride_w)
850
+ {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
851
+ 32 // groups
852
+ });
853
+ }
854
+ else if (GroupsPerTile == 2) {
855
+ // channel_per_group == 32
856
+ problem_shapes.push_back({
857
+ cutlass::conv::Mode::kCrossCorrelation,
858
+ {1, 1, 16, 16, 1024}, // ndhwc
859
+ {1024, 1, 3, 3, 32}, // ktrsc
860
+ {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w)
861
+ {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w)
862
+ {1, 1, 1}, // stride (stride_d, stride_h, stride_w)
863
+ {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
864
+ 32 // groups
865
+ });
866
+ }
867
+ else if (GroupsPerTile == 4) {
868
+ // channel_per_group == 16
869
+ problem_shapes.push_back({
870
+ cutlass::conv::Mode::kCrossCorrelation,
871
+ {1, 1, 16, 16, 512}, // ndhwc
872
+ {512, 1, 3, 3, 16}, // ktrsc
873
+ {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w)
874
+ {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w)
875
+ {1, 1, 1}, // stride (stride_d, stride_h, stride_w)
876
+ {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
877
+ 32 // groups
878
+ });
879
+ }
880
+ else if (GroupsPerTile == 8) {
881
+ // channel_per_group == 8
882
+ problem_shapes.push_back({
883
+ cutlass::conv::Mode::kCrossCorrelation,
884
+ {1, 1, 16, 16, 256}, // ndhwc
885
+ {256, 1, 3, 3, 8}, // ktrsc
886
+ {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w)
887
+ {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w)
888
+ {1, 1, 1}, // stride (stride_d, stride_h, stride_w)
889
+ {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
890
+ 32 // groups
891
+ });
892
+ }
893
+ return problem_shapes;
894
+ }
895
+
896
+ /////////////////////////////////////////////////////////////////////////////////////////////////
897
+ // Unit Stride Dgrad
898
+ /////////////////////////////////////////////////////////////////////////////////////////////////
899
+
900
+ // Specialization for 1D dgrad problems
901
+ template<>
902
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 1>> inline
903
+ get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, false>() {
904
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 1>;
905
+ std::vector<ProblemShape> problem_shapes;
906
+ problem_shapes.push_back({
907
+ cutlass::conv::Mode::kCrossCorrelation,
908
+ {1, 8, 64}, // nqk
909
+ {64, 1, 64}, // ksc
910
+ {0}, // padding lower (pad_w)
911
+ {0}, // padding upper (pad_w)
912
+ {1}, // stride (stride_w)
913
+ {1}, // dilation (dilation_w)
914
+ 1 // group
915
+ });
916
+ // non-packed input strides.
917
+ problem_shapes.push_back({
918
+ cutlass::conv::Mode::kCrossCorrelation,
919
+ {1, 8, 64}, // nqk
920
+ {800, 80, 1}, // stride (nqk)
921
+ {64, 1, 64}, // ksc
922
+ {64, 64, 1}, // stride (ksc)
923
+ {0}, // padding lower (pad_w)
924
+ {0}, // padding upper (pad_w)
925
+ {1}, // stride (stride_w)
926
+ {1}, // dilation (dilation_w)
927
+ 1 // group
928
+ });
929
+ // non-packed output strides.
930
+ problem_shapes.push_back({
931
+ cutlass::conv::Mode::kCrossCorrelation,
932
+ {1, 8, 64}, // nqk
933
+ {512, 64, 1}, // stride (nqk)
934
+ {64, 1, 64}, // ksc
935
+ {64, 64, 1}, // stride (ksc)
936
+ {800, 80, 1}, // stride (nwc)
937
+ {0}, // padding lower (pad_w)
938
+ {0}, // padding upper (pad_w)
939
+ {1}, // stride (stride_w)
940
+ {1}, // dilation (dilation_w)
941
+ 1 // group
942
+ });
943
+ // Filter-K = 16 for predication
944
+ problem_shapes.push_back({
945
+ cutlass::conv::Mode::kCrossCorrelation,
946
+ {1, 8, 16},
947
+ {64, 1, 16},
948
+ {0},
949
+ {0},
950
+ {1},
951
+ {1},
952
+ 1
953
+ });
954
+ // N = 2 and K = 128 for a larger grid
955
+ problem_shapes.push_back({
956
+ cutlass::conv::Mode::kCrossCorrelation,
957
+ {2, 8, 96},
958
+ {64, 1, 96},
959
+ {0},
960
+ {0},
961
+ {1},
962
+ {1},
963
+ 1
964
+ });
965
+ // N = 7 and K = 256 for a even larger grid
966
+ problem_shapes.push_back({
967
+ cutlass::conv::Mode::kCrossCorrelation,
968
+ {7, 8, 256},
969
+ {64, 1, 256},
970
+ {0},
971
+ {0},
972
+ {1},
973
+ {1},
974
+ 1
975
+ });
976
+ // 3 filter, no padding
977
+ problem_shapes.push_back({
978
+ cutlass::conv::Mode::kCrossCorrelation,
979
+ {2, 8, 256},
980
+ {64, 3, 256},
981
+ {0},
982
+ {0},
983
+ {1},
984
+ {1},
985
+ 1
986
+ });
987
+ // 3 filter, symmetric padding with k % cta_k !=0
988
+ problem_shapes.push_back({
989
+ cutlass::conv::Mode::kCrossCorrelation,
990
+ {2, 8, 256},
991
+ {32, 3, 256},
992
+ {1},
993
+ {1},
994
+ {1},
995
+ {1},
996
+ 1
997
+ });
998
+ // 4 filter, asymmetric padding
999
+ problem_shapes.push_back({
1000
+ cutlass::conv::Mode::kCrossCorrelation,
1001
+ {2, 8, 256},
1002
+ {64, 4, 256},
1003
+ {0},
1004
+ {1},
1005
+ {1},
1006
+ {1},
1007
+ 1
1008
+ });
1009
+ // 3 filter, asymmetric padding and dilation of 2
1010
+ problem_shapes.push_back({
1011
+ cutlass::conv::Mode::kCrossCorrelation,
1012
+ {2, 16, 64},
1013
+ {256, 3, 64},
1014
+ {0},
1015
+ {1},
1016
+ {1},
1017
+ {2},
1018
+ 1
1019
+ });
1020
+ return problem_shapes;
1021
+ }
1022
+
1023
+ // Specialization for 2D dgrad problems
1024
+ template<>
1025
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 2>> inline
1026
+ get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, false>() {
1027
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 2>;
1028
+ std::vector<ProblemShape> problem_shapes;
1029
+ problem_shapes.push_back({
1030
+ cutlass::conv::Mode::kCrossCorrelation,
1031
+ {1, 8, 8, 64}, // npqk
1032
+ {64, 1, 1, 64}, // krsc
1033
+ {0, 0}, // padding lower (pad_h, pad_w)
1034
+ {0, 0}, // padding upper (pad_h, pad_w)
1035
+ {1, 1}, // stride (stride_h, stride_w)
1036
+ {1, 1}, // dilation (dilation_h, dilation_w)
1037
+ 1 // group
1038
+ });
1039
+ // non-packed input strides.
1040
+ problem_shapes.push_back({
1041
+ cutlass::conv::Mode::kCrossCorrelation,
1042
+ {1, 8, 8, 64}, // npqk
1043
+ {8000, 800, 80, 1}, // stride (npqk)
1044
+ {64, 1, 1, 64}, // krsc
1045
+ {64, 64, 64, 1}, // stride (krsc)
1046
+ {0, 0}, // padding lower (pad_h, pad_w)
1047
+ {0, 0}, // padding upper (pad_h, pad_w)
1048
+ {1, 1}, // stride (stride_h, stride_w)
1049
+ {1, 1}, // dilation (dilation_h, dilation_w)
1050
+ 1 // group
1051
+ });
1052
+ // non-packed output strides.
1053
+ problem_shapes.push_back({
1054
+ cutlass::conv::Mode::kCrossCorrelation,
1055
+ {1, 8, 8, 64}, // npqk
1056
+ {4096, 512, 64, 1}, // stride (npqk)
1057
+ {64, 1, 1, 64}, // krsc
1058
+ {64, 64, 64, 1}, // stride (krsc)
1059
+ {8000, 800, 80, 1}, // stride (nhwc)
1060
+ {0, 0}, // padding lower (pad_h, pad_w)
1061
+ {0, 0}, // padding upper (pad_h, pad_w)
1062
+ {1, 1}, // stride (stride_h, stride_w)
1063
+ {1, 1}, // dilation (dilation_h, dilation_w)
1064
+ 1 // group
1065
+ });
1066
+ // Filter-K = 16 for predication
1067
+ problem_shapes.push_back({
1068
+ cutlass::conv::Mode::kCrossCorrelation,
1069
+ {1, 8, 8, 16},
1070
+ {64, 1, 1, 16},
1071
+ {0, 0},
1072
+ {0, 0},
1073
+ {1, 1},
1074
+ {1, 1},
1075
+ 1
1076
+ });
1077
+ // N = 2 and K = 128 for a larger grid
1078
+ problem_shapes.push_back({
1079
+ cutlass::conv::Mode::kCrossCorrelation,
1080
+ {2, 8, 8, 96},
1081
+ {64, 1, 1, 96},
1082
+ {0, 0},
1083
+ {0, 0},
1084
+ {1, 1},
1085
+ {1, 1},
1086
+ 1
1087
+ });
1088
+ // N = 7 and K = 256 for a even larger grid
1089
+ problem_shapes.push_back({
1090
+ cutlass::conv::Mode::kCrossCorrelation,
1091
+ {7, 8, 8, 256},
1092
+ {64, 1, 1, 256},
1093
+ {0, 0},
1094
+ {0, 0},
1095
+ {1, 1},
1096
+ {1, 1},
1097
+ 1
1098
+ });
1099
+ // 3x3 filter, no padding
1100
+ problem_shapes.push_back({
1101
+ cutlass::conv::Mode::kCrossCorrelation,
1102
+ {2, 8, 8, 256},
1103
+ {64, 3, 3, 256},
1104
+ {0, 0},
1105
+ {0, 0},
1106
+ {1, 1},
1107
+ {1, 1},
1108
+ 1
1109
+ });
1110
+ // 3x3 filter, symmetric padding with k % cta_k !=0
1111
+ problem_shapes.push_back({
1112
+ cutlass::conv::Mode::kCrossCorrelation,
1113
+ {2, 8, 8, 256},
1114
+ {32, 3, 3, 256},
1115
+ {1, 1},
1116
+ {1, 1},
1117
+ {1, 1},
1118
+ {1, 1},
1119
+ 1
1120
+ });
1121
+ // 2x5 filter, asymmetric padding 1,0/1,0
1122
+ problem_shapes.push_back({
1123
+ cutlass::conv::Mode::kCrossCorrelation,
1124
+ {2, 8, 8, 256},
1125
+ {64, 2, 5, 256},
1126
+ {1, 1},
1127
+ {0, 0},
1128
+ {1, 1},
1129
+ {1, 1},
1130
+ 1
1131
+ });
1132
+ // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
1133
+ problem_shapes.push_back({
1134
+ cutlass::conv::Mode::kCrossCorrelation,
1135
+ {2, 16, 16, 64},
1136
+ {256, 2, 5, 64},
1137
+ {1, 1},
1138
+ {0, 0},
1139
+ {1, 1},
1140
+ {2, 3},
1141
+ 1
1142
+ });
1143
+ return problem_shapes;
1144
+ }
1145
+
1146
+ // Specialization for 3D dgrad problems
1147
+ template<>
1148
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 3>> inline
1149
+ get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, false>() {
1150
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 3>;
1151
+ std::vector<ProblemShape> problem_shapes;
1152
+ // Filter-K = 16 for predication
1153
+ problem_shapes.push_back({
1154
+ cutlass::conv::Mode::kCrossCorrelation,
1155
+ {1, 1, 8, 8, 16},
1156
+ {64, 1, 1, 1, 16},
1157
+ {0, 0, 0},
1158
+ {0, 0, 0},
1159
+ {1, 1, 1},
1160
+ {1, 1, 1},
1161
+ 1
1162
+ });
1163
+ // non-packed input output strides.
1164
+ problem_shapes.push_back({
1165
+ cutlass::conv::Mode::kCrossCorrelation,
1166
+ {1, 1, 8, 8, 64}, // nzpqk
1167
+ {8000, 8000, 800, 80, 1}, // stride (nzpqk)
1168
+ {64, 1, 1, 1, 64}, // ktrsc
1169
+ {64, 64, 64, 64, 1}, // stride (ktrsc)
1170
+ {8000, 8000, 800, 80, 1}, // stride (ndhwc)
1171
+ {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w)
1172
+ {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w)
1173
+ {1, 1, 1}, // stride (stride_d, stride_h, stride_w)
1174
+ {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w)
1175
+ 1 // group
1176
+ });
1177
+ // N = 7 and K = 256 for a larger grid
1178
+ problem_shapes.push_back({
1179
+ cutlass::conv::Mode::kCrossCorrelation,
1180
+ {2, 1, 8, 8, 96},
1181
+ {64, 1, 1, 1, 96},
1182
+ {0, 0, 0},
1183
+ {0, 0, 0},
1184
+ {1, 1, 1},
1185
+ {1, 1, 1},
1186
+ 1
1187
+ });
1188
+ // Filter 3x4x5 + symmetric padding 111
1189
+ problem_shapes.push_back({
1190
+ cutlass::conv::Mode::kCrossCorrelation,
1191
+ {2, 3, 5, 8, 96},
1192
+ {64, 3, 4, 5, 96},
1193
+ {1, 1, 1},
1194
+ {1, 1, 1},
1195
+ {1, 1, 1},
1196
+ {1, 1, 1},
1197
+ 1
1198
+ });
1199
+ // Filter 3x4x5 + asymmetric padding 102/010
1200
+ problem_shapes.push_back({
1201
+ cutlass::conv::Mode::kCrossCorrelation,
1202
+ {2, 3, 5, 8, 96},
1203
+ {64, 3, 4, 5, 96},
1204
+ {1, 0, 1},
1205
+ {0, 2, 0},
1206
+ {1, 1, 1},
1207
+ {1, 1, 1},
1208
+ 1
1209
+ });
1210
+ // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation
1211
+ problem_shapes.push_back({
1212
+ cutlass::conv::Mode::kCrossCorrelation,
1213
+ {2, 16, 10, 16, 64},
1214
+ {64, 3, 4, 5, 96},
1215
+ {1, 0, 1},
1216
+ {0, 2, 0},
1217
+ {1, 1, 1},
1218
+ {2, 2, 3},
1219
+ 1
1220
+ });
1221
+ return problem_shapes;
1222
+ }
1223
+
1224
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1225
+ // Strided Dgrad
1226
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1227
+
1228
+ // Specialization for 1D dgrad problems
1229
+ template<>
1230
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 1>> inline
1231
+ get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() {
1232
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 1>;
1233
+ std::vector<ProblemShape> problem_shapes;
1234
+ // Test TMA truncation
1235
+ problem_shapes.push_back({
1236
+ cutlass::conv::Mode::kCrossCorrelation,
1237
+ {1, 512, 64}, // nqk
1238
+ {64, 1, 64}, // ksc
1239
+ {0}, // padding lower (pad_w)
1240
+ {0}, // padding upper (pad_w)
1241
+ {2}, // stride (stride_w)
1242
+ {1}, // dilation (dilation_w)
1243
+ 1 // group
1244
+ });
1245
+ problem_shapes.push_back({
1246
+ cutlass::conv::Mode::kCrossCorrelation,
1247
+ {1, 1024, 64}, // nqk
1248
+ {64, 1, 64}, // ksc
1249
+ {0}, // padding lower (pad_w)
1250
+ {0}, // padding upper (pad_w)
1251
+ {4}, // stride (stride_w)
1252
+ {1}, // dilation (dilation_w)
1253
+ 1 // group
1254
+ });
1255
+ problem_shapes.push_back({
1256
+ cutlass::conv::Mode::kCrossCorrelation,
1257
+ {1, 2048, 64}, // nqk
1258
+ {64, 1, 64}, // ksc
1259
+ {0}, // padding lower (pad_w)
1260
+ {0}, // padding upper (pad_w)
1261
+ {8}, // stride (stride_w)
1262
+ {1}, // dilation (dilation_w)
1263
+ 1 // group
1264
+ });
1265
+ // non-packed input/output strides.
1266
+ // stride divides dilation
1267
+ // asymmetric padding
1268
+ problem_shapes.push_back({
1269
+ cutlass::conv::Mode::kCrossCorrelation,
1270
+ {3, 8, 64}, // nqk
1271
+ {800, 80, 1}, // stride (nqk)
1272
+ {64, 3, 64}, // ksc
1273
+ {64, 64, 1}, // stride (ksc)
1274
+ {800, 80, 1}, // stride (nwc)
1275
+ {0}, // padding lower (pad_w)
1276
+ {1}, // padding upper (pad_w)
1277
+ {2}, // stride (stride_w)
1278
+ {4}, // dilation (dilation_w)
1279
+ 1 // group
1280
+ });
1281
+ // non-packed input/output strides.
1282
+ // dilation divides stride
1283
+ // asymmetric padding
1284
+ problem_shapes.push_back({
1285
+ cutlass::conv::Mode::kCrossCorrelation,
1286
+ {3, 8, 64}, // nqk
1287
+ {800, 80, 1}, // stride (nqk)
1288
+ {64, 3, 64}, // ksc
1289
+ {64, 64, 1}, // stride (ksc)
1290
+ {800, 80, 1}, // stride (nwc)
1291
+ {1}, // padding lower (pad_w)
1292
+ {0}, // padding upper (pad_w)
1293
+ {4}, // stride (stride_w)
1294
+ {2}, // dilation (dilation_w)
1295
+ 1 // group
1296
+ });
1297
+ // non-packed input/output strides.
1298
+ // stride dilation dont divide
1299
+ // asymmetric padding
1300
+ problem_shapes.push_back({
1301
+ cutlass::conv::Mode::kCrossCorrelation,
1302
+ {3, 8, 64}, // nqk
1303
+ {800, 80, 1}, // stride (nqk)
1304
+ {64, 3, 64}, // ksc
1305
+ {64, 64, 1}, // stride (ksc)
1306
+ {800, 80, 1}, // stride (nwc)
1307
+ {1}, // padding lower (pad_w)
1308
+ {2}, // padding upper (pad_w)
1309
+ {2}, // stride (stride_w)
1310
+ {3}, // dilation (dilation_w)
1311
+ 1 // group
1312
+ });
1313
+ return problem_shapes;
1314
+ }
1315
+
1316
+ // Specialization for 2D dgrad problems
1317
+ template<>
1318
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 2>> inline
1319
+ get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, true>() {
1320
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 2>;
1321
+ std::vector<ProblemShape> problem_shapes;
1322
+ // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
1323
+ // mode 0 stride divides dilation
1324
+ // mode 1 dilation divides stride
1325
+ problem_shapes.push_back({
1326
+ cutlass::conv::Mode::kCrossCorrelation,
1327
+ {3, 16, 16, 64},
1328
+ {256, 2, 5, 64},
1329
+ {1, 0},
1330
+ {0, 1},
1331
+ {2, 4},
1332
+ {4, 2},
1333
+ 1
1334
+ });
1335
+ // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
1336
+ // mode 0 dilation divides stride
1337
+ // mode 1 stride divides dilation
1338
+ problem_shapes.push_back({
1339
+ cutlass::conv::Mode::kCrossCorrelation,
1340
+ {3, 16, 16, 64},
1341
+ {256, 2, 5, 64},
1342
+ {1, 0},
1343
+ {0, 1},
1344
+ {4, 2},
1345
+ {2, 4},
1346
+ 1
1347
+ });
1348
+ // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation
1349
+ // stride dilation dont divide
1350
+ problem_shapes.push_back({
1351
+ cutlass::conv::Mode::kCrossCorrelation,
1352
+ {3, 16, 16, 64},
1353
+ {256, 2, 5, 64},
1354
+ {1, 0},
1355
+ {0, 1},
1356
+ {3, 2},
1357
+ {2, 3},
1358
+ 1
1359
+ });
1360
+ return problem_shapes;
1361
+ }
1362
+
1363
+ // Specialization for 3D dgrad problems
1364
+ template<>
1365
+ std::vector<cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 3>> inline
1366
+ get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, true>() {
1367
+ using ProblemShape = cutlass::conv::ConvProblemShape<cutlass::conv::Operator::kDgrad, 3>;
1368
+ std::vector<ProblemShape> problem_shapes;
1369
+ // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation
1370
+ problem_shapes.push_back({
1371
+ cutlass::conv::Mode::kCrossCorrelation,
1372
+ {2, 16, 10, 16, 64},
1373
+ {64, 3, 4, 5, 96},
1374
+ {1, 0, 1},
1375
+ {0, 2, 0},
1376
+ {2, 1, 2},
1377
+ {4, 2, 3},
1378
+ 1
1379
+ });
1380
+ return problem_shapes;
1381
+ }
1382
+
1383
+ /////////////////////////////////////////////////////////////////////////////////////////////////
1384
+
1385
+ } // namespace cutlass::test
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+ /*! \file
32
+ \brief Implicit GEMM testbed for 3.x API
33
+ */
34
+ #pragma once
35
+
36
+ #include "cutlass/cutlass.h"
37
+ #include "../../common/cutlass_unit_test.h"
38
+
39
+ #include "cute/tensor.hpp"
40
+ #include "cutlass/kernel_hardware_info.hpp"
41
+ #include "cutlass/conv/convolution.h"
42
+ #include "cutlass/conv/convnd_problem_shape.hpp"
43
+ #include "../test/unit/gemm/device/gemm_testbed_3x.hpp"
44
+
45
+ #include "thrust/universal_vector.h"
46
+ #include "cutlass/util/distribution.h"
47
+ #include "cutlass/util/host_tensor.h"
48
+ #include "cutlass/util/tensor_view_io.h"
49
+ #include "cutlass/util/packed_stride.hpp"
50
+ #include "cutlass/util/reference/host/conv.hpp"
51
+ #include "cutlass/util/reference/host/tensor_fill.h"
52
+ #include "cutlass/util/reference/host/tensor_copy.h"
53
+ #include "cutlass/util/reference/host/tensor_compare.h"
54
+ #include "cutlass/util/reference/host/tensor_norm.h"
55
+ #include "cutlass/util/reference/device/tensor_fill.h"
56
+ #include "cutlass/util/reference/device/tensor_compare.h"
57
+ #include "conv_problem_sizes.hpp"
58
+ #include "../cache_testbed_output.h"
59
+
60
+ #include <iostream>
61
+
62
+ #include "cute/layout.hpp"
63
+ /////////////////////////////////////////////////////////////////////////////////////////////////
64
+
65
+ namespace test::conv::device {
66
+
67
+ /////////////////////////////////////////////////////////////////////////////////////////////////
68
+
69
+ // Initializes a flat device buffer
70
+ template <typename Element>
71
+ static void
72
+ initialize_values(
73
+ thrust::universal_vector<Element>& dst_ptr,
74
+ cutlass::Distribution::Kind dist_kind,
75
+ uint64_t seed) {
76
+ if (cutlass::Distribution::Uniform == dist_kind) {
77
+ int scope;
78
+ int bits = cutlass::sizeof_bits<Element>::value;
79
+
80
+ if (bits <= 8) {
81
+ scope = 2;
82
+ }
83
+ else if (bits == 16) {
84
+ scope = 4;
85
+ }
86
+ else {
87
+ scope = 8;
88
+ }
89
+ cutlass::reference::host::BlockFillRandomUniform(
90
+ dst_ptr.data().get(), dst_ptr.size(), seed, scope, -scope, 0);
91
+ }
92
+ else if (cutlass::Distribution::Identity == dist_kind) {
93
+ cutlass::reference::host::BlockFillRandomUniform(
94
+ dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0, 0);
95
+ }
96
+ else if (cutlass::Distribution::Gaussian == dist_kind) {
97
+ cutlass::reference::host::BlockFillRandomGaussian(dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0.5);
98
+ }
99
+ else if (cutlass::Distribution::Sequential == dist_kind) {
100
+ cutlass::reference::host::BlockFillSequential(dst_ptr.data().get(), dst_ptr.size());
101
+ }
102
+ else {
103
+ std::cerr << "Invalid distribution kind!\n.";
104
+ exit(1);
105
+ }
106
+ }
107
+
108
+ /////////////////////////////////////////////////////////////////////////////////////////////////
109
+ // utils for sparse or dense conv parameters
110
+
111
+ template <class Conv>
112
+ struct DenseConvParams {
113
+ // Default Kernel data types
114
+ using ElementA = typename Conv::ConvKernel::ElementA;
115
+ using ElementB = typename Conv::ConvKernel::ElementB;
116
+
117
+ static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp;
118
+ static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions;
119
+ using ProblemShape = cutlass::conv::ConvProblemShape<ConvOp, NumSpatialDimensions>;
120
+
121
+ // get the default arguments without sparse data
122
+ auto get_mainloop_arguments(
123
+ [[maybe_unused]] ProblemShape const& problem_shape,
124
+ thrust::universal_vector<ElementA>& tensor_A,
125
+ thrust::universal_vector<ElementB>& tensor_B
126
+ ) {
127
+ auto args = typename Conv::ConvKernel::MainloopArguments {
128
+ tensor_A.data().get(),
129
+ tensor_B.data().get(),
130
+ };
131
+ return args;
132
+ }
133
+ };
134
+
135
+ template <class Conv>
136
+ struct SparseConvParams {
137
+ };
138
+
139
+ /////////////////////////////////////////////////////////////////////////////////////////////////
140
+ template <class Conv, bool isSparseEnabled_ = false>
141
+ struct ConvTestbed {
142
+ // Kernel data types
143
+ using ElementA = typename Conv::ConvKernel::ElementA;
144
+ using ElementB = typename Conv::ConvKernel::ElementB;
145
+ using ElementC = cute::conditional_t<cute::is_void_v<typename Conv::ConvKernel::ElementC>,
146
+ typename Conv::ConvKernel::ElementD, typename Conv::ConvKernel::ElementC>;
147
+ using ElementD = typename Conv::ConvKernel::ElementD;
148
+ using ElementAccumulator = typename Conv::ConvKernel::ElementAccumulator;
149
+
150
+ // ConvTest for sparse kernel
151
+ static constexpr bool isSparseEnabled = isSparseEnabled_;
152
+ using ConvParams = cute::conditional_t<isSparseEnabled, SparseConvParams<Conv>, DenseConvParams<Conv>>;
153
+ ConvParams params;
154
+
155
+ //
156
+ // FusionOperation derived types/queries
157
+ //
158
+ using FusionOp = typename Conv::EpilogueOutputOp;
159
+
160
+ // fusion types are potentially void if the fusion is not supported
161
+ // helper so we don't try to construct HostTensor with void type
162
+ template <typename T, typename U = uint8_t>
163
+ using non_void_t = cute::conditional_t<cute::is_void_v<T>, U, T>;
164
+ using ElementScalar = typename FusionOp::ElementScalar;
165
+ using ElementCompute = typename FusionOp::ElementCompute;
166
+ using BiasType = typename cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias<FusionOp>::type;
167
+ using ElementBias = non_void_t<BiasType>;
168
+ using ActivationType = non_void_t<typename cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithActivation<FusionOp>::type,
169
+ cutlass::epilogue::thread::Identity<ElementCompute>>;
170
+ static constexpr bool IsActivationEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithActivation<FusionOp>::value;
171
+ using ActivationFunctor = cute::conditional_t<IsActivationEnabled, ActivationType, cutlass::epilogue::thread::Identity<ElementCompute>>;
172
+
173
+ static constexpr bool IsBiasEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias<FusionOp>::value &&
174
+ !cute::is_same_v<BiasType, void>;
175
+ static constexpr bool IsPerChannelScaleEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithPerChannelScaled<FusionOp>::value;
176
+
177
+ static constexpr bool DisableSource = cute::is_void_v<typename FusionOp::ElementSource>;
178
+
179
+ static constexpr bool IsResidualEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithResidualAdd<FusionOp>::value;
180
+
181
+ using StrideC = typename Conv::ConvKernel::StrideC;
182
+ using StrideD = typename Conv::ConvKernel::StrideD;
183
+ using ThreadEpilogueOp = typename Conv::ConvKernel::CollectiveEpilogue::ThreadEpilogueOp;
184
+
185
+ static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp;
186
+ static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions;
187
+ using ProblemShape = cutlass::conv::ConvProblemShape<ConvOp, NumSpatialDimensions>;
188
+ using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
189
+ using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
190
+ using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize;
191
+ using Splits = typename gemm::device::detail::Splits;
192
+
193
+ using Schedule = typename Conv::DispatchPolicy::Schedule;
194
+ /// Initialization
195
+ cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform;
196
+ cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform;
197
+ cutlass::Distribution::Kind init_C = cutlass::Distribution::Uniform;
198
+ cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform;
199
+ cutlass::Distribution::Kind init_disable = cutlass::Distribution::Identity; // all zeros
200
+ uint64_t seed = 6090;
201
+ float epsilon = 0.0f;
202
+ int split_p_slices = 1;
203
+ thrust::universal_vector<ElementA> tensor_A;
204
+ thrust::universal_vector<ElementB> tensor_B;
205
+ thrust::universal_vector<ElementC> tensor_C;
206
+ thrust::universal_vector<ElementD> tensor_D_computed;
207
+ thrust::universal_vector<ElementD> tensor_D_reference;
208
+ thrust::universal_vector<ElementBias> tensor_bias;
209
+ thrust::universal_vector<ElementScalar> tensor_alpha;
210
+ thrust::universal_vector<ElementScalar> tensor_beta;
211
+
212
+ // Return true on success, else false
213
+ bool initialize(ProblemShape const& problem_shape, uint64_t seed = 6090) {
214
+ tensor_A.resize(sizeof(ElementA) * problem_shape.size_A());
215
+ tensor_B.resize(sizeof(ElementB) * problem_shape.size_B());
216
+ tensor_C.resize(sizeof(ElementC) * problem_shape.size_C());
217
+ tensor_D_computed.resize(sizeof(ElementD) * problem_shape.size_C());
218
+ tensor_D_reference.resize(sizeof(ElementD) * problem_shape.size_C());
219
+ tensor_bias.resize(sizeof(ElementBias) * cute::size(cute::get<0>(problem_shape.get_shape_B())));
220
+ if constexpr (IsPerChannelScaleEnabled) {
221
+ tensor_alpha.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B())));
222
+ tensor_beta.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B())));
223
+ }
224
+ initialize_values(tensor_A, init_A, seed);
225
+ initialize_values(tensor_B, init_B, seed * 11);
226
+ initialize_values(tensor_C, init_C, seed * 17);
227
+ initialize_values(tensor_bias, init_bias, seed * 19);
228
+ if constexpr (IsPerChannelScaleEnabled) {
229
+ initialize_values(tensor_alpha, init_bias, seed * 23);
230
+ if constexpr (DisableSource) {
231
+ initialize_values(tensor_beta, init_disable, seed * 27);
232
+ }
233
+ else {
234
+ initialize_values(tensor_beta, init_bias, seed * 27);
235
+ }
236
+ }
237
+
238
+ bool flag = true;
239
+ if constexpr (isSparseEnabled) {
240
+ flag &= params.initialize(problem_shape, tensor_B, static_cast<int>(seed + 2023));
241
+ }
242
+
243
+ return flag;
244
+ }
245
+
246
+ // Determine SMEM requirements and waive if not satisfied
247
+ bool sufficient() const {
248
+ int device_idx;
249
+ cudaError_t result = cudaGetDevice(&device_idx);
250
+ if (result != cudaSuccess) {
251
+ throw std::runtime_error("cudaGetDevice() API call failed.");
252
+ }
253
+
254
+ int max_smem_size;
255
+ result = cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx);
256
+ if (result != cudaSuccess) {
257
+ throw std::runtime_error("cudaDeviceGetAttribute() failed");
258
+ }
259
+
260
+ return max_smem_size >= Conv::ConvKernel::SharedStorageSize;
261
+ }
262
+
263
+ auto transform_shape_and_stride_with_groups(ProblemShape const& problem_shape) {
264
+ using TensorExtent = cute::array<int32_t, NumSpatialDimensions + 3>;
265
+ using TensorStride = cute::array<int64_t, NumSpatialDimensions + 3>;
266
+
267
+ TensorExtent shape_a_g{};
268
+ TensorExtent shape_b_g{};
269
+ TensorExtent shape_c_g{};
270
+ TensorStride stride_a_g{};
271
+ TensorStride stride_b_g{};
272
+ TensorStride stride_c_g{};
273
+
274
+ auto shape_a = cute::reverse(problem_shape.shape_A);
275
+ auto shape_b = cute::reverse(problem_shape.shape_B);
276
+ auto shape_c = cute::reverse(problem_shape.shape_C);
277
+ auto stride_a = cute::reverse(problem_shape.stride_A);
278
+ auto stride_b = cute::reverse(problem_shape.stride_B);
279
+ auto stride_c = cute::reverse(problem_shape.stride_C);
280
+
281
+ int32_t G = problem_shape.groups;
282
+
283
+ if constexpr (ConvOp == cutlass::conv::Operator::kFprop ||
284
+ ConvOp == cutlass::conv::Operator::kDgrad) {
285
+ // shape_a_g = (c,w,h,d,n,g) or (k,q,p,z,n,g)
286
+ // shape_b_g = (c,s,r,k,t,g)
287
+ // shape_c_g = (k,q,p,z,n,g) or (c,w,h,d,n,g)
288
+ shape_a_g = cute::to_array<int32_t>(tuple_cat(
289
+ cute::make_shape(cute::size<0>(shape_a) / G),
290
+ cute::take<1,NumSpatialDimensions + 2>(shape_a),
291
+ cute::make_shape(G)));
292
+ shape_b_g = cute::to_array<int32_t>(tuple_cat(
293
+ cute::take<0,NumSpatialDimensions + 1>(shape_b),
294
+ cute::make_shape(cute::size<NumSpatialDimensions + 1>(shape_b) / G, G)));
295
+ shape_c_g = cute::to_array<int32_t>(tuple_cat(
296
+ cute::make_shape(cute::size<0>(shape_c) / G),
297
+ cute::take<1,NumSpatialDimensions + 2>(shape_c),
298
+ cute::make_shape(G)));
299
+
300
+ stride_a_g = cute::to_array<int64_t>(append(stride_a, cute::size<0>(shape_a) / G));
301
+ stride_b_g = cute::to_array<int64_t>(append(stride_b,
302
+ cute::size<NumSpatialDimensions + 1>(stride_b) * cute::size<NumSpatialDimensions + 1>(shape_b) / G));
303
+ stride_c_g = cute::to_array<int64_t>(append(stride_c, cute::size<0>(shape_c) / G));
304
+ }
305
+ else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) {
306
+ // shape_a_g = (k,q,p,z,n,g)
307
+ // shape_b_g = (c,w,h,d,n,g)
308
+ // shape_c_g = (c,s,r,k,t,g)
309
+ shape_a_g = cute::to_array<int32_t>(tuple_cat(
310
+ cute::make_shape(cute::size<0>(shape_a) / G),
311
+ cute::take<1,NumSpatialDimensions + 2>(shape_a),
312
+ cute::make_shape(G)));
313
+ shape_b_g = cute::to_array<int32_t>(tuple_cat(
314
+ cute::make_shape(cute::size<0>(shape_b) / G),
315
+ cute::take<1,NumSpatialDimensions + 2>(shape_b),
316
+ cute::make_shape(G)));
317
+ shape_c_g = cute::to_array<int32_t>(tuple_cat(
318
+ cute::take<0,NumSpatialDimensions + 1>(shape_c),
319
+ cute::make_shape(cute::size<NumSpatialDimensions + 1>(shape_c) / G, G)));
320
+
321
+ stride_a_g = cute::to_array<int64_t>(append(stride_a, cute::size<0>(shape_a) / G));
322
+ stride_b_g = cute::to_array<int64_t>(append(stride_b, cute::size<0>(shape_b) / G));
323
+ stride_c_g = cute::to_array<int64_t>(append(stride_c,
324
+ cute::size<NumSpatialDimensions + 1>(stride_c) * cute::size<NumSpatialDimensions + 1>(shape_c) / G));
325
+ }
326
+
327
+ return make_tuple(shape_a_g, shape_b_g, shape_c_g,
328
+ stride_a_g, stride_b_g, stride_c_g);
329
+ }
330
+
331
+ // Executes one test
332
+ bool run(
333
+ ProblemShape const& problem_shape,
334
+ ElementScalar alpha = ElementScalar(1),
335
+ ElementScalar beta = ElementScalar(0),
336
+ dim3 cluster_shape = dim3(0, 0, 0),
337
+ dim3 cluster_shape_fallback = dim3(0, 0, 0),
338
+ RasterOrderOptions raster_order = RasterOrderOptions::Heuristic,
339
+ MaxSwizzleSize max_swizzle = MaxSwizzleSize{},
340
+ Splits splits = Splits{},
341
+ DecompositionMode decomposition_mode = DecompositionMode::Heuristic
342
+ ) {
343
+
344
+ // Waive test if insufficient CUDA device
345
+ if (!sufficient()) {
346
+ if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) {
347
+ std::cerr << "Test waived due to insufficient CUDA device.\n";
348
+ }
349
+ return true;
350
+ }
351
+
352
+ bool ret = initialize(problem_shape);
353
+
354
+ if (!ret) {
355
+ std::cerr << "initialize failed for the given problem_shape: \n";
356
+ return false;
357
+ }
358
+
359
+ cutlass::KernelHardwareInfo hw_info;
360
+ cudaGetDevice(&hw_info.device_id);
361
+ hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
362
+
363
+ hw_info.cluster_shape = cluster_shape;
364
+ hw_info.cluster_shape_fallback = cluster_shape_fallback;
365
+
366
+ // configure the operator
367
+ Conv conv_op;
368
+ auto stride_C = StrideC{};
369
+ auto stride_D = StrideD{};
370
+ if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) {
371
+ stride_C = cutlass::make_cute_packed_stride(
372
+ StrideC{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp);
373
+ stride_D = cutlass::make_cute_packed_stride(
374
+ StrideD{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp);
375
+ }
376
+ // Need to support non-packed output strides for fprop and dgrad kernel.
377
+ else {
378
+ cute::for_each(cute::make_seq<cute::rank<0>(StrideC{})>{}, [&](auto i) {
379
+ cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i];
380
+ });
381
+ cute::for_each(cute::make_seq<cute::rank<0>(StrideD{})>{}, [&](auto i) {
382
+ cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i];
383
+ });
384
+ }
385
+
386
+ using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
387
+ using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
388
+
389
+ typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{};
390
+ if constexpr (cute::is_same_v<typename Conv::ConvKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>) {
391
+ scheduler_args = { static_cast<int>(splits), static_cast<int>(max_swizzle), raster_order, decomposition_mode };
392
+ }
393
+
394
+ auto mainloop_args = params.get_mainloop_arguments(problem_shape, tensor_A, tensor_B);
395
+
396
+ auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments {
397
+ {},
398
+ tensor_C.data().get(),
399
+ stride_C,
400
+ tensor_D_computed.data().get(),
401
+ stride_D,
402
+ };
403
+
404
+ auto args = typename Conv::Arguments {
405
+ problem_shape,
406
+ mainloop_args, // MainloopArguments
407
+ epilogue_args, // EpilogueArguments
408
+ hw_info,
409
+ scheduler_args
410
+ };
411
+
412
+ auto &fusion_args = args.epilogue.thread;
413
+
414
+ fusion_args.alpha = alpha;
415
+ fusion_args.beta = beta;
416
+
417
+ if constexpr (IsPerChannelScaleEnabled) {
418
+ fusion_args.alpha_ptr = tensor_alpha.data().get();
419
+ fusion_args.beta_ptr = tensor_beta.data().get();
420
+ }
421
+
422
+ if constexpr (IsBiasEnabled) {
423
+ fusion_args.bias_ptr = tensor_bias.data().get();
424
+ }
425
+
426
+ // Clamp bound
427
+ if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::Clamp<ElementCompute>>) {
428
+ fusion_args.activation.lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits<ElementCompute>::lowest();
429
+ fusion_args.activation.upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits<ElementCompute>::max();
430
+ }
431
+
432
+ // Scale
433
+ if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledGELU_taylor<ElementCompute>> ||
434
+ cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledGELU<ElementCompute>> ||
435
+ cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledSiLu<ElementCompute>> ||
436
+ cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::ScaledHardSwish<ElementCompute>> ) {
437
+ fusion_args.activation.scale = ElementCompute{1};
438
+ }
439
+
440
+ // LeakyRelu
441
+ if constexpr (cute::is_same_v<ActivationFunctor, cutlass::epilogue::thread::LeakyReLU<ElementCompute>> ) {
442
+ fusion_args.activation.leaky_alpha = ElementCompute{0};
443
+ }
444
+
445
+ cutlass::Status status = cutlass::Status::kInvalid;
446
+
447
+ status = conv_op.can_implement(args);
448
+ EXPECT_EQ(conv_op.can_implement(args), cutlass::Status::kSuccess);
449
+ if (status != cutlass::Status::kSuccess) {
450
+ std::cerr << "can_implement failed for the given problem_shape: \n";
451
+ print(problem_shape);
452
+ return false;
453
+ }
454
+
455
+ // find workspace requirement for parallel split-k reduction
456
+ size_t workspace_size = Conv::get_workspace_size(args);
457
+ thrust::universal_vector<uint8_t> workspace(workspace_size);
458
+
459
+ status = conv_op.initialize(args, workspace.data().get());
460
+ if (status != cutlass::Status::kSuccess) {
461
+ cudaError_t error = cudaGetLastError();
462
+ std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n";
463
+ return true;
464
+ }
465
+
466
+ // run conv3d operator
467
+ status = conv_op();
468
+
469
+ EXPECT_TRUE(status == cutlass::Status::kSuccess);
470
+ if (status != cutlass::Status::kSuccess) {
471
+ return false;
472
+ }
473
+
474
+ bool passed = false;
475
+ cudaError_t result = cudaDeviceSynchronize();
476
+ EXPECT_EQ(result, cudaSuccess) << " Kernel execution error: "
477
+ << cudaGetErrorString(result);
478
+
479
+ // Create cute::Tensors using the logical rank-3 MNK multi-mode shapes the mainloop gives us
480
+ auto [shape_mA, shape_mB, shape_mC, stride_mA, stride_mB, stride_mC] =
481
+ transform_shape_and_stride_with_groups(problem_shape);
482
+ auto shape_mBias = cute::make_shape(cute::size(cute::get<0>(problem_shape.get_shape_B())));
483
+
484
+ auto mA = make_tensor(tensor_A.data().get(), make_layout(shape_mA, stride_mA));
485
+ auto mB = make_tensor(tensor_B.data().get(), make_layout(shape_mB, stride_mB));
486
+ auto mC = make_tensor(tensor_C.data().get(), make_layout(shape_mC, stride_mC));
487
+ auto mD_ref = make_tensor(tensor_D_reference.data().get(), make_layout(shape_mC, stride_mC));
488
+ auto mD_computed = make_tensor(tensor_D_computed.data().get(), make_layout(shape_mC, stride_mC));
489
+ auto mBias = make_tensor(tensor_bias.data().get(), make_layout(shape_mBias));
490
+ auto mAlpha = make_tensor(tensor_alpha.data().get(), make_layout(shape_mBias));
491
+ auto mBeta = make_tensor(tensor_beta.data().get(), make_layout(shape_mBias));
492
+
493
+ cutlass::reference::host::ConvEpilogueFusionParams<
494
+ ElementAccumulator,
495
+ ElementScalar,
496
+ ElementCompute,
497
+ ElementC,
498
+ ElementD,
499
+ IsResidualEnabled,
500
+ decltype(mAlpha),
501
+ decltype(mBeta),
502
+ decltype(mBias),
503
+ ActivationFunctor>
504
+ epilogue_fusion_params{};
505
+
506
+ epilogue_fusion_params.alpha = alpha;
507
+ epilogue_fusion_params.beta = beta;
508
+
509
+ if constexpr (IsPerChannelScaleEnabled) {
510
+ epilogue_fusion_params.tensor_alpha = mAlpha;
511
+ epilogue_fusion_params.tensor_beta = mBeta;
512
+ }
513
+
514
+ if constexpr (IsBiasEnabled) {
515
+ epilogue_fusion_params.tensor_bias = mBias;
516
+ }
517
+
518
+ auto padding = cute::reverse(problem_shape.lower_padding);
519
+ auto tstride = cute::reverse(problem_shape.traversal_stride);
520
+ auto dilation = cute::reverse(problem_shape.dilation);
521
+
522
+ cutlass::reference::host::ConvReferenceImpl<
523
+ ConvOp,
524
+ NumSpatialDimensions,
525
+ decltype(mA),
526
+ decltype(mB),
527
+ decltype(mC),
528
+ decltype(mD_ref),
529
+ decltype(padding),
530
+ decltype(tstride),
531
+ decltype(dilation),
532
+ decltype(epilogue_fusion_params)>
533
+ reference_impl(mA, mB, mC, mD_ref, padding, tstride, dilation, epilogue_fusion_params);
534
+
535
+ //
536
+ // Reference check - support caching results
537
+ //
538
+
539
+ CachedTestKey cached_test_key = CreateCachedConvNd3xTestKey<
540
+ ProblemShape,
541
+ ElementA,
542
+ ElementB,
543
+ ElementC,
544
+ ElementD
545
+ >(
546
+ ConvOp,
547
+ problem_shape,
548
+ alpha,
549
+ beta,
550
+ tensor_A,
551
+ tensor_B,
552
+ tensor_C
553
+ );
554
+
555
+ //
556
+ // Look for the cached key
557
+ //
558
+
559
+ bool cached_result_loaded = false;
560
+ CachedTestResult cached_test_result;
561
+
562
+ std::string convnd_result_cache_name =
563
+ std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt";
564
+
565
+ #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
566
+ CachedTestResultListing cached_results(convnd_result_cache_name);
567
+
568
+ auto cached = cached_results.find(cached_test_key);
569
+
570
+ cached_result_loaded = cached.first;
571
+ if (cached_result_loaded) {
572
+ cached_test_result = cached.second;
573
+ }
574
+ #endif
575
+
576
+ if (!cached_result_loaded) {
577
+ // Compute reference
578
+ reference_impl.compute_reference();
579
+
580
+ #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
581
+ cached_test_result.D = TensorHash(tensor_D_reference);
582
+ CachedTestResultListing cached_results(convnd_result_cache_name);
583
+
584
+ cached_results.append(cached_test_key, cached_test_result);
585
+ cached_results.write(convnd_result_cache_name);
586
+ #endif
587
+ } // if (!cached_result_loaded)
588
+
589
+ #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
590
+ uint32_t tensor_D_computed_hash = TensorHash(tensor_D_computed);
591
+ passed = (tensor_D_computed_hash == cached_test_result.D);
592
+ // If hash fails, double check against reference implementation.
593
+ if(!passed) {
594
+ std::cerr << "Hash-based comparison unsuccessful for key:" << "\n" << cached_test_key
595
+ << ", comparing with reference implementation now.\n";
596
+ if (cached_result_loaded) {
597
+ // Compute reference
598
+ reference_impl.compute_reference();
599
+ }
600
+ // Validate kernel against reference
601
+ passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon);
602
+ }
603
+ #else
604
+ // Validate kernel against reference
605
+ passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon);
606
+ #endif
607
+
608
+ EXPECT_TRUE(passed);
609
+ return passed;
610
+ }
611
+
612
+ template<
613
+ class Engine, class Layout,
614
+ class EngineA, class LayoutA,
615
+ class EngineB, class LayoutB,
616
+ class EngineAlpha, class LayoutAlpha,
617
+ class EngineBeta, class LayoutBeta,
618
+ class EngineBias, class LayoutBias>
619
+ static constexpr bool
620
+ compare_reference(
621
+ cute::Tensor<Engine, Layout> const& reference,
622
+ cute::Tensor<Engine, Layout> const& computed,
623
+ cute::Tensor<EngineA, LayoutA> const& A,
624
+ cute::Tensor<EngineB, LayoutB> const& B,
625
+ cute::Tensor<EngineAlpha, LayoutAlpha> const& tensor_alpha,
626
+ cute::Tensor<EngineBeta, LayoutBeta> const& tensor_beta,
627
+ cute::Tensor<EngineBias, LayoutBias> const& tensor_bias,
628
+ float epsilon = 0.0f) {
629
+ if (size(reference) != size(computed)) {
630
+ return false;
631
+ }
632
+
633
+ bool passed = true;
634
+ if (epsilon == 0.0f) {
635
+ // fast refcheck w/o epsilon
636
+ for (size_t i = 0; i < size_t(size(reference)); ++i) {
637
+ if (reference(i) != computed(i)) {
638
+ passed = false;
639
+ printf("[%llu] %f, %f\n", static_cast<unsigned long long>(i),
640
+ float(reference(i)), float(computed(i)));
641
+ break;
642
+ }
643
+ }
644
+ } else {
645
+ // refcheck with epsilon
646
+ for (size_t i = 0; i < size_t(size(reference)); ++i) {
647
+ auto ref = static_cast<float>(reference(i));
648
+ auto act = static_cast<float>(computed(i));
649
+ auto abs_error = std::abs(act - ref);
650
+ auto rel_error = abs_error / (std::max(std::abs(act), std::abs(ref)) + 0.00001f);
651
+ if (std::isnan(abs_error) || std::isnan(rel_error) ||
652
+ std::min(abs_error, rel_error) > epsilon) {
653
+ passed = false;
654
+ printf("[%llu] %f, %f\n", static_cast<unsigned long long>(i),
655
+ float(reference(i)), float(computed(i)));
656
+ break;
657
+ }
658
+ }
659
+ }
660
+ #if CUTLASS_DEBUG_TRACE_LEVEL > 1
661
+ if (not passed) {
662
+ cute::print("Reference:");
663
+ cute::print_tensor(reference);
664
+ cute::print("\nComputed:");
665
+ cute::print_tensor(computed);
666
+ cute::print("\n");
667
+
668
+ for (size_t i = 0; i < size_t(size(A)); ++i) {
669
+ printf("[%llu]: A = %f\n", static_cast<unsigned long long>(i), float(A(i)));
670
+ }
671
+ for (size_t i = 0; i < size_t(size(B)); ++i) {
672
+ printf("[%llu]: B = %f\n", static_cast<unsigned long long>(i), float(B(i)));
673
+ }
674
+ if constexpr (IsPerChannelScaleEnabled) {
675
+ for (size_t i = 0; i < size_t(size(tensor_alpha)); ++i) {
676
+ printf("[%llu]: alpha = %f\n", static_cast<unsigned long long>(i),
677
+ float(tensor_alpha(i)));
678
+ }
679
+ for (size_t i = 0; i < size_t(size(tensor_beta)); ++i) {
680
+ printf("[%llu]: beta = %f\n", static_cast<unsigned long long>(i),
681
+ float(tensor_beta(i)));
682
+ }
683
+ }
684
+ if constexpr (IsBiasEnabled) {
685
+ for (size_t i = 0; i < size_t(size(tensor_bias)); ++i) {
686
+ printf("[%llu]: bias = %f\n", static_cast<unsigned long long>(i),
687
+ float(tensor_bias(i)));
688
+ }
689
+ }
690
+ for (size_t i = 0; i < size_t(size(reference)); ++i) {
691
+ printf("[%llu]: ref = %f, computed = %f\n", static_cast<unsigned long long>(i),
692
+ float(reference(i)), float(computed(i)));
693
+ }
694
+ }
695
+ #endif
696
+ return passed;
697
+ }
698
+ };
699
+
700
+ /////////////////////////////////////////////////////////////////////////////////////////////////
701
+
702
+ template <typename Conv, bool SupportStrides = (Conv::DispatchPolicy::ConvOp != cutlass::conv::Operator::kDgrad)>
703
+ bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f,
704
+ dim3 cluster_shape = dim3(0, 0, 0),
705
+ dim3 cluster_shape_fallback = dim3(0, 0, 0)
706
+ ) {
707
+ using ElementScalar = typename Conv::EpilogueOutputOp::ElementScalar;
708
+
709
+ bool passed = true;
710
+ ConvTestbed<Conv> testbed;
711
+ testbed.epsilon = epsilon;
712
+ auto problem_vector = get_conv_problem_vector<
713
+ Conv::NumSpatialDimensions, Conv::DispatchPolicy::ConvOp, SupportStrides>();
714
+
715
+ using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode;
716
+ using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions;
717
+ using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize;
718
+ using Splits = typename gemm::device::detail::Splits;
719
+
720
+ std::vector<DecompositionMode> decomposition_modes = {DecompositionMode::Heuristic};
721
+ static constexpr bool UsesStreamKScheduler = cute::is_same_v<typename Conv::ConvKernel::TileSchedulerTag, cutlass::gemm::StreamKScheduler>;
722
+ if constexpr (UsesStreamKScheduler) {
723
+ decomposition_modes.push_back(DecompositionMode::DataParallel);
724
+ decomposition_modes.push_back(DecompositionMode::SplitK);
725
+ decomposition_modes.push_back(DecompositionMode::StreamK);
726
+ }
727
+
728
+ for (auto conv_problem : problem_vector) {
729
+ #if CUTLASS_DEBUG_TRACE_LEVEL > 0
730
+ print(conv_problem);
731
+ #endif
732
+ for (DecompositionMode decomp_mode : decomposition_modes) {
733
+ std::vector problem_splits = {Splits{1}};
734
+ if constexpr (UsesStreamKScheduler) {
735
+ if (decomp_mode == DecompositionMode::SplitK) {
736
+ problem_splits.push_back(Splits{2});
737
+ problem_splits.push_back(Splits{4});
738
+ }
739
+ }
740
+ for (auto splits : problem_splits) {
741
+
742
+ passed = testbed.run(
743
+ conv_problem,
744
+ cutlass::from_real<ElementScalar>(alpha),
745
+ cutlass::from_real<ElementScalar>(beta),
746
+ cluster_shape,
747
+ cluster_shape_fallback,
748
+ RasterOrderOptions::Heuristic, // raster_order
749
+ MaxSwizzleSize(1),
750
+ splits,
751
+ decomp_mode
752
+ );
753
+ if (!passed) {
754
+ printf("Failed test for "); print(conv_problem);
755
+ return false;
756
+ }
757
+ } // splits
758
+ } // decomposition_mode
759
+ }
760
+
761
+ return passed;
762
+ }
763
+
764
+ /////////////////////////////////////////////////////////////////////////////////////////////////
765
+
766
+ } // namespace test::conv::device
767
+
768
+ /////////////////////////////////////////////////////////////////////////////////////////////////
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #include "cutlass_unit_test.h"
33
+
34
+ #include <iostream>
35
+ #include <iomanip>
36
+ #include <utility>
37
+ #include <type_traits>
38
+ #include <vector>
39
+ #include <numeric>
40
+
41
+ #include <thrust/host_vector.h>
42
+ #include <thrust/device_vector.h>
43
+
44
+ #include <cute/tensor.hpp>
45
+
46
+ using namespace cute;
47
+
48
+ template <class ElementType, class SmemLayout>
49
+ struct SharedStorage
50
+ {
51
+ cute::ArrayEngine<ElementType, cute::cosize_v<SmemLayout>> smem;
52
+ };
53
+
54
+ template <class T, class TiledCopy, class GmemLayout, class SmemLayout>
55
+ __global__ void
56
+ test_tiled_cp_async_device_cute(T const* g_in, T* g_out,
57
+ TiledCopy const tiled_copy,
58
+ GmemLayout gmem_layout, SmemLayout smem_layout)
59
+ {
60
+ using namespace cute;
61
+
62
+ extern __shared__ char shared_memory[];
63
+ using SharedStorage = SharedStorage<T, SmemLayout>;
64
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
65
+
66
+ auto thr_copy = tiled_copy.get_slice(threadIdx.x);
67
+ Tensor gA = make_tensor(make_gmem_ptr(g_in), gmem_layout);
68
+ Tensor gB = make_tensor(make_gmem_ptr(g_out), gmem_layout);
69
+
70
+ // Construct SMEM tensor
71
+ Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout);
72
+
73
+ auto tAgA = thr_copy.partition_S(gA);
74
+ auto tAsA = thr_copy.partition_D(sA);
75
+
76
+ #if 0
77
+ if (thread0()) {
78
+ print("gA : "); print(gA.layout()); print("\n");
79
+ print("sA : "); print(sA.layout()); print("\n");
80
+ print("tAgA: "); print(tAgA.layout()); print("\n");
81
+ print("tAsA: "); print(tAsA.layout()); print("\n");
82
+ }
83
+ #endif
84
+
85
+ copy(tiled_copy, tAgA, tAsA);
86
+
87
+ cp_async_fence();
88
+ cp_async_wait<0>();
89
+ __syncthreads();
90
+
91
+ // Store trivially smem -> gmem
92
+
93
+ if (thread0()) {
94
+ copy(sA, gB);
95
+ }
96
+
97
+ }
98
+
99
+ template <class T, class TiledCopy, class GMEM_Layout, class SMEM_Layout>
100
+ void
101
+ test_tiled_cp_async(
102
+ TiledCopy const tiled_copy,
103
+ GMEM_Layout const& gmem_layout,
104
+ SMEM_Layout const& smem_layout)
105
+ {
106
+ using namespace cute;
107
+
108
+ // Allocate and initialize host test data
109
+ size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits<T>::value, 8);
110
+ thrust::host_vector<T> h_in(N);
111
+ Tensor hA_in = make_tensor(recast_ptr<T>(h_in.data()), gmem_layout);
112
+ for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast<T>(i % 13); }
113
+
114
+ // Allocate and initialize device test data
115
+ thrust::device_vector<T> d_in = h_in;
116
+ thrust::device_vector<T> d_out(h_in.size(), T(-1));
117
+
118
+ // Launch
119
+ int smem_size = int(sizeof(SharedStorage<T, decltype(smem_layout)>));
120
+ test_tiled_cp_async_device_cute<<<1, 128, smem_size>>>(
121
+ reinterpret_cast<T const*>(raw_pointer_cast(d_in.data())),
122
+ reinterpret_cast<T*> (raw_pointer_cast(d_out.data())),
123
+ tiled_copy,
124
+ gmem_layout,
125
+ smem_layout);
126
+
127
+ // Copy results back to host
128
+ thrust::host_vector<T> h_out = d_out;
129
+ Tensor hA_out = make_tensor(recast_ptr<T>(h_out.data()), gmem_layout);
130
+
131
+ // Validate the results. Print only the first 3 errors.
132
+ int count = 3;
133
+ for (int i = 0; i < size(hA_out) && count > 0; ++i) {
134
+ EXPECT_EQ(hA_in(i), hA_out(i));
135
+ if (hA_in(i) != hA_out(i)) {
136
+ --count;
137
+ }
138
+ }
139
+ }
140
+
141
+ template <typename T, typename M, typename N, typename GMEM_STRIDE_TYPE, typename SMEM_LAYOUT, typename TILED_COPY>
142
+ void test_cp_async_no_swizzle() {
143
+ using namespace cute;
144
+ auto smem_atom = SMEM_LAYOUT{};
145
+ auto smem_layout = tile_to_shape(smem_atom, Shape<M, N>{});
146
+ auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{});
147
+ test_tiled_cp_async<T>(TILED_COPY{}, gmem_layout, smem_layout);
148
+ }
149
+
150
+ template <typename T, typename M, typename N, typename GMEM_STRIDE_TYPE, typename SWIZZLE_ATOM, typename SMEM_LAYOUT, typename TILED_COPY>
151
+ void test_cp_async_with_swizzle() {
152
+ using namespace cute;
153
+ auto swizzle_atom = SWIZZLE_ATOM{};
154
+ auto smem_atom = composition(swizzle_atom, SMEM_LAYOUT{});
155
+ auto smem_layout = tile_to_shape(smem_atom, Shape<M, N>{});
156
+ auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{});
157
+ test_tiled_cp_async<T>(TILED_COPY{}, gmem_layout, smem_layout);
158
+ }
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include "cutlass/relatively_equal.h"
35
+ #include "cutlass_unit_test.h"
36
+ #include "cutlass/util/reference/host/tensor_compare.h"
37
+
38
+ #include <iostream>
39
+
40
+ #include <thrust/host_vector.h>
41
+ #include <thrust/device_vector.h>
42
+
43
+ #include <cute/tensor.hpp>
44
+
45
+ using namespace cute;
46
+
47
+ template<typename T>
48
+ struct fp64_tester {
49
+ using value_type = double;
50
+ };
51
+
52
+ template<typename T>
53
+ struct fp64_tester<complex<T>> {
54
+ using value_type = complex<double>;
55
+ };
56
+
57
+ template<class TA,
58
+ class TB,
59
+ class TC,
60
+ class ALayout, // logical shape (M, K)
61
+ class BLayout, // logical shape (N, K)
62
+ class CLayout> // logical shape (M, N)
63
+ auto host_generate_gemm_inputs(
64
+ ALayout a_layout,
65
+ BLayout b_layout,
66
+ CLayout c_layout
67
+ ) {
68
+ thrust::host_vector<TA> h_a(cosize(a_layout));
69
+ thrust::host_vector<TB> h_b(cosize(b_layout));
70
+ thrust::host_vector<TC> h_c(cosize(c_layout));
71
+ thrust::host_vector<TC> h_c_out(cosize(c_layout));
72
+
73
+ auto h_a_tensor = make_tensor(h_a.data(), a_layout);
74
+ auto h_b_tensor = make_tensor(h_b.data(), b_layout);
75
+ auto h_c_tensor = make_tensor(h_c.data(), c_layout);
76
+ size_t max_size = std::max<size_t>({static_cast<size_t>(size(a_layout)),
77
+ static_cast<size_t>(size(b_layout)),
78
+ static_cast<size_t>(size(c_layout))});
79
+ for (size_t i = 0; i < max_size; ++i) {
80
+ double di = static_cast<double>(i);
81
+ if(i < size(a_layout)) {
82
+ h_a_tensor(i) = static_cast<TA>(di / size(a_layout));
83
+ }
84
+ if(i < size(b_layout)) {
85
+ h_b_tensor(i) = static_cast<TB>(di / size(a_layout));
86
+ }
87
+ if(i < size(c_layout)) {
88
+ h_c_tensor(i) = static_cast<TC>((di*di) / size(a_layout));
89
+ }
90
+ }
91
+
92
+ return std::make_tuple(h_a, h_b, h_c, h_c_out);
93
+ }
94
+
95
+ template<class Alpha, class EngineA, class ALayout,
96
+ class EngineB, class BLayout,
97
+ class Beta, class EngineC, class CLayout,
98
+ class ALoadTransform = cute::identity,
99
+ class BLoadTransform = cute::identity,
100
+ class CLoadTransform = cute::identity,
101
+ class CStoreTransform = cute::identity>
102
+ thrust::host_vector<typename EngineC::value_type>
103
+ host_reference_gemm(Alpha alpha,
104
+ Tensor<EngineA, ALayout> const& h_a_tensor,
105
+ Tensor<EngineB, BLayout> const& h_b_tensor,
106
+ Beta beta,
107
+ Tensor<EngineC, CLayout> const& h_c_tensor,
108
+ ALoadTransform const& a_load_transform = {},
109
+ BLoadTransform const& b_load_transform = {},
110
+ CLoadTransform const& c_load_transform = {},
111
+ CStoreTransform const& c_store_transform = {})
112
+ {
113
+ // Cannot use ::value_type because it propagates to complex::value_type,
114
+ // so ViewEngine<complex<double>>::value_type == double
115
+ using TA = remove_cv_t<typename EngineA::element_type>;
116
+ using TB = remove_cv_t<typename EngineB::element_type>;
117
+ using TC = remove_cv_t<typename EngineC::element_type>;
118
+
119
+ using tester = fp64_tester<TC>;
120
+ using ABC_64 = typename tester::value_type;
121
+
122
+ static_assert(std::is_same_v<typename fp64_tester<TA>::value_type, typename fp64_tester<TB>::value_type>);
123
+ static_assert(std::is_same_v<typename fp64_tester<TB>::value_type, typename fp64_tester<TC>::value_type>);
124
+
125
+ thrust::host_vector<TC> h_c_ref(cosize(h_c_tensor.layout()), static_cast<TC>(0.0));
126
+ auto h_c_ref_tensor = make_tensor(h_c_ref.data(), h_c_tensor.layout());
127
+ // A * B
128
+ for (int k = 0; k < size<1>(h_a_tensor); k++) {
129
+ for (int m = 0; m < size<0>(h_a_tensor); m++) {
130
+ for (int n = 0; n < size<0>(h_b_tensor); n++) {
131
+ const auto a_value = a_load_transform(h_a_tensor(m, k));
132
+ const auto b_value = b_load_transform(h_b_tensor(n, k));
133
+ const auto a_value_fp64 = static_cast<ABC_64>(a_value);
134
+ const auto b_value_fp64 = static_cast<ABC_64>(b_value);
135
+ h_c_ref_tensor(m, n) += static_cast<TC>(a_value_fp64 * b_value_fp64);
136
+ }
137
+ }
138
+ }
139
+ // C = A*B + C
140
+ for (int i = 0; i < size(h_c_ref_tensor); i++) {
141
+ const auto ab_value_fp64 = static_cast<ABC_64>(h_c_ref_tensor(i));
142
+ const auto c_value_fp64 = static_cast<ABC_64>(c_load_transform(h_c_tensor(i)));
143
+ h_c_ref_tensor(i) = c_store_transform(static_cast<TC>(alpha * ab_value_fp64 + beta * c_value_fp64));
144
+ }
145
+
146
+ return h_c_ref;
147
+ }
148
+
149
+ template<class EngineC, class CLayout>
150
+ void verify_gemm_correctness(cute::Tensor<EngineC, CLayout> const& h_c_out_tensor,
151
+ cute::Tensor<EngineC, CLayout> const& h_c_ref_tensor)
152
+ {
153
+ // Cannot use ::value_type because it propagates to complex::value_type,
154
+ // so ViewEngine<complex<double>>::value_type == double
155
+ using TC = remove_cv_t<typename EngineC::element_type>;
156
+
157
+ using tester = fp64_tester<TC>;
158
+ using ABC_64 = typename tester::value_type;
159
+
160
+ for (int i = 0; i < size(h_c_ref_tensor); i++) {
161
+ ABC_64 h_c_ref_i = h_c_ref_tensor(i);
162
+ ABC_64 h_c_out_i = h_c_out_tensor(i);
163
+ double epsilon(0.1f);
164
+ double nonzero_floor(std::numeric_limits<double>::min());
165
+ bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor);
166
+ ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i;
167
+ }
168
+ }
169
+
170
+
171
+ template<uint32_t ThreadBlockSize,
172
+ uint32_t CopyMaxVecBits,
173
+ class GMemALayout,
174
+ class GMemBLayout,
175
+ class GMemCLayout,
176
+ class SMemALayout,
177
+ class SMemBLayout,
178
+ class SMemCLayout,
179
+ class TA,
180
+ class TB,
181
+ class TC,
182
+ class Alpha,
183
+ class Beta,
184
+ class TiledMma,
185
+ class ALoadTransform,
186
+ class BLoadTransform,
187
+ class CLoadTransform,
188
+ class CStoreTransform,
189
+ class SMemCopyOpA,
190
+ class SMemCopyOpB,
191
+ class SMemCopyLdOpC,
192
+ class SMemCopyStOpC>
193
+ __launch_bounds__(ThreadBlockSize) __global__ void
194
+ cooperative_gemm_kernel(GMemALayout gmem_a_layout,
195
+ GMemBLayout gmem_b_layout,
196
+ GMemCLayout gmem_c_layout,
197
+ SMemALayout smem_a_layout,
198
+ SMemBLayout smem_b_layout,
199
+ SMemCLayout smem_c_layout,
200
+ TA const* a,
201
+ TB const* b,
202
+ TC const* c,
203
+ TC * c_out,
204
+ Alpha const alpha,
205
+ Beta const beta,
206
+ TiledMma tiled_mma,
207
+ ALoadTransform a_load_transform,
208
+ BLoadTransform b_load_transform,
209
+ CLoadTransform c_load_transform,
210
+ CStoreTransform c_store_transform,
211
+ SMemCopyOpA a_copy_op,
212
+ SMemCopyOpB b_copy_op,
213
+ SMemCopyLdOpC c_copy_ld_op,
214
+ SMemCopyStOpC c_copy_st_op)
215
+ {
216
+ using namespace cute;
217
+
218
+ Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout);
219
+ Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout);
220
+ Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout);
221
+ Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout);
222
+
223
+ constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8;
224
+
225
+ extern __shared__ float4 smem_buf[];
226
+ auto* smem_ptr = reinterpret_cast<unsigned char*>(smem_buf);
227
+ auto* smem_ptr_a = smem_ptr;
228
+ auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes);
229
+ auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(smem_b_layout)), copy_max_vec_bytes);
230
+
231
+ Tensor s_a_tensor = make_tensor(make_smem_ptr<TA>(smem_ptr_a), smem_a_layout);
232
+ Tensor s_b_tensor = make_tensor(make_smem_ptr<TB>(smem_ptr_b), smem_b_layout);
233
+ Tensor s_c_tensor = make_tensor(make_smem_ptr<TC>(smem_ptr_c), smem_c_layout);
234
+
235
+ cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_a_tensor, s_a_tensor);
236
+ cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_b_tensor, s_b_tensor);
237
+ cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_c_tensor, s_c_tensor);
238
+
239
+ cp_async_fence();
240
+ cp_async_wait<0>();
241
+ __syncthreads();
242
+
243
+ cooperative_gemm(
244
+ threadIdx.x, tiled_mma,
245
+ alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor,
246
+ a_load_transform, b_load_transform, c_load_transform, c_store_transform,
247
+ a_copy_op, b_copy_op, c_copy_ld_op, c_copy_st_op
248
+ );
249
+ __syncthreads();
250
+
251
+ cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, s_c_tensor, g_c_out_tensor);
252
+ }
253
+
254
+ template<uint32_t ThreadBlockSize,
255
+ uint32_t CopyMaxVecBits,
256
+ class GMemALayout,
257
+ class GMemBLayout,
258
+ class GMemCLayout,
259
+ class SMemALayout,
260
+ class SMemBLayout,
261
+ class TA,
262
+ class TB,
263
+ class TC,
264
+ class TiledMma,
265
+ class ALoadTransform,
266
+ class BLoadTransform,
267
+ class CLoadTransform,
268
+ class CStoreTransform,
269
+ class SMemCopyOpA,
270
+ class SMemCopyOpB>
271
+ __launch_bounds__(ThreadBlockSize) __global__ void
272
+ cooperative_gemm_kernel_rmem_c(GMemALayout gmem_a_layout,
273
+ GMemBLayout gmem_b_layout,
274
+ GMemCLayout gmem_c_layout,
275
+ SMemALayout smem_a_layout,
276
+ SMemBLayout smem_b_layout,
277
+ TA const* a,
278
+ TB const* b,
279
+ TC const* c,
280
+ TC * c_out,
281
+ TiledMma tiled_mma,
282
+ ALoadTransform a_load_transform,
283
+ BLoadTransform b_load_transform,
284
+ CLoadTransform c_load_transform,
285
+ CStoreTransform c_store_transform,
286
+ SMemCopyOpA a_copy_op,
287
+ SMemCopyOpB b_copy_op)
288
+ {
289
+ using namespace cute;
290
+
291
+ Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout);
292
+ Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout);
293
+ Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout);
294
+ Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout);
295
+
296
+ constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8;
297
+
298
+ extern __shared__ float4 smem_buf[];
299
+ auto* smem_ptr = reinterpret_cast<unsigned char*>(smem_buf);
300
+ auto* smem_ptr_a = smem_ptr;
301
+ auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes);
302
+
303
+ Tensor s_a_tensor = make_tensor(make_smem_ptr<TA>(smem_ptr_a), smem_a_layout);
304
+ Tensor s_b_tensor = make_tensor(make_smem_ptr<TB>(smem_ptr_b), smem_b_layout);
305
+
306
+ cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_a_tensor, s_a_tensor);
307
+ cooperative_copy<ThreadBlockSize, CopyMaxVecBits>(threadIdx.x, g_b_tensor, s_b_tensor);
308
+
309
+ cp_async_fence();
310
+ cp_async_wait<0>();
311
+ __syncthreads();
312
+
313
+ // Create C fragment for storing intermediate results
314
+ auto thr_mma = TiledMma().get_thread_slice(threadIdx.x);
315
+ Tensor g_c_partition = thr_mma.partition_C(g_c_tensor);
316
+ Tensor g_c_out_partition = thr_mma.partition_C(g_c_out_tensor);
317
+ Tensor r_c_partition = thr_mma.make_fragment_C(g_c_partition);
318
+
319
+ // Create indexing help for predicated GEMMs
320
+ Tensor cC = make_identity_tensor(shape(gmem_c_layout));
321
+ Tensor tCcC = thr_mma.partition_C(cC);
322
+
323
+ // Load C from global
324
+ // (always loading in predicated way)
325
+ CUTE_UNROLL
326
+ for (int i = 0; i < size(r_c_partition); ++i)
327
+ {
328
+ if (elem_less(tCcC(i), shape(g_c_tensor)))
329
+ {
330
+ r_c_partition(i) = c_load_transform(g_c_partition(i));
331
+ }
332
+ }
333
+
334
+ cooperative_gemm(
335
+ threadIdx.x, tiled_mma, s_a_tensor, s_b_tensor, r_c_partition,
336
+ a_load_transform, b_load_transform, a_copy_op, b_copy_op
337
+ );
338
+
339
+ __syncthreads();
340
+
341
+ // Store C to global
342
+ // (always storing in predicated way)
343
+ CUTE_UNROLL
344
+ for (int i = 0; i < size(r_c_partition); ++i)
345
+ {
346
+ if (elem_less(tCcC(i), shape(g_c_tensor)))
347
+ {
348
+ g_c_out_partition(i) = c_store_transform(r_c_partition(i));
349
+ }
350
+ }
351
+ }
352
+
353
+ template<uint32_t ThreadBlockSize,
354
+ uint32_t CopyMaxVecBits,
355
+ class TA,
356
+ class TB,
357
+ class TC,
358
+ class GMemALayout, // logical shape (M, K)
359
+ class GMemBLayout, // logical shape (N, K)
360
+ class GMemCLayout, // logical shape (M, N)
361
+ class SMemALayout, // logical shape (M, K)
362
+ class SMemBLayout, // logical shape (N, K)
363
+ class SMemCLayout, // logical shape (M, N)
364
+ class TiledMma,
365
+ class ALoadTransform = cute::identity,
366
+ class BLoadTransform = cute::identity,
367
+ class CLoadTransform = cute::identity,
368
+ class CStoreTransform = cute::identity,
369
+ class ASMemCopyOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>,
370
+ class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>,
371
+ class CSMemCopyLdOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>,
372
+ class CSMemCopyStOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>>
373
+ void test_cooperative_gemm(GMemALayout gmem_a_layout,
374
+ GMemBLayout gmem_b_layout,
375
+ GMemCLayout gmem_c_layout,
376
+ SMemALayout smem_a_layout,
377
+ SMemBLayout smem_b_layout,
378
+ SMemCLayout smem_c_layout,
379
+ TiledMma tiled_mma,
380
+ ALoadTransform a_load_transform = {},
381
+ BLoadTransform b_load_transform = {},
382
+ CLoadTransform c_load_transform = {},
383
+ CStoreTransform c_store_transform = {},
384
+ ASMemCopyOp a_smem_copy_op = {},
385
+ BSMemCopyOp b_smem_copy_op = {},
386
+ CSMemCopyLdOp c_smem_copy_ld_op = {},
387
+ CSMemCopyStOp c_smem_copy_st_op = {})
388
+ {
389
+ static_assert(std::is_same_v<typename fp64_tester<TA>::value_type, typename fp64_tester<TB>::value_type>);
390
+ static_assert(std::is_same_v<typename fp64_tester<TB>::value_type, typename fp64_tester<TC>::value_type>);
391
+
392
+ static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM
393
+ static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN
394
+ static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK
395
+
396
+ static_assert(size<0>(smem_a_layout) == size<0>(smem_c_layout)); // AM == CM
397
+ static_assert(size<0>(smem_b_layout) == size<1>(smem_c_layout)); // BN == CN
398
+ static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK
399
+
400
+ static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout));
401
+ static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout));
402
+ static_assert(cute::size(gmem_c_layout) == cute::size(smem_c_layout));
403
+
404
+ #if 0
405
+ print(" "); print("gmem: "); print(gmem_layout); print("\n");
406
+ print(" "); print("smem: "); print(smem_layout); print("\n");
407
+ print(" "); print("threads: "); print(ThreadBlockSize); print("\n");
408
+ #endif
409
+
410
+ const auto alpha = static_cast<TC>(1.1);
411
+ const auto beta = static_cast<TC>(1.2);
412
+
413
+ // Generate inputs
414
+ auto [h_a, h_b, h_c, h_c_out] = host_generate_gemm_inputs<TA, TB, TC>(gmem_a_layout, gmem_b_layout, gmem_c_layout);
415
+
416
+ thrust::device_vector<TA> d_a(h_a);
417
+ thrust::device_vector<TB> d_b(h_b);
418
+ thrust::device_vector<TC> d_c(h_c);
419
+ thrust::device_vector<TC> d_c_out(h_c_out.size(), TC(float(-1)));
420
+
421
+ constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8;
422
+
423
+ const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) +
424
+ round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) +
425
+ sizeof(TC) * h_c.size();
426
+
427
+
428
+ auto kernel = cooperative_gemm_kernel<
429
+ ThreadBlockSize, CopyMaxVecBits,
430
+ GMemALayout, GMemBLayout, GMemCLayout,
431
+ SMemALayout, SMemBLayout, SMemCLayout,
432
+ TA, TB, TC, decltype(alpha), decltype(beta),
433
+ TiledMma,
434
+ ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform,
435
+ ASMemCopyOp, BSMemCopyOp, CSMemCopyLdOp, CSMemCopyStOp
436
+ >;
437
+
438
+ ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast<int>(shared_memory_size)), 0);
439
+
440
+ kernel<<<1, ThreadBlockSize, shared_memory_size>>>(
441
+ gmem_a_layout,
442
+ gmem_b_layout,
443
+ gmem_c_layout,
444
+ smem_a_layout,
445
+ smem_b_layout,
446
+ smem_c_layout,
447
+ thrust::raw_pointer_cast(d_a.data()),
448
+ thrust::raw_pointer_cast(d_b.data()),
449
+ thrust::raw_pointer_cast(d_c.data()),
450
+ thrust::raw_pointer_cast(d_c_out.data()),
451
+ alpha,
452
+ beta,
453
+ tiled_mma,
454
+ a_load_transform,
455
+ b_load_transform,
456
+ c_load_transform,
457
+ c_store_transform,
458
+ a_smem_copy_op,
459
+ b_smem_copy_op,
460
+ c_smem_copy_ld_op,
461
+ c_smem_copy_st_op
462
+ );
463
+
464
+ cudaError_t result = cudaDeviceSynchronize();
465
+ if (result != cudaSuccess) {
466
+ cudaError_t error = cudaGetLastError();
467
+ FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n";
468
+ }
469
+
470
+ // Reference gemm
471
+ auto h_c_ref = host_reference_gemm(alpha,
472
+ make_tensor(h_a.data(), gmem_a_layout),
473
+ make_tensor(h_b.data(), gmem_b_layout),
474
+ beta,
475
+ make_tensor(h_c.data(), gmem_c_layout),
476
+ a_load_transform,
477
+ b_load_transform,
478
+ c_load_transform,
479
+ c_store_transform);
480
+
481
+ // Copy result data
482
+ h_c_out = d_c_out;
483
+
484
+ // Verify correctness
485
+ verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout),
486
+ make_tensor(h_c_ref.data(), gmem_c_layout));
487
+ }
488
+
489
+ template<uint32_t ThreadBlockSize,
490
+ uint32_t CopyMaxVecBits,
491
+ class TA,
492
+ class TB,
493
+ class TC,
494
+ class GMemALayout, // logical shape (M, K)
495
+ class GMemBLayout, // logical shape (N, K)
496
+ class GMemCLayout, // logical shape (M, N)
497
+ class SMemALayout, // logical shape (M, K)
498
+ class SMemBLayout, // logical shape (N, K)
499
+ class TiledMma,
500
+ class ALoadTransform = cute::identity,
501
+ class BLoadTransform = cute::identity,
502
+ class CLoadTransform = cute::identity,
503
+ class CStoreTransform = cute::identity,
504
+ class ASMemCopyOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>,
505
+ class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment<CopyMaxVecBits>>
506
+ void test_cooperative_gemm_rmem_c(GMemALayout gmem_a_layout,
507
+ GMemBLayout gmem_b_layout,
508
+ GMemCLayout gmem_c_layout,
509
+ SMemALayout smem_a_layout,
510
+ SMemBLayout smem_b_layout,
511
+ TiledMma tiled_mma,
512
+ ALoadTransform a_load_transform = {},
513
+ BLoadTransform b_load_transform = {},
514
+ CLoadTransform c_load_transform = {},
515
+ CStoreTransform c_store_transform = {},
516
+ ASMemCopyOp a_smem_copy_op = {},
517
+ BSMemCopyOp b_smem_copy_op = {})
518
+ {
519
+ static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM
520
+ static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN
521
+ static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK
522
+
523
+ static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK
524
+
525
+ static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout));
526
+ static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout));
527
+
528
+ #if 0
529
+ print(" "); print("gmem: "); print(gmem_layout); print("\n");
530
+ print(" "); print("smem: "); print(smem_layout); print("\n");
531
+ print(" "); print("threads: "); print(ThreadBlockSize); print("\n");
532
+ #endif
533
+
534
+ const auto alpha = static_cast<TC>(1.0);
535
+ const auto beta = static_cast<TC>(1.0);
536
+
537
+ // Generate inputs
538
+ auto [h_a, h_b, h_c, h_c_out] =
539
+ host_generate_gemm_inputs<TA, TB, TC>(gmem_a_layout, gmem_b_layout, gmem_c_layout);
540
+
541
+ thrust::device_vector<TA> d_a(h_a);
542
+ thrust::device_vector<TB> d_b(h_b);
543
+ thrust::device_vector<TC> d_c(h_c);
544
+ thrust::device_vector<TC> d_c_out(h_c_out.size(), static_cast<TC>(-1));
545
+
546
+ constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8;
547
+
548
+ const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) +
549
+ round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes);
550
+
551
+
552
+ auto kernel = cooperative_gemm_kernel_rmem_c<
553
+ ThreadBlockSize, CopyMaxVecBits,
554
+ GMemALayout, GMemBLayout, GMemCLayout,
555
+ SMemALayout, SMemBLayout,
556
+ TA, TB, TC,
557
+ TiledMma,
558
+ ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform,
559
+ ASMemCopyOp, BSMemCopyOp
560
+ >;
561
+
562
+ ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast<int>(shared_memory_size)), 0);
563
+
564
+ kernel<<<1, ThreadBlockSize, shared_memory_size>>>(
565
+ gmem_a_layout,
566
+ gmem_b_layout,
567
+ gmem_c_layout,
568
+ smem_a_layout,
569
+ smem_b_layout,
570
+ thrust::raw_pointer_cast(d_a.data()),
571
+ thrust::raw_pointer_cast(d_b.data()),
572
+ thrust::raw_pointer_cast(d_c.data()),
573
+ thrust::raw_pointer_cast(d_c_out.data()),
574
+ tiled_mma,
575
+ a_load_transform, b_load_transform, c_load_transform, c_store_transform,
576
+ a_smem_copy_op, b_smem_copy_op
577
+ );
578
+
579
+ cudaError_t result = cudaDeviceSynchronize();
580
+ if (result != cudaSuccess) {
581
+ cudaError_t error = cudaGetLastError();
582
+ FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n";
583
+ }
584
+
585
+ // Copy result data
586
+ h_c_out = d_c_out;
587
+
588
+ // Reference gemm
589
+ auto h_c_ref = host_reference_gemm(alpha,
590
+ make_tensor(h_a.data(), gmem_a_layout),
591
+ make_tensor(h_b.data(), gmem_b_layout),
592
+ beta,
593
+ make_tensor(h_c.data(), gmem_c_layout),
594
+ a_load_transform,
595
+ b_load_transform,
596
+ c_load_transform,
597
+ c_store_transform);
598
+
599
+ // Verify correctness
600
+ verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout),
601
+ make_tensor(h_c_ref.data(), gmem_c_layout));
602
+ }
603
+
604
+ template<uint32_t ThreadBlockSize,
605
+ uint32_t CopyMaxVecBits,
606
+ class TA,
607
+ class TB,
608
+ class TC,
609
+ class ShapeMNK,
610
+ class TiledMma,
611
+ class ... Ops>
612
+ void test_cooperative_gemm_col_major_layout(ShapeMNK shape_mnk,
613
+ TiledMma tiled_mma,
614
+ Ops ... ops)
615
+ {
616
+ auto a_layout = make_layout(select<0, 2>(shape_mnk));
617
+ auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{});
618
+ auto c_layout = make_layout(select<0, 1>(shape_mnk));
619
+
620
+ test_cooperative_gemm<ThreadBlockSize,
621
+ CopyMaxVecBits,
622
+ TA, TB, TC>
623
+ (a_layout,
624
+ b_layout,
625
+ c_layout,
626
+ a_layout,
627
+ b_layout,
628
+ c_layout,
629
+ tiled_mma,
630
+ ops...);
631
+ }
632
+
633
+
634
+ template<uint32_t ThreadBlockSize,
635
+ uint32_t CopyMaxVecBits,
636
+ class TA,
637
+ class TB,
638
+ class TC,
639
+ class SMemAtomLayoutA,
640
+ class SMemAtomLayoutB,
641
+ class SMemAtomLayoutC,
642
+ class ShapeMNK,
643
+ class TiledMma,
644
+ class ... Ops>
645
+ std::enable_if_t<std::conjunction_v<cute::is_layout<SMemAtomLayoutA>,
646
+ cute::is_layout<SMemAtomLayoutB>,
647
+ cute::is_layout<SMemAtomLayoutC>>>
648
+ test_cooperative_gemm_col_major_layout(SMemAtomLayoutA smem_atom_layout_a,
649
+ SMemAtomLayoutB smem_atom_layout_b,
650
+ SMemAtomLayoutC smem_atom_layout_c,
651
+ ShapeMNK shape_mnk,
652
+ TiledMma tiled_mma,
653
+ Ops&& ... ops)
654
+ {
655
+ auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk));
656
+ auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{});
657
+ auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk));
658
+
659
+ auto smem_a_layout = tile_to_shape(
660
+ smem_atom_layout_a,
661
+ make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout)));
662
+
663
+ auto smem_b_layout = tile_to_shape(
664
+ smem_atom_layout_b,
665
+ make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout)));
666
+
667
+ auto smem_c_layout = tile_to_shape(
668
+ smem_atom_layout_c,
669
+ make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout)));
670
+
671
+ test_cooperative_gemm<ThreadBlockSize,
672
+ CopyMaxVecBits,
673
+ TA, TB, TC>
674
+ (gmem_a_layout,
675
+ gmem_b_layout,
676
+ gmem_c_layout,
677
+ smem_a_layout,
678
+ smem_b_layout,
679
+ smem_c_layout,
680
+ tiled_mma,
681
+ ops...);
682
+ }
683
+
684
+
685
+ template<uint32_t ThreadBlockSize,
686
+ uint32_t CopyMaxVecBits,
687
+ class TA,
688
+ class TB,
689
+ class TC,
690
+ class ShapeMNK,
691
+ class TiledMma,
692
+ class ... Ops>
693
+ void test_cooperative_gemm_col_major_layout_rmem_c(ShapeMNK shape_mnk,
694
+ TiledMma tiled_mma,
695
+ Ops ... ops)
696
+ {
697
+ auto a_layout = make_layout(select<0, 2>(shape_mnk));
698
+ auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{});
699
+ auto c_layout = make_layout(select<0, 1>(shape_mnk));
700
+
701
+
702
+ test_cooperative_gemm_rmem_c<ThreadBlockSize,
703
+ CopyMaxVecBits,
704
+ TA, TB,TC>
705
+ (a_layout,
706
+ b_layout,
707
+ c_layout,
708
+ a_layout,
709
+ b_layout,
710
+ tiled_mma,
711
+ ops...);
712
+ }
713
+
714
+ template<uint32_t ThreadBlockSize,
715
+ uint32_t CopyMaxVecBits,
716
+ class TA,
717
+ class TB,
718
+ class TC,
719
+ class SMemAtomLayoutA,
720
+ class SMemAtomLayoutB,
721
+ class ShapeMNK,
722
+ class TiledMma,
723
+ class ... Ops>
724
+ std::enable_if_t<std::conjunction_v<cute::is_layout<SMemAtomLayoutA>,
725
+ cute::is_layout<SMemAtomLayoutB>>>
726
+ test_cooperative_gemm_col_major_layout_rmem_c(SMemAtomLayoutA smem_atom_layout_a,
727
+ SMemAtomLayoutB smem_atom_layout_b,
728
+ ShapeMNK shape_mnk,
729
+ TiledMma tiled_mma,
730
+ Ops ... ops)
731
+ {
732
+ auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk));
733
+ auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{});
734
+ auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk));
735
+
736
+ auto smem_a_layout = tile_to_shape(
737
+ smem_atom_layout_a,
738
+ make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout)));
739
+
740
+ auto smem_b_layout = tile_to_shape(
741
+ smem_atom_layout_b,
742
+ make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout)));
743
+
744
+ test_cooperative_gemm_rmem_c<ThreadBlockSize, CopyMaxVecBits,
745
+ TA, TB, TC>
746
+ (gmem_a_layout,
747
+ gmem_b_layout,
748
+ gmem_c_layout,
749
+ smem_a_layout,
750
+ smem_b_layout,
751
+ tiled_mma,
752
+ ops...);
753
+ }
754
+
755
+ template<uint32_t ThreadBlockSize,
756
+ typename T,
757
+ class ... Args>
758
+ void test_cooperative_gemm_col_major_layout_rmem_c(Args&& ... args)
759
+ {
760
+ test_cooperative_gemm_col_major_layout_rmem_c<ThreadBlockSize,
761
+ cute::sizeof_bits_v<T>,
762
+ T, T, T>
763
+ (static_cast<Args&&>(args)...);
764
+ }
765
+
766
+ template<uint32_t ThreadBlockSize,
767
+ class T,
768
+ class ... Args>
769
+ void test_cooperative_gemm_col_major_layout(Args&& ... args)
770
+ {
771
+ test_cooperative_gemm_col_major_layout<ThreadBlockSize,
772
+ cute::sizeof_bits_v<T>,
773
+ T, T, T>
774
+ (static_cast<Args&&>(args)...);
775
+ }
build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ * SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ *
30
+ **************************************************************************************************/
31
+
32
+ #pragma once
33
+
34
+ #include "cutlass_unit_test.h"
35
+
36
+ #include <iostream>
37
+ #include <cstdint>
38
+
39
+ #include <thrust/host_vector.h>
40
+ #include <thrust/device_vector.h>
41
+
42
+ #include <cute/tensor.hpp>
43
+
44
+ namespace cutlass::test {
45
+
46
+ template <class ElementType, class SmemLayout>
47
+ struct SharedStorage
48
+ {
49
+ cute::ArrayEngine<ElementType, cute::cosize_v<SmemLayout>> smem;
50
+ alignas(16) cute::uint64_t tma_load_mbar[1];
51
+ };
52
+
53
+ #if CUDA_12_0_SM90_FEATURES_SUPPORTED
54
+
55
+ template <class T, class TiledCopy, class CTA_Tiler, class GmemLayout, class SmemLayout>
56
+ __global__ void
57
+ tma_test_device_cute(T const* g_in, T* g_out,
58
+ CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler,
59
+ GmemLayout gmem_layout, SmemLayout smem_layout)
60
+ {
61
+ using namespace cute;
62
+ CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout)));
63
+
64
+ // Use Shared Storage structure to allocate and distribute aligned SMEM addresses
65
+ extern __shared__ char shared_memory[];
66
+ using SharedStorage = SharedStorage<T, SmemLayout>;
67
+ SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
68
+
69
+ // Construct SMEM tensor
70
+ Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...)
71
+ // Shared memory barriers use 64bits in SMEM for synchronization
72
+ uint64_t* tma_load_mbar = shared_storage.tma_load_mbar;
73
+
74
+ // TMA requires special handling of strides to deal with coord codomain mapping
75
+ // Represent the full tensors -- get these from TMA
76
+ Tensor mA = tma.get_tma_tensor(shape(gmem_layout));
77
+ Tensor mB = make_tensor(make_gmem_ptr<T>(g_out), gmem_layout);
78
+
79
+ constexpr int R = rank_v<CTA_Tiler>;
80
+ Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
81
+ Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...)
82
+
83
+ //
84
+ // Prepare the TMA_LOAD
85
+ //
86
+
87
+ auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice
88
+ Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N)
89
+ Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N)
90
+
91
+ #if 0
92
+ if (thread0()) {
93
+ print(tma);
94
+ print("TILE : "); print(cta_tiler); print("\n");
95
+ print(" mA : "); print( mA); print("\n");
96
+ print(" mB : "); print( mB); print("\n");
97
+ print(" gA : "); print( gA); print("\n");
98
+ print(" gB : "); print( gB); print("\n");
99
+ print(" sA : "); print( sA); print("\n");
100
+ print("tAgA_x: "); print(tAgA_x); print("\n");
101
+ print("tAsA_x: "); print(tAsA_x); print("\n");
102
+ }
103
+ #endif
104
+
105
+ //
106
+ // Perform the TMA_LOAD
107
+ //
108
+
109
+ // INPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles
110
+ Tensor tAgA = group_modes<1,rank(tAgA_x)>(tAgA_x); // (TMA,REST)
111
+ Tensor tAsA = group_modes<1,rank(tAsA_x)>(tAsA_x); // (TMA,REST)
112
+ static_assert(size<1>(tAsA) == 1);
113
+
114
+ // OUTPUT: Group the CTA_TILE_X modes and REST_X modes for output
115
+ Tensor tBgB = group_modes<0,R>(group_modes<R,rank(gB)>(gB)); // (CTA_TILE, REST)
116
+
117
+ #if 0
118
+ if (thread0()) {
119
+ print("tAgA : "); print(tAgA); print("\n");
120
+ print("tAsA : "); print(tAsA); print("\n");
121
+ print("tBgB : "); print(tBgB); print("\n");
122
+ }
123
+ #endif
124
+
125
+ // Test L2 prefetch
126
+ if (threadIdx.x == 0) {
127
+ prefetch(tma, tAgA);
128
+ }
129
+
130
+ // Loop over the TMA stages, using smem as our buffer
131
+ for (int stage = 0; stage < size<1>(tAgA); ++stage)
132
+ {
133
+ // Set the bytes transferred in this TMA transaction (may involve multiple issues)
134
+ constexpr int kTmaTransactionBytes = sizeof(make_tensor_like(tensor<0>(tAsA)));
135
+
136
+ if (threadIdx.x == 0)
137
+ {
138
+ /// Initialize shared memory barrier
139
+ tma_load_mbar[0] = 0;
140
+ cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/);
141
+ cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes);
142
+
143
+ copy(tma.with(tma_load_mbar[0]), tAgA(_,stage), tAsA(_,0));
144
+ }
145
+ __syncthreads();
146
+
147
+ /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value
148
+ constexpr int kPhaseBit = 0;
149
+ cute::wait_barrier(tma_load_mbar[0], kPhaseBit);
150
+
151
+ //
152
+ // Write out trivially smem -> gmem
153
+ //
154
+
155
+ // Subbyte elements could cause race conditions, so be even more conservative
156
+ if (thread0()) {
157
+ copy(sA, tBgB(_,stage));
158
+ }
159
+
160
+ __syncthreads();
161
+ }
162
+ }
163
+
164
+ template <class T, class TmaType = T, class CopyOp, class GMEM_Layout, class SMEM_Layout, class CTA_Tile>
165
+ auto
166
+ test_tma_load(CopyOp const& copy_op,
167
+ GMEM_Layout const& gmem_layout,
168
+ SMEM_Layout const& smem_layout,
169
+ CTA_Tile const& cta_tile)
170
+ {
171
+ using namespace cute;
172
+
173
+ // Allocate and initialize host test data
174
+ size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits<T>::value, 8);
175
+ thrust::host_vector<uint8_t> h_in(N);
176
+ for (size_t i = 0; i < h_in.size(); ++i) {
177
+ h_in[i] = uint8_t(i % 13);
178
+ }
179
+ Tensor hA_in = make_tensor(recast_ptr<T>(h_in.data()), gmem_layout);
180
+
181
+ // Allocate and initialize device test data
182
+ thrust::device_vector<uint8_t> d_in = h_in;
183
+ thrust::device_vector<uint8_t> d_out(h_in.size(), uint8_t(-1)); // overflow uint
184
+
185
+ // Create TMA for this device Tensor
186
+ Tensor gA = make_tensor(make_gmem_ptr<T>(raw_pointer_cast(d_in.data())), gmem_layout);
187
+ auto tma = make_tma_copy<TmaType>(copy_op, gA, smem_layout, cta_tile, Int<1>{});
188
+ //print(tma);
189
+
190
+ // Launch
191
+ int smem_size = int(sizeof(SharedStorage<T, decltype(smem_layout)>));
192
+ tma_test_device_cute<<<1, 128, smem_size>>>(
193
+ reinterpret_cast<T const*>(raw_pointer_cast(d_in.data())),
194
+ reinterpret_cast<T*> (raw_pointer_cast(d_out.data())),
195
+ tma, cta_tile,
196
+ gmem_layout,
197
+ smem_layout);
198
+
199
+ // Copy results back to host
200
+ thrust::host_vector<uint8_t> h_out = d_out;
201
+ Tensor hA_out = make_tensor(recast_ptr<T>(h_out.data()), gmem_layout);
202
+
203
+ // Validate the results. Print only the first 3 errors.
204
+ int count = 3;
205
+ for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) {
206
+ EXPECT_EQ(hA_in(i), hA_out(i));
207
+ if (hA_in(i) != hA_out(i)) {
208
+ --count;
209
+ }
210
+ }
211
+
212
+ return tma;
213
+ }
214
+
215
+ #endif
216
+
217
+ } // end namespace cutlass::test